diff --git a/packages/core/src/LLMPredictor.ts b/packages/core/src/LLMPredictor.ts index 9df10ce6bc66f8f1cced27c504fae21ae63de5e9..ac54bc1aa5a3af8c6568cd6dd3dbc299d464f81e 100644 --- a/packages/core/src/LLMPredictor.ts +++ b/packages/core/src/LLMPredictor.ts @@ -1,8 +1,12 @@ import { ChatOpenAI } from "./LanguageModel"; +import { SimplePrompt } from "./Prompt"; export interface BaseLLMPredictor { getLlmMetadata(): Promise<any>; - apredict(prompt: string, options: any): Promise<string>; + apredict( + prompt: string | SimplePrompt, + input?: { [key: string]: string } + ): Promise<string>; // stream(prompt: string, options: any): Promise<any>; } @@ -25,13 +29,21 @@ export class ChatGPTLLMPredictor implements BaseLLMPredictor { throw new Error("Not implemented yet"); } - async apredict(prompt: string, options: any) { - return this.languageModel.agenerate([ - { - content: prompt, - type: "human", - }, - ]); + async apredict( + prompt: string | SimplePrompt, + input?: { [key: string]: string } + ): Promise<string> { + if (typeof prompt === "string") { + const result = await this.languageModel.agenerate([ + { + content: prompt, + type: "human", + }, + ]); + return result.generations[0][0].text; + } else { + return this.apredict(prompt(input ?? {})); + } } // async stream(prompt: string, options: any) { diff --git a/packages/core/src/LanguageModel.ts b/packages/core/src/LanguageModel.ts index cf0d1307f8d6ef8e9b3d591bf384b5ced3ebfa31..8862e2fd905ad0f56e0ab45d8df95377b6c62dfe 100644 --- a/packages/core/src/LanguageModel.ts +++ b/packages/core/src/LanguageModel.ts @@ -6,8 +6,6 @@ import { getOpenAISession, } from "./openai"; -interface LLMResult {} - export interface BaseLanguageModel {} type MessageType = "human" | "ai" | "system" | "generic" | "function"; @@ -22,7 +20,7 @@ interface Generation { generationInfo?: { [key: string]: any }; } -interface LLMResult { +export interface LLMResult { generations: Generation[][]; // Each input can have more than one generations } @@ -62,7 +60,7 @@ export class ChatOpenAI extends BaseChatModel { } } - async agenerate(messages: BaseMessage[]) { + async agenerate(messages: BaseMessage[]): Promise<LLMResult> { const { data } = await this.session.openai.createChatCompletion({ model: this.model, temperature: this.temperature, @@ -75,6 +73,6 @@ export class ChatOpenAI extends BaseChatModel { }); const content = data.choices[0].message?.content ?? ""; - return content; + return { generations: [[{ text: content }]] }; } } diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts index 2d90617b3affb93a4826be6b6372abc3ec33d511..8d1db56cf872c51d7d434e926f78b11ecfcc31c5 100644 --- a/packages/core/src/Prompt.ts +++ b/packages/core/src/Prompt.ts @@ -1,6 +1,7 @@ /** * A SimplePrompt is a function that takes a dictionary of inputs and returns a string. * NOTE this is a different interface compared to LlamaIndex Python + * NOTE 2: we default to empty string to make it easy to calculate prompt sizes */ export type SimplePrompt = (input: { [key: string]: string }) => string; @@ -16,7 +17,7 @@ DEFAULT_TEXT_QA_PROMPT_TMPL = ( */ export const defaultTextQaPrompt: SimplePrompt = (input) => { - const { context, query } = input; + const { context = "", query = "" } = input; return `Context information is below. --------------------- @@ -41,7 +42,7 @@ DEFAULT_SUMMARY_PROMPT_TMPL = ( */ export const defaultSummaryPrompt: SimplePrompt = (input) => { - const { context } = input; + const { context = "" } = input; return `Write a summary of the following. Try to use only the information provided. Try to include as many key details as possible. @@ -69,7 +70,7 @@ DEFAULT_REFINE_PROMPT_TMPL = ( */ export const defaultRefinePrompt: SimplePrompt = (input) => { - const { query, existingAnswer, context } = input; + const { query = "", existingAnswer = "", context = "" } = input; return `The original question is as follows: ${query} We have provided an existing answer: ${existingAnswer} diff --git a/packages/core/src/PromptHelper.ts b/packages/core/src/PromptHelper.ts new file mode 100644 index 0000000000000000000000000000000000000000..eb797f0c4942de06bb3130ddfcc29e70e73ed347 --- /dev/null +++ b/packages/core/src/PromptHelper.ts @@ -0,0 +1,14 @@ +import { + DEFAULT_CONTEXT_WINDOW, + DEFAULT_NUM_OUTPUTS, + DEFAULT_CHUNK_OVERLAP_RATIO, +} from "./constants"; + +class PromptHelper { + contextWindow = DEFAULT_CONTEXT_WINDOW; + numOutput = DEFAULT_NUM_OUTPUTS; + chunkOverlapRatio = DEFAULT_CHUNK_OVERLAP_RATIO; + chunkSizeLimit?: number; + tokenizer?: (text: string) => string[]; + separator = " "; +} diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/ResponseSynthesizer.ts index b7e23f800d2dc1d95afdc1adb045d2de336b78cf..d0c1e2855cda56f10d9e1bd2b7fd97526f032541 100644 --- a/packages/core/src/ResponseSynthesizer.ts +++ b/packages/core/src/ResponseSynthesizer.ts @@ -2,12 +2,13 @@ import { ChatGPTLLMPredictor } from "./LLMPredictor"; import { NodeWithScore } from "./Node"; import { SimplePrompt, defaultTextQaPrompt } from "./Prompt"; import { Response } from "./Response"; +import { ServiceContext } from "./ServiceContext"; interface BaseResponseBuilder { agetResponse(query: string, textChunks: string[]): Promise<string>; } -export class SimpleResponseBuilder { +export class SimpleResponseBuilder implements BaseResponseBuilder { llmPredictor: ChatGPTLLMPredictor; textQATemplate: SimplePrompt; @@ -27,6 +28,81 @@ export class SimpleResponseBuilder { } } +export class Refine implements BaseResponseBuilder { + async agetResponse( + query: string, + textChunks: string[], + prevResponse?: any + ): Promise<string> { + throw new Error("Not implemented yet"); + } + + private giveResponseSingle(queryStr: string, textChunk: string) { + const textQATemplate = defaultTextQaPrompt; + } + + private refineResponseSingle( + response: string, + queryStr: string, + textChunk: string + ) { + throw new Error("Not implemented yet"); + } +} +export class CompactAndRefine extends Refine { + async agetResponse( + query: string, + textChunks: string[], + prevResponse?: any + ): Promise<string> { + throw new Error("Not implemented yet"); + } +} + +export class TreeSummarize implements BaseResponseBuilder { + serviceContext: ServiceContext; + + constructor(serviceContext: ServiceContext) { + this.serviceContext = serviceContext; + } + + async agetResponse(query: string, textChunks: string[]): Promise<string> { + const summaryTemplate: SimplePrompt = (input) => + defaultTextQaPrompt({ ...input, query: query }); + + if (!textChunks || textChunks.length === 0) { + throw new Error("Must have at least one text chunk"); + } + + // TODO repack more intelligently + // Combine text chunks in pairs into packedTextChunks + let packedTextChunks: string[] = []; + for (let i = 0; i < textChunks.length; i += 2) { + if (i + 1 < textChunks.length) { + packedTextChunks.push(textChunks[i] + "\n\n" + textChunks[i + 1]); + } else { + packedTextChunks.push(textChunks[i]); + } + } + + if (packedTextChunks.length === 1) { + return this.serviceContext.llmPredictor.apredict(summaryTemplate, { + context: packedTextChunks[0], + }); + } else { + const summaries = await Promise.all( + packedTextChunks.map((chunk) => + this.serviceContext.llmPredictor.apredict(summaryTemplate, { + context: chunk, + }) + ) + ); + + return this.agetResponse(query, summaries); + } + } +} + export function getResponseBuilder(): BaseResponseBuilder { return new SimpleResponseBuilder(); } diff --git a/packages/core/src/constants.ts b/packages/core/src/constants.ts index def343a7945c57e0e96ea87fd729e1e0c2f4903c..2c3a3c92dd88e9100e62f27d3f80fc385cf11cd7 100644 --- a/packages/core/src/constants.ts +++ b/packages/core/src/constants.ts @@ -3,6 +3,7 @@ export const DEFAULT_NUM_OUTPUTS = 256; export const DEFAULT_CHUNK_SIZE = 1024; export const DEFAULT_CHUNK_OVERLAP = 20; +export const DEFAULT_CHUNK_OVERLAP_RATIO = 0.1; export const DEFAULT_SIMILARITY_TOP_K = 2; // NOTE: for text-embedding-ada-002