diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index cc37aa2e6c0f668bb717ab357678f15a3300830e..62fbf49bf33aab316e28719d3486f307aebcfd86 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -89,12 +89,11 @@ export class CondenseQuestionChatEngine implements ChatEngine { private async condenseQuestion(chatHistory: ChatMessage[], question: string) { const chatHistoryStr = messagesToHistoryStr(chatHistory); - return this.serviceContext.llmPredictor.predict( - defaultCondenseQuestionPrompt, - { + return this.serviceContext.llm.complete( + defaultCondenseQuestionPrompt({ question: question, chat_history: chatHistoryStr, - } + }) ); } @@ -104,7 +103,9 @@ export class CondenseQuestionChatEngine implements ChatEngine { ): Promise<Response> { chatHistory = chatHistory ?? this.chatHistory; - const condensedQuestion = await this.condenseQuestion(chatHistory, message); + const condensedQuestion = ( + await this.condenseQuestion(chatHistory, message) + ).message.content; const response = await this.queryEngine.query(condensedQuestion); diff --git a/packages/core/src/PromptHelper.ts b/packages/core/src/PromptHelper.ts index 3ead2e09094d9bdb92288895c8504b4c70a282dc..2d9ae8b3a458b887986e30a6bb7036032f5e5633 100644 --- a/packages/core/src/PromptHelper.ts +++ b/packages/core/src/PromptHelper.ts @@ -1,4 +1,3 @@ -import { chunk } from "lodash"; import { globalsHelper } from "./GlobalsHelper"; import { SimplePrompt } from "./Prompt"; import { SentenceSplitter } from "./TextSplitter"; diff --git a/packages/core/src/QuestionGenerator.ts b/packages/core/src/QuestionGenerator.ts index dd669c2236edc969fffc2336872f4036689a3e9b..24d46e954938be822d9ca8bac2dc984d6508d2bc 100644 --- a/packages/core/src/QuestionGenerator.ts +++ b/packages/core/src/QuestionGenerator.ts @@ -1,4 +1,3 @@ -import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./llm/LLMPredictor"; import { BaseOutputParser, StructuredOutput, @@ -10,6 +9,7 @@ import { defaultSubQuestionPrompt, } from "./Prompt"; import { ToolMetadata } from "./Tool"; +import { LLM, OpenAI } from "./llm/LLM"; export interface SubQuestion { subQuestion: string; @@ -27,12 +27,12 @@ export interface BaseQuestionGenerator { * LLMQuestionGenerator uses the LLM to generate new questions for the LLM using tools and a user query. */ export class LLMQuestionGenerator implements BaseQuestionGenerator { - llmPredictor: BaseLLMPredictor; + llm: LLM; prompt: SimplePrompt; outputParser: BaseOutputParser<StructuredOutput<SubQuestion[]>>; constructor(init?: Partial<LLMQuestionGenerator>) { - this.llmPredictor = init?.llmPredictor ?? new ChatGPTLLMPredictor(); + this.llm = init?.llm ?? new OpenAI(); this.prompt = init?.prompt ?? defaultSubQuestionPrompt; this.outputParser = init?.outputParser ?? new SubQuestionOutputParser(); } @@ -40,10 +40,14 @@ export class LLMQuestionGenerator implements BaseQuestionGenerator { async generate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]> { const toolsStr = buildToolsText(tools); const queryStr = query; - const prediction = await this.llmPredictor.predict(this.prompt, { - toolsStr, - queryStr, - }); + const prediction = ( + await this.llm.complete( + this.prompt({ + toolsStr, + queryStr, + }) + ) + ).message.content; const structuredOutput = this.outputParser.parse(prediction); diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/ResponseSynthesizer.ts index eeecbbd5958d60755ad28dd33c35a439d91a3d19..8aa60d9bd35e0ad3061b555b44c8acd08a907e69 100644 --- a/packages/core/src/ResponseSynthesizer.ts +++ b/packages/core/src/ResponseSynthesizer.ts @@ -1,4 +1,3 @@ -import { ChatGPTLLMPredictor, BaseLLMPredictor } from "./llm/LLMPredictor"; import { MetadataMode, NodeWithScore } from "./Node"; import { SimplePrompt, @@ -9,6 +8,7 @@ import { getBiggestPrompt } from "./PromptHelper"; import { Response } from "./Response"; import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; import { Event } from "./callbacks/CallbackManager"; +import { LLM } from "./llm/LLM"; /** * Response modes of the response synthesizer @@ -43,11 +43,11 @@ interface BaseResponseBuilder { * A response builder that just concatenates responses. */ export class SimpleResponseBuilder implements BaseResponseBuilder { - llmPredictor: BaseLLMPredictor; + llm: LLM; textQATemplate: SimplePrompt; constructor(serviceContext: ServiceContext) { - this.llmPredictor = serviceContext.llmPredictor; + this.llm = serviceContext.llm; this.textQATemplate = defaultTextQaPrompt; } @@ -62,7 +62,8 @@ export class SimpleResponseBuilder implements BaseResponseBuilder { }; const prompt = this.textQATemplate(input); - return this.llmPredictor.predict(prompt, {}, parentEvent); + const response = await this.llm.complete(prompt, parentEvent); + return response.message.content; } } @@ -124,13 +125,14 @@ export class Refine implements BaseResponseBuilder { for (const chunk of textChunks) { if (!response) { - response = await this.serviceContext.llmPredictor.predict( - textQATemplate, - { - context: chunk, - }, - parentEvent - ); + response = ( + await this.serviceContext.llm.complete( + textQATemplate({ + context: chunk, + }), + parentEvent + ) + ).message.content; } else { response = await this.refineResponseSingle( response, @@ -158,14 +160,15 @@ export class Refine implements BaseResponseBuilder { ]); for (const chunk of textChunks) { - response = await this.serviceContext.llmPredictor.predict( - refineTemplate, - { - context: chunk, - existingAnswer: response, - }, - parentEvent - ); + response = ( + await this.serviceContext.llm.complete( + refineTemplate({ + context: chunk, + existingAnswer: response, + }), + parentEvent + ) + ).message.content; } return response; } @@ -228,27 +231,30 @@ export class TreeSummarize implements BaseResponseBuilder { ); if (packedTextChunks.length === 1) { - return this.serviceContext.llmPredictor.predict( - summaryTemplate, - { - context: packedTextChunks[0], - }, - parentEvent - ); + return ( + await this.serviceContext.llm.complete( + summaryTemplate({ + context: packedTextChunks[0], + }), + parentEvent + ) + ).message.content; } else { const summaries = await Promise.all( packedTextChunks.map((chunk) => - this.serviceContext.llmPredictor.predict( - summaryTemplate, - { + this.serviceContext.llm.complete( + summaryTemplate({ context: chunk, - }, + }), parentEvent ) ) ); - return this.getResponse(query, summaries); + return this.getResponse( + query, + summaries.map((s) => s.message.content) + ); } } } diff --git a/packages/core/src/ServiceContext.ts b/packages/core/src/ServiceContext.ts index 51423c917bfec3db9774c1b2670599c98299a8d6..92a567b9bf9ad2112c7ebaa1d8674e29a72e9a38 100644 --- a/packages/core/src/ServiceContext.ts +++ b/packages/core/src/ServiceContext.ts @@ -1,6 +1,5 @@ import { BaseEmbedding, OpenAIEmbedding } from "./Embedding"; -import { OpenAI } from "./llm/LLM"; -import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./llm/LLMPredictor"; +import { LLM, OpenAI } from "./llm/LLM"; import { NodeParser, SimpleNodeParser } from "./NodeParser"; import { PromptHelper } from "./PromptHelper"; import { CallbackManager } from "./callbacks/CallbackManager"; @@ -9,7 +8,7 @@ import { CallbackManager } from "./callbacks/CallbackManager"; * The ServiceContext is a collection of components that are used in different parts of the application. */ export interface ServiceContext { - llmPredictor: BaseLLMPredictor; + llm: LLM; promptHelper: PromptHelper; embedModel: BaseEmbedding; nodeParser: NodeParser; @@ -18,7 +17,6 @@ export interface ServiceContext { } export interface ServiceContextOptions { - llmPredictor?: BaseLLMPredictor; llm?: OpenAI; promptHelper?: PromptHelper; embedModel?: BaseEmbedding; @@ -32,9 +30,7 @@ export interface ServiceContextOptions { export function serviceContextFromDefaults(options?: ServiceContextOptions) { const callbackManager = options?.callbackManager ?? new CallbackManager(); const serviceContext: ServiceContext = { - llmPredictor: - options?.llmPredictor ?? - new ChatGPTLLMPredictor({ callbackManager, languageModel: options?.llm }), + llm: options?.llm ?? new OpenAI(), embedModel: options?.embedModel ?? new OpenAIEmbedding(), nodeParser: options?.nodeParser ?? @@ -54,8 +50,8 @@ export function serviceContextFromServiceContext( options: ServiceContextOptions ) { const newServiceContext = { ...serviceContext }; - if (options.llmPredictor) { - newServiceContext.llmPredictor = options.llmPredictor; + if (options.llm) { + newServiceContext.llm = options.llm; } if (options.promptHelper) { newServiceContext.promptHelper = options.promptHelper; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 2218e2d5d221e6aedd3e52a9e82625eebb591e2f..1a5aada1961ff418c6ce7489f3ca0a7dbaf62c2b 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -3,7 +3,6 @@ export * from "./constants"; export * from "./Embedding"; export * from "./GlobalsHelper"; export * from "./llm/LLM"; -export * from "./llm/LLMPredictor"; export * from "./Node"; export * from "./NodeParser"; // export * from "./openai"; Don't export OpenAIWrapper diff --git a/packages/core/src/indices/list/ListIndexRetriever.ts b/packages/core/src/indices/list/ListIndexRetriever.ts index d0761a269ae6fbee969d40658a23e6fe253c4070..d359e7b3ecf0a33e15a1c0d813a7debda56d7680 100644 --- a/packages/core/src/indices/list/ListIndexRetriever.ts +++ b/packages/core/src/indices/list/ListIndexRetriever.ts @@ -88,10 +88,9 @@ export class ListIndexLLMRetriever implements BaseRetriever { const fmtBatchStr = this.formatNodeBatchFn(nodesBatch); const input = { context: fmtBatchStr, query: query }; - const rawResponse = await this.serviceContext.llmPredictor.predict( - this.choiceSelectPrompt, - input - ); + const rawResponse = ( + await this.serviceContext.llm.complete(this.choiceSelectPrompt(input)) + ).message.content; // parseResult is a map from doc number to relevance score const parseResult = this.parseChoiceSelectAnswerFn( diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 5a721093bc4093616c5d898b0840c6fea01d9f36..839deabc61dc200ef29466d70a517bdc1a654a2f 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -32,13 +32,13 @@ export interface LLM { * Get a chat response from the LLM * @param messages */ - chat(messages: ChatMessage[]): Promise<ChatResponse>; + chat(messages: ChatMessage[], parentEvent?: Event): Promise<ChatResponse>; /** * Get a prompt completion from the LLM * @param prompt the prompt to complete */ - complete(prompt: string): Promise<CompletionResponse>; + complete(prompt: string, parentEvent?: Event): Promise<CompletionResponse>; } export const GPT4_MODELS = { @@ -151,17 +151,17 @@ export class OpenAI implements LLM { export const ALL_AVAILABLE_LLAMADEUCE_MODELS = { "Llama-2-70b-chat": { - contextWindow: 4000, + contextWindow: 4096, replicateApi: "replicate/llama70b-v2-chat:e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48", }, "Llama-2-13b-chat": { - contextWindow: 4000, + contextWindow: 4096, replicateApi: "a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5", }, "Llama-2-7b-chat": { - contextWindow: 4000, + contextWindow: 4096, replicateApi: "a16z-infra/llama7b-v2-chat:4f0a4744c7295c024a1de15e1a63c880d3da035fa1f49bfd344fe076074c8eea", }, @@ -174,13 +174,13 @@ export class LlamaDeuce implements LLM { model: keyof typeof ALL_AVAILABLE_LLAMADEUCE_MODELS; temperature: number; maxTokens?: number; - session: ReplicateSession; + replicateSession: ReplicateSession; constructor(init?: Partial<LlamaDeuce>) { this.model = init?.model ?? "Llama-2-70b-chat"; this.temperature = init?.temperature ?? 0; this.maxTokens = init?.maxTokens ?? undefined; - this.session = init?.session ?? new ReplicateSession(); + this.replicateSession = init?.replicateSession ?? new ReplicateSession(); } mapMessageType(messageType: MessageType): string { @@ -202,7 +202,7 @@ export class LlamaDeuce implements LLM { ): Promise<ChatResponse> { const api = ALL_AVAILABLE_LLAMADEUCE_MODELS[this.model] .replicateApi as `${string}/${string}:${string}`; - const response = await this.session.replicate.run(api, { + const response = await this.replicateSession.replicate.run(api, { input: { prompt: messages.reduce((acc, message) => { @@ -215,7 +215,7 @@ export class LlamaDeuce implements LLM { }); return { message: { - content: (response as Array<string>).join(""), + content: (response as Array<string>).join(""), // We need to do this because replicate returns a list of strings (for streaming functionality) role: "assistant", }, }; @@ -225,6 +225,6 @@ export class LlamaDeuce implements LLM { prompt: string, parentEvent?: Event ): Promise<CompletionResponse> { - return this.chat([{ content: prompt, role: "system" }], parentEvent); // Using system prompt here to avoid giving it a previx + return this.chat([{ content: prompt, role: "user" }], parentEvent); } } diff --git a/packages/core/src/llm/LLMPredictor.ts b/packages/core/src/llm/LLMPredictor.ts deleted file mode 100644 index 8dcdeb83ed6228b4500ff9acecbe3e1d33d81a22..0000000000000000000000000000000000000000 --- a/packages/core/src/llm/LLMPredictor.ts +++ /dev/null @@ -1,65 +0,0 @@ -import { ALL_AVAILABLE_OPENAI_MODELS, OpenAI } from "./LLM"; -import { SimplePrompt } from "../Prompt"; -import { CallbackManager, Event } from "../callbacks/CallbackManager"; - -/** - * LLM Predictors are an abstraction to predict the response to a prompt. - */ -export interface BaseLLMPredictor { - getLlmMetadata(): Promise<any>; - predict( - prompt: string | SimplePrompt, - input?: Record<string, string>, - parentEvent?: Event - ): Promise<string>; -} - -/** - * ChatGPTLLMPredictor is a predictor that uses GPT. - */ -export class ChatGPTLLMPredictor implements BaseLLMPredictor { - model: keyof typeof ALL_AVAILABLE_OPENAI_MODELS; - retryOnThrottling: boolean; - languageModel: OpenAI; - callbackManager?: CallbackManager; - - constructor(props?: Partial<ChatGPTLLMPredictor>) { - const { - model = "gpt-3.5-turbo", - retryOnThrottling = true, - callbackManager, - languageModel, - } = props || {}; - this.model = model; - this.callbackManager = callbackManager; - this.retryOnThrottling = retryOnThrottling; - - this.languageModel = - languageModel ?? - new OpenAI({ - model: this.model, - callbackManager: this.callbackManager, - }); - } - - async getLlmMetadata() { - throw new Error("Not implemented yet"); - } - - async predict( - prompt: string | SimplePrompt, - input?: Record<string, string>, - parentEvent?: Event, - logProgress: boolean = false - ): Promise<string> { - if (typeof prompt === "string") { - if (logProgress) { - console.log("PROMPT", prompt); - } - const result = await this.languageModel.complete(prompt, parentEvent); - return result.message.content; - } else { - return this.predict(prompt(input ?? {}), undefined, parentEvent); - } - } -} diff --git a/packages/core/src/readers/PDFReader.ts b/packages/core/src/readers/PDFReader.ts index 4c0d977b9ad0a47c879fd62d99f30f4078615830..a42d65ca82915c6a7ef4454cb1182082be5c69d2 100644 --- a/packages/core/src/readers/PDFReader.ts +++ b/packages/core/src/readers/PDFReader.ts @@ -8,7 +8,7 @@ import _ from "lodash"; /** * Read the text of a PDF */ -export default class PDFReader implements BaseReader { +export class PDFReader implements BaseReader { async loadData( file: string, fs: GenericFileSystem = DEFAULT_FS diff --git a/packages/core/src/readers/SimpleDirectoryReader.ts b/packages/core/src/readers/SimpleDirectoryReader.ts index 99d1c4f9974be1b2b9334b4fb344e1319640551f..70b9ea686920521eb97ad0515e0ddcd44f85faff 100644 --- a/packages/core/src/readers/SimpleDirectoryReader.ts +++ b/packages/core/src/readers/SimpleDirectoryReader.ts @@ -3,7 +3,7 @@ import { Document } from "../Node"; import { BaseReader } from "./base"; import { CompleteFileSystem, walk } from "../storage/FileSystem"; import { DEFAULT_FS } from "../storage/constants"; -import PDFReader from "./PDFReader"; +import { PDFReader } from "./PDFReader"; /** * Read a .txt file @@ -33,7 +33,7 @@ export type SimpleDirectoryReaderLoadDataProps = { /** * Read all of the documents in a directory. Currently supports PDF and TXT files. */ -export default class SimpleDirectoryReader implements BaseReader { +export class SimpleDirectoryReader implements BaseReader { async loadData({ directoryPath, fs = DEFAULT_FS as CompleteFileSystem,