diff --git a/.changeset/cyan-brooms-share.md b/.changeset/cyan-brooms-share.md new file mode 100644 index 0000000000000000000000000000000000000000..4d28cbde84cc6908e528c177ff0aca705ab30c18 --- /dev/null +++ b/.changeset/cyan-brooms-share.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Remove HistoryChatEngine and use ChatHistory for all chat engines diff --git a/.eslintrc.js b/.eslintrc.js index 077771bc0c716fcd35971848830e6245205e6213..5362d4f038fe609e95a795277b5562558fb1fe8e 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -10,4 +10,5 @@ module.exports = { rules: { "max-params": ["error", 4], }, + ignorePatterns: ["dist/"], }; diff --git a/.gitignore b/.gitignore index 782d93b772de24426bdc2e27511fe628f35f93cc..3e17caf04ec627162858e41a93d680b590db2e77 100644 --- a/.gitignore +++ b/.gitignore @@ -39,9 +39,6 @@ yarn-error.log* dist/ lib/ -# vs code -.vscode/launch.json - .cache test-results/ playwright-report/ diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000000000000000000000000000000000..91fc7023c6d240c434ea1159ed8c611fd5c985da --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,17 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "node", + "request": "launch", + "name": "Debug Example", + "skipFiles": ["<node_internals>/**"], + "runtimeExecutable": "pnpm", + "cwd": "${workspaceFolder}/examples", + "runtimeArgs": ["ts-node", "${fileBasename}"] + } + ] +} diff --git a/packages/core/jest.config.cjs b/packages/core/jest.config.cjs index 3abcbd94670c03671629f5b54f5fcae1b960f872..b359e058b5274f085faafdd9d5f7cb91c123ebdd 100644 --- a/packages/core/jest.config.cjs +++ b/packages/core/jest.config.cjs @@ -2,4 +2,5 @@ module.exports = { preset: "ts-jest", testEnvironment: "node", + testPathIgnorePatterns: ["/lib/"], }; diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts deleted file mode 100644 index 2f18a15a49511a419c9f7d0260eeb49bf5a494af..0000000000000000000000000000000000000000 --- a/packages/core/src/ChatEngine.ts +++ /dev/null @@ -1,420 +0,0 @@ -import { randomUUID } from "node:crypto"; -import { ChatHistory, SimpleChatHistory } from "./ChatHistory"; -import { NodeWithScore, TextNode } from "./Node"; -import { - CondenseQuestionPrompt, - ContextSystemPrompt, - defaultCondenseQuestionPrompt, - defaultContextSystemPrompt, - messagesToHistoryStr, -} from "./Prompt"; -import { BaseQueryEngine } from "./QueryEngine"; -import { Response } from "./Response"; -import { BaseRetriever } from "./Retriever"; -import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; -import { Event } from "./callbacks/CallbackManager"; -import { ChatMessage, ChatResponseChunk, LLM, OpenAI } from "./llm"; -import { streamConverter, streamReducer } from "./llm/utils"; -import { BaseNodePostprocessor } from "./postprocessors"; - -/** - * Represents the base parameters for ChatEngine. - */ -export interface ChatEngineParamsBase { - message: MessageContent; - /** - * Optional chat history if you want to customize the chat history. - */ - chatHistory?: ChatMessage[]; - history?: ChatHistory; -} - -export interface ChatEngineParamsStreaming extends ChatEngineParamsBase { - stream: true; -} - -export interface ChatEngineParamsNonStreaming extends ChatEngineParamsBase { - stream?: false | null; -} - -/** - * A ChatEngine is used to handle back and forth chats between the application and the LLM. - */ -export interface ChatEngine { - /** - * Send message along with the class's current chat history to the LLM. - * @param params - */ - chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; - chat(params: ChatEngineParamsNonStreaming): Promise<Response>; - - /** - * Resets the chat history so that it's empty. - */ - reset(): void; -} - -/** - * SimpleChatEngine is the simplest possible chat engine. Useful for using your own custom prompts. - */ -export class SimpleChatEngine implements ChatEngine { - chatHistory: ChatMessage[]; - llm: LLM; - - constructor(init?: Partial<SimpleChatEngine>) { - this.chatHistory = init?.chatHistory ?? []; - this.llm = init?.llm ?? new OpenAI(); - } - - chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; - chat(params: ChatEngineParamsNonStreaming): Promise<Response>; - async chat( - params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, - ): Promise<Response | AsyncIterable<Response>> { - const { message, stream } = params; - - const chatHistory = params.chatHistory ?? this.chatHistory; - chatHistory.push({ content: message, role: "user" }); - - if (stream) { - const stream = await this.llm.chat({ - messages: chatHistory, - stream: true, - }); - return streamConverter( - streamReducer({ - stream, - initialValue: "", - reducer: (accumulator, part) => (accumulator += part.delta), - finished: (accumulator) => { - chatHistory.push({ content: accumulator, role: "assistant" }); - }, - }), - (r: ChatResponseChunk) => new Response(r.delta), - ); - } - - const response = await this.llm.chat({ messages: chatHistory }); - chatHistory.push(response.message); - return new Response(response.message.content); - } - - reset() { - this.chatHistory = []; - } -} - -/** - * CondenseQuestionChatEngine is used in conjunction with a Index (for example VectorStoreIndex). - * It does two steps on taking a user's chat message: first, it condenses the chat message - * with the previous chat history into a question with more context. - * Then, it queries the underlying Index using the new question with context and returns - * the response. - * CondenseQuestionChatEngine performs well when the input is primarily questions about the - * underlying data. It performs less well when the chat messages are not questions about the - * data, or are very referential to previous context. - */ -export class CondenseQuestionChatEngine implements ChatEngine { - queryEngine: BaseQueryEngine; - chatHistory: ChatMessage[]; - llm: LLM; - condenseMessagePrompt: CondenseQuestionPrompt; - - constructor(init: { - queryEngine: BaseQueryEngine; - chatHistory: ChatMessage[]; - serviceContext?: ServiceContext; - condenseMessagePrompt?: CondenseQuestionPrompt; - }) { - this.queryEngine = init.queryEngine; - this.chatHistory = init?.chatHistory ?? []; - this.llm = init?.serviceContext?.llm ?? serviceContextFromDefaults().llm; - this.condenseMessagePrompt = - init?.condenseMessagePrompt ?? defaultCondenseQuestionPrompt; - } - - private async condenseQuestion(chatHistory: ChatMessage[], question: string) { - const chatHistoryStr = messagesToHistoryStr(chatHistory); - - return this.llm.complete({ - prompt: defaultCondenseQuestionPrompt({ - question: question, - chatHistory: chatHistoryStr, - }), - }); - } - - chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; - chat(params: ChatEngineParamsNonStreaming): Promise<Response>; - async chat( - params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, - ): Promise<Response | AsyncIterable<Response>> { - const { message, stream } = params; - const chatHistory = params.chatHistory ?? this.chatHistory; - - const condensedQuestion = ( - await this.condenseQuestion(chatHistory, extractText(message)) - ).text; - chatHistory.push({ content: message, role: "user" }); - - if (stream) { - const stream = await this.queryEngine.query({ - query: condensedQuestion, - stream: true, - }); - return streamReducer({ - stream, - initialValue: "", - reducer: (accumulator, part) => (accumulator += part.response), - finished: (accumulator) => { - chatHistory.push({ content: accumulator, role: "assistant" }); - }, - }); - } - const response = await this.queryEngine.query({ - query: condensedQuestion, - }); - chatHistory.push({ content: response.response, role: "assistant" }); - - return response; - } - - reset() { - this.chatHistory = []; - } -} - -export interface Context { - message: ChatMessage; - nodes: NodeWithScore[]; -} - -export interface ContextGenerator { - generate(message: string, parentEvent?: Event): Promise<Context>; -} - -export class DefaultContextGenerator implements ContextGenerator { - retriever: BaseRetriever; - contextSystemPrompt: ContextSystemPrompt; - nodePostprocessors: BaseNodePostprocessor[]; - - constructor(init: { - retriever: BaseRetriever; - contextSystemPrompt?: ContextSystemPrompt; - nodePostprocessors?: BaseNodePostprocessor[]; - }) { - this.retriever = init.retriever; - this.contextSystemPrompt = - init?.contextSystemPrompt ?? defaultContextSystemPrompt; - this.nodePostprocessors = init.nodePostprocessors || []; - } - - private applyNodePostprocessors(nodes: NodeWithScore[]) { - return this.nodePostprocessors.reduce( - (nodes, nodePostprocessor) => nodePostprocessor.postprocessNodes(nodes), - nodes, - ); - } - - async generate(message: string, parentEvent?: Event): Promise<Context> { - if (!parentEvent) { - parentEvent = { - id: randomUUID(), - type: "wrapper", - tags: ["final"], - }; - } - const sourceNodesWithScore = await this.retriever.retrieve( - message, - parentEvent, - ); - - const nodes = this.applyNodePostprocessors(sourceNodesWithScore); - - return { - message: { - content: this.contextSystemPrompt({ - context: nodes.map((r) => (r.node as TextNode).text).join("\n\n"), - }), - role: "system", - }, - nodes, - }; - } -} - -/** - * ContextChatEngine uses the Index to get the appropriate context for each query. - * The context is stored in the system prompt, and the chat history is preserved, - * ideally allowing the appropriate context to be surfaced for each query. - */ -export class ContextChatEngine implements ChatEngine { - chatModel: LLM; - chatHistory: ChatMessage[]; - contextGenerator: ContextGenerator; - - constructor(init: { - retriever: BaseRetriever; - chatModel?: LLM; - chatHistory?: ChatMessage[]; - contextSystemPrompt?: ContextSystemPrompt; - nodePostprocessors?: BaseNodePostprocessor[]; - }) { - this.chatModel = - init.chatModel ?? new OpenAI({ model: "gpt-3.5-turbo-16k" }); - this.chatHistory = init?.chatHistory ?? []; - this.contextGenerator = new DefaultContextGenerator({ - retriever: init.retriever, - contextSystemPrompt: init?.contextSystemPrompt, - }); - } - - chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; - chat(params: ChatEngineParamsNonStreaming): Promise<Response>; - async chat( - params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, - ): Promise<Response | AsyncIterable<Response>> { - const { message, stream } = params; - const chatHistory = params.chatHistory ?? this.chatHistory; - const parentEvent: Event = { - id: randomUUID(), - type: "wrapper", - tags: ["final"], - }; - const context = await this.contextGenerator.generate( - extractText(message), - parentEvent, - ); - const nodes = context.nodes.map((r) => r.node); - chatHistory.push({ content: message, role: "user" }); - - if (stream) { - const stream = await this.chatModel.chat({ - messages: [context.message, ...chatHistory], - parentEvent, - stream: true, - }); - return streamConverter( - streamReducer({ - stream, - initialValue: "", - reducer: (accumulator, part) => (accumulator += part.delta), - finished: (accumulator) => { - chatHistory.push({ content: accumulator, role: "assistant" }); - }, - }), - (r: ChatResponseChunk) => new Response(r.delta, nodes), - ); - } - const response = await this.chatModel.chat({ - messages: [context.message, ...chatHistory], - parentEvent, - }); - chatHistory.push(response.message); - return new Response(response.message.content, nodes); - } - - reset() { - this.chatHistory = []; - } -} - -export interface MessageContentDetail { - type: "text" | "image_url"; - text?: string; - image_url?: { url: string }; -} - -/** - * Extended type for the content of a message that allows for multi-modal messages. - */ -export type MessageContent = string | MessageContentDetail[]; - -/** - * Extracts just the text from a multi-modal message or the message itself if it's just text. - * - * @param message The message to extract text from. - * @returns The extracted text - */ -function extractText(message: MessageContent): string { - if (Array.isArray(message)) { - // message is of type MessageContentDetail[] - retrieve just the text parts and concatenate them - // so we can pass them to the context generator - return (message as MessageContentDetail[]) - .filter((c) => c.type === "text") - .map((c) => c.text) - .join("\n\n"); - } - return message; -} - -/** - * HistoryChatEngine is a ChatEngine that uses a `ChatHistory` object - * to keeps track of chat's message history. - * A `ChatHistory` object is passed as a parameter for each call to the `chat` method, - * so the state of the chat engine is preserved between calls. - * Optionally, a `ContextGenerator` can be used to generate an additional context for each call to `chat`. - */ -export class HistoryChatEngine { - llm: LLM; - contextGenerator?: ContextGenerator; - - constructor(init?: Partial<HistoryChatEngine>) { - this.llm = init?.llm ?? new OpenAI(); - this.contextGenerator = init?.contextGenerator; - } - - chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; - chat(params: ChatEngineParamsNonStreaming): Promise<Response>; - async chat( - params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, - ): Promise<Response | AsyncIterable<Response>> { - const { message, stream, history } = params; - const chatHistory = history ?? new SimpleChatHistory(); - const requestMessages = await this.prepareRequestMessages( - message, - chatHistory, - ); - - if (stream) { - const stream = await this.llm.chat({ - messages: requestMessages, - stream: true, - }); - return streamConverter( - streamReducer({ - stream, - initialValue: "", - reducer: (accumulator, part) => (accumulator += part.delta), - finished: (accumulator) => { - chatHistory.addMessage({ content: accumulator, role: "assistant" }); - }, - }), - (r: ChatResponseChunk) => new Response(r.delta), - ); - } - const response = await this.llm.chat({ messages: requestMessages }); - chatHistory.addMessage(response.message); - return new Response(response.message.content); - } - - private async prepareRequestMessages( - message: MessageContent, - chatHistory: ChatHistory, - ) { - chatHistory.addMessage({ - content: message, - role: "user", - }); - let requestMessages; - let context; - if (this.contextGenerator) { - const textOnly = extractText(message); - context = await this.contextGenerator.generate(textOnly); - } - requestMessages = await chatHistory.requestMessages( - context ? [context.message] : undefined, - ); - return requestMessages; - } -} diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts index f49ef318b9db3390a25189df03a0ad635ec751b8..4eb89bbed4cde8bf1159f58438b15cd22559544e 100644 --- a/packages/core/src/ChatHistory.ts +++ b/packages/core/src/ChatHistory.ts @@ -1,4 +1,5 @@ -import { ChatMessage, LLM, MessageType, OpenAI } from "./llm/LLM"; +import { OpenAI } from "./llm/LLM"; +import { ChatMessage, LLM, MessageType } from "./llm/types"; import { defaultSummaryPrompt, messagesToHistoryStr, @@ -8,35 +9,38 @@ import { /** * A ChatHistory is used to keep the state of back and forth chat messages */ -export interface ChatHistory { - messages: ChatMessage[]; +export abstract class ChatHistory { + abstract get messages(): ChatMessage[]; /** * Adds a message to the chat history. * @param message */ - addMessage(message: ChatMessage): void; + abstract addMessage(message: ChatMessage): void; /** * Returns the messages that should be used as input to the LLM. */ - requestMessages(transientMessages?: ChatMessage[]): Promise<ChatMessage[]>; + abstract requestMessages( + transientMessages?: ChatMessage[], + ): Promise<ChatMessage[]>; /** * Resets the chat history so that it's empty. */ - reset(): void; + abstract reset(): void; /** * Returns the new messages since the last call to this function (or since calling the constructor) */ - newMessages(): ChatMessage[]; + abstract newMessages(): ChatMessage[]; } -export class SimpleChatHistory implements ChatHistory { +export class SimpleChatHistory extends ChatHistory { messages: ChatMessage[]; private messagesBefore: number; constructor(init?: Partial<SimpleChatHistory>) { + super(); this.messages = init?.messages ?? []; this.messagesBefore = this.messages.length; } @@ -60,7 +64,7 @@ export class SimpleChatHistory implements ChatHistory { } } -export class SummaryChatHistory implements ChatHistory { +export class SummaryChatHistory extends ChatHistory { tokensToSummarize: number; messages: ChatMessage[]; summaryPrompt: SummaryPrompt; @@ -68,6 +72,7 @@ export class SummaryChatHistory implements ChatHistory { private messagesBefore: number; constructor(init?: Partial<SummaryChatHistory>) { + super(); this.messages = init?.messages ?? []; this.messagesBefore = this.messages.length; this.summaryPrompt = init?.summaryPrompt ?? defaultSummaryPrompt; @@ -79,6 +84,11 @@ export class SummaryChatHistory implements ChatHistory { } this.tokensToSummarize = this.llm.metadata.contextWindow - this.llm.metadata.maxTokens; + if (this.tokensToSummarize < this.llm.metadata.contextWindow * 0.25) { + throw new Error( + "The number of tokens that trigger the summarize process are less than 25% of the context window. Try lowering maxTokens or use a model with a larger context window.", + ); + } } private async summarize(): Promise<ChatMessage> { @@ -198,3 +208,12 @@ export class SummaryChatHistory implements ChatHistory { return newMessages; } } + +export function getHistory( + chatHistory?: ChatMessage[] | ChatHistory, +): ChatHistory { + if (chatHistory instanceof ChatHistory) { + return chatHistory; + } + return new SimpleChatHistory({ messages: chatHistory }); +} diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts index be03ef76c775d21cbbcb900a21d0e6118cc774d7..b89fe772080f5a2878d7cbc9926b26e370f2e5a9 100644 --- a/packages/core/src/Prompt.ts +++ b/packages/core/src/Prompt.ts @@ -1,4 +1,4 @@ -import { ChatMessage } from "./llm/LLM"; +import { ChatMessage } from "./llm/types"; import { SubQuestion } from "./QuestionGenerator"; import { ToolMetadata } from "./Tool"; diff --git a/packages/core/src/QuestionGenerator.ts b/packages/core/src/QuestionGenerator.ts index 12a590832edc38ac3656f7dc1409786f890f9f39..eceac0303a7332d9e20a571428acd910d82136f4 100644 --- a/packages/core/src/QuestionGenerator.ts +++ b/packages/core/src/QuestionGenerator.ts @@ -9,7 +9,8 @@ import { defaultSubQuestionPrompt, } from "./Prompt"; import { ToolMetadata } from "./Tool"; -import { LLM, OpenAI } from "./llm/LLM"; +import { OpenAI } from "./llm/LLM"; +import { LLM } from "./llm/types"; export interface SubQuestion { subQuestion: string; diff --git a/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts b/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts new file mode 100644 index 0000000000000000000000000000000000000000..07ce7aad45b9b114f09ba7c975d011e5dbc89301 --- /dev/null +++ b/packages/core/src/engines/chat/CondenseQuestionChatEngine.ts @@ -0,0 +1,104 @@ +import { ChatHistory, getHistory } from "../../ChatHistory"; +import { + CondenseQuestionPrompt, + defaultCondenseQuestionPrompt, + messagesToHistoryStr, +} from "../../Prompt"; +import { BaseQueryEngine } from "../../QueryEngine"; +import { Response } from "../../Response"; +import { + ServiceContext, + serviceContextFromDefaults, +} from "../../ServiceContext"; +import { ChatMessage, LLM } from "../../llm"; +import { extractText, streamReducer } from "../../llm/utils"; +import { + ChatEngine, + ChatEngineParamsNonStreaming, + ChatEngineParamsStreaming, +} from "./types"; + +/** + * CondenseQuestionChatEngine is used in conjunction with a Index (for example VectorStoreIndex). + * It does two steps on taking a user's chat message: first, it condenses the chat message + * with the previous chat history into a question with more context. + * Then, it queries the underlying Index using the new question with context and returns + * the response. + * CondenseQuestionChatEngine performs well when the input is primarily questions about the + * underlying data. It performs less well when the chat messages are not questions about the + * data, or are very referential to previous context. + */ + +export class CondenseQuestionChatEngine implements ChatEngine { + queryEngine: BaseQueryEngine; + chatHistory: ChatHistory; + llm: LLM; + condenseMessagePrompt: CondenseQuestionPrompt; + + constructor(init: { + queryEngine: BaseQueryEngine; + chatHistory: ChatMessage[]; + serviceContext?: ServiceContext; + condenseMessagePrompt?: CondenseQuestionPrompt; + }) { + this.queryEngine = init.queryEngine; + this.chatHistory = getHistory(init?.chatHistory); + this.llm = init?.serviceContext?.llm ?? serviceContextFromDefaults().llm; + this.condenseMessagePrompt = + init?.condenseMessagePrompt ?? defaultCondenseQuestionPrompt; + } + + private async condenseQuestion(chatHistory: ChatHistory, question: string) { + const chatHistoryStr = messagesToHistoryStr( + await chatHistory.requestMessages(), + ); + + return this.llm.complete({ + prompt: defaultCondenseQuestionPrompt({ + question: question, + chatHistory: chatHistoryStr, + }), + }); + } + + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; + chat(params: ChatEngineParamsNonStreaming): Promise<Response>; + async chat( + params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { message, stream } = params; + const chatHistory = params.chatHistory + ? getHistory(params.chatHistory) + : this.chatHistory; + + const condensedQuestion = ( + await this.condenseQuestion(chatHistory, extractText(message)) + ).text; + chatHistory.addMessage({ content: message, role: "user" }); + + if (stream) { + const stream = await this.queryEngine.query({ + query: condensedQuestion, + stream: true, + }); + return streamReducer({ + stream, + initialValue: "", + reducer: (accumulator, part) => (accumulator += part.response), + finished: (accumulator) => { + chatHistory.addMessage({ content: accumulator, role: "assistant" }); + }, + }); + } + const response = await this.queryEngine.query({ + query: condensedQuestion, + }); + chatHistory.addMessage({ content: response.response, role: "assistant" }); + + return response; + } + + reset() { + this.chatHistory.reset(); + } +} diff --git a/packages/core/src/engines/chat/ContextChatEngine.ts b/packages/core/src/engines/chat/ContextChatEngine.ts new file mode 100644 index 0000000000000000000000000000000000000000..23500659715bd4572b36bda5fdc58a5b8e62703c --- /dev/null +++ b/packages/core/src/engines/chat/ContextChatEngine.ts @@ -0,0 +1,113 @@ +import { randomUUID } from "node:crypto"; +import { ChatHistory, getHistory } from "../../ChatHistory"; +import { ContextSystemPrompt } from "../../Prompt"; +import { Response } from "../../Response"; +import { BaseRetriever } from "../../Retriever"; +import { Event } from "../../callbacks/CallbackManager"; +import { ChatMessage, ChatResponseChunk, LLM, OpenAI } from "../../llm"; +import { MessageContent } from "../../llm/types"; +import { extractText, streamConverter, streamReducer } from "../../llm/utils"; +import { BaseNodePostprocessor } from "../../postprocessors"; +import { DefaultContextGenerator } from "./DefaultContextGenerator"; +import { + ChatEngine, + ChatEngineParamsNonStreaming, + ChatEngineParamsStreaming, + ContextGenerator, +} from "./types"; + +/** + * ContextChatEngine uses the Index to get the appropriate context for each query. + * The context is stored in the system prompt, and the chat history is preserved, + * ideally allowing the appropriate context to be surfaced for each query. + */ +export class ContextChatEngine implements ChatEngine { + chatModel: LLM; + chatHistory: ChatHistory; + contextGenerator: ContextGenerator; + + constructor(init: { + retriever: BaseRetriever; + chatModel?: LLM; + chatHistory?: ChatMessage[]; + contextSystemPrompt?: ContextSystemPrompt; + nodePostprocessors?: BaseNodePostprocessor[]; + }) { + this.chatModel = + init.chatModel ?? new OpenAI({ model: "gpt-3.5-turbo-16k" }); + this.chatHistory = getHistory(init?.chatHistory); + this.contextGenerator = new DefaultContextGenerator({ + retriever: init.retriever, + contextSystemPrompt: init?.contextSystemPrompt, + nodePostprocessors: init?.nodePostprocessors, + }); + } + + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; + chat(params: ChatEngineParamsNonStreaming): Promise<Response>; + async chat( + params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { message, stream } = params; + const chatHistory = params.chatHistory + ? getHistory(params.chatHistory) + : this.chatHistory; + const parentEvent: Event = { + id: randomUUID(), + type: "wrapper", + tags: ["final"], + }; + const requestMessages = await this.prepareRequestMessages( + message, + chatHistory, + parentEvent, + ); + + if (stream) { + const stream = await this.chatModel.chat({ + messages: requestMessages.messages, + parentEvent, + stream: true, + }); + return streamConverter( + streamReducer({ + stream, + initialValue: "", + reducer: (accumulator, part) => (accumulator += part.delta), + finished: (accumulator) => { + chatHistory.addMessage({ content: accumulator, role: "assistant" }); + }, + }), + (r: ChatResponseChunk) => new Response(r.delta, requestMessages.nodes), + ); + } + const response = await this.chatModel.chat({ + messages: requestMessages.messages, + parentEvent, + }); + chatHistory.addMessage(response.message); + return new Response(response.message.content, requestMessages.nodes); + } + + reset() { + this.chatHistory.reset(); + } + + private async prepareRequestMessages( + message: MessageContent, + chatHistory: ChatHistory, + parentEvent?: Event, + ) { + chatHistory.addMessage({ + content: message, + role: "user", + }); + const textOnly = extractText(message); + const context = await this.contextGenerator.generate(textOnly, parentEvent); + const nodes = context.nodes.map((r) => r.node); + const messages = await chatHistory.requestMessages( + context ? [context.message] : undefined, + ); + return { nodes, messages }; + } +} diff --git a/packages/core/src/engines/chat/DefaultContextGenerator.ts b/packages/core/src/engines/chat/DefaultContextGenerator.ts new file mode 100644 index 0000000000000000000000000000000000000000..ea5604f2cf74076c349ce63ada7429d17f750d5a --- /dev/null +++ b/packages/core/src/engines/chat/DefaultContextGenerator.ts @@ -0,0 +1,57 @@ +import { randomUUID } from "node:crypto"; +import { NodeWithScore, TextNode } from "../../Node"; +import { ContextSystemPrompt, defaultContextSystemPrompt } from "../../Prompt"; +import { BaseRetriever } from "../../Retriever"; +import { Event } from "../../callbacks/CallbackManager"; +import { BaseNodePostprocessor } from "../../postprocessors"; +import { Context, ContextGenerator } from "./types"; + +export class DefaultContextGenerator implements ContextGenerator { + retriever: BaseRetriever; + contextSystemPrompt: ContextSystemPrompt; + nodePostprocessors: BaseNodePostprocessor[]; + + constructor(init: { + retriever: BaseRetriever; + contextSystemPrompt?: ContextSystemPrompt; + nodePostprocessors?: BaseNodePostprocessor[]; + }) { + this.retriever = init.retriever; + this.contextSystemPrompt = + init?.contextSystemPrompt ?? defaultContextSystemPrompt; + this.nodePostprocessors = init.nodePostprocessors || []; + } + + private applyNodePostprocessors(nodes: NodeWithScore[]) { + return this.nodePostprocessors.reduce( + (nodes, nodePostprocessor) => nodePostprocessor.postprocessNodes(nodes), + nodes, + ); + } + + async generate(message: string, parentEvent?: Event): Promise<Context> { + if (!parentEvent) { + parentEvent = { + id: randomUUID(), + type: "wrapper", + tags: ["final"], + }; + } + const sourceNodesWithScore = await this.retriever.retrieve( + message, + parentEvent, + ); + + const nodes = this.applyNodePostprocessors(sourceNodesWithScore); + + return { + message: { + content: this.contextSystemPrompt({ + context: nodes.map((r) => (r.node as TextNode).text).join("\n\n"), + }), + role: "system", + }, + nodes, + }; + } +} diff --git a/packages/core/src/engines/chat/SimpleChatEngine.ts b/packages/core/src/engines/chat/SimpleChatEngine.ts new file mode 100644 index 0000000000000000000000000000000000000000..370e32cfe91d0ad851d431aebcc0e696c4ce1653 --- /dev/null +++ b/packages/core/src/engines/chat/SimpleChatEngine.ts @@ -0,0 +1,64 @@ +import { ChatHistory, getHistory } from "../../ChatHistory"; +import { Response } from "../../Response"; +import { ChatResponseChunk, LLM, OpenAI } from "../../llm"; +import { streamConverter, streamReducer } from "../../llm/utils"; +import { + ChatEngine, + ChatEngineParamsNonStreaming, + ChatEngineParamsStreaming, +} from "./types"; + +/** + * SimpleChatEngine is the simplest possible chat engine. Useful for using your own custom prompts. + */ + +export class SimpleChatEngine implements ChatEngine { + chatHistory: ChatHistory; + llm: LLM; + + constructor(init?: Partial<SimpleChatEngine>) { + this.chatHistory = getHistory(init?.chatHistory); + this.llm = init?.llm ?? new OpenAI(); + } + + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; + chat(params: ChatEngineParamsNonStreaming): Promise<Response>; + async chat( + params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { message, stream } = params; + + const chatHistory = params.chatHistory + ? getHistory(params.chatHistory) + : this.chatHistory; + chatHistory.addMessage({ content: message, role: "user" }); + + if (stream) { + const stream = await this.llm.chat({ + messages: await chatHistory.requestMessages(), + stream: true, + }); + return streamConverter( + streamReducer({ + stream, + initialValue: "", + reducer: (accumulator, part) => (accumulator += part.delta), + finished: (accumulator) => { + chatHistory.addMessage({ content: accumulator, role: "assistant" }); + }, + }), + (r: ChatResponseChunk) => new Response(r.delta), + ); + } + + const response = await this.llm.chat({ + messages: await chatHistory.requestMessages(), + }); + chatHistory.addMessage(response.message); + return new Response(response.message.content); + } + + reset() { + this.chatHistory.reset(); + } +} diff --git a/packages/core/src/engines/chat/index.ts b/packages/core/src/engines/chat/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..3049bfa5780b8da2e6981098037bb8a3ae655a2d --- /dev/null +++ b/packages/core/src/engines/chat/index.ts @@ -0,0 +1,4 @@ +export { CondenseQuestionChatEngine } from "./CondenseQuestionChatEngine"; +export { ContextChatEngine } from "./ContextChatEngine"; +export { SimpleChatEngine } from "./SimpleChatEngine"; +export * from "./types"; diff --git a/packages/core/src/engines/chat/types.ts b/packages/core/src/engines/chat/types.ts new file mode 100644 index 0000000000000000000000000000000000000000..1d8c18378450d764a6df3168b1ffa64ae2871edc --- /dev/null +++ b/packages/core/src/engines/chat/types.ts @@ -0,0 +1,54 @@ +import { ChatHistory } from "../../ChatHistory"; +import { NodeWithScore } from "../../Node"; +import { Response } from "../../Response"; +import { Event } from "../../callbacks/CallbackManager"; +import { ChatMessage } from "../../llm"; +import { MessageContent } from "../../llm/types"; + +/** + * Represents the base parameters for ChatEngine. + */ +export interface ChatEngineParamsBase { + message: MessageContent; + /** + * Optional chat history if you want to customize the chat history. + */ + chatHistory?: ChatMessage[] | ChatHistory; +} + +export interface ChatEngineParamsStreaming extends ChatEngineParamsBase { + stream: true; +} + +export interface ChatEngineParamsNonStreaming extends ChatEngineParamsBase { + stream?: false | null; +} + +/** + * A ChatEngine is used to handle back and forth chats between the application and the LLM. + */ +export interface ChatEngine { + /** + * Send message along with the class's current chat history to the LLM. + * @param params + */ + chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>; + chat(params: ChatEngineParamsNonStreaming): Promise<Response>; + + /** + * Resets the chat history so that it's empty. + */ + reset(): void; +} + +export interface Context { + message: ChatMessage; + nodes: NodeWithScore[]; +} + +/** + * A ContextGenerator is used to generate a context based on a message's text content + */ +export interface ContextGenerator { + generate(message: string, parentEvent?: Event): Promise<Context>; +} diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index c0eda73767afac895eb462b69e07d8b47548bdd8..3b0dfef57fc33cce1cb6ae330005895ac17b9971 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,4 +1,3 @@ -export * from "./ChatEngine"; export * from "./ChatHistory"; export * from "./GlobalsHelper"; export * from "./Node"; @@ -15,6 +14,7 @@ export * from "./Tool"; export * from "./callbacks/CallbackManager"; export * from "./constants"; export * from "./embeddings"; +export * from "./engines/chat"; export * from "./indices"; export * from "./llm"; export * from "./nodeParsers"; diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index e6ee82e4daaf09017ee5b5092606743920a7dd40..28a16c6355682b0bc41c3aa3d1296a10325ed3eb 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -24,146 +24,19 @@ import { getAzureModel, shouldUseAzure, } from "./azure"; +import { BaseLLM } from "./base"; import { OpenAISession, getOpenAISession } from "./openai"; import { PortkeySession, getPortkeySession } from "./portkey"; import { ReplicateSession } from "./replicate"; -import { streamConverter } from "./utils"; - -export type MessageType = - | "user" - | "assistant" - | "system" - | "generic" - | "function" - | "memory"; - -export interface ChatMessage { - content: any; - role: MessageType; -} - -export interface ChatResponse { - message: ChatMessage; - raw?: Record<string, any>; -} - -export interface ChatResponseChunk { - delta: string; -} - -export interface CompletionResponse { - text: string; - raw?: Record<string, any>; -} - -export interface LLMMetadata { - model: string; - temperature: number; - topP: number; - maxTokens?: number; - contextWindow: number; - tokenizer: Tokenizers | undefined; -} - -export interface LLMChatParamsBase { - messages: ChatMessage[]; - parentEvent?: Event; - extraParams?: Record<string, any>; -} - -export interface LLMChatParamsStreaming extends LLMChatParamsBase { - stream: true; -} - -export interface LLMChatParamsNonStreaming extends LLMChatParamsBase { - stream?: false | null; -} - -export interface LLMCompletionParamsBase { - prompt: any; - parentEvent?: Event; -} - -export interface LLMCompletionParamsStreaming extends LLMCompletionParamsBase { - stream: true; -} - -export interface LLMCompletionParamsNonStreaming - extends LLMCompletionParamsBase { - stream?: false | null; -} - -/** - * Unified language model interface - */ -export interface LLM { - metadata: LLMMetadata; - /** - * Get a chat response from the LLM - * - * @param params - */ - chat( - params: LLMChatParamsStreaming, - ): Promise<AsyncIterable<ChatResponseChunk>>; - chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; - - /** - * Get a prompt completion from the LLM - * @param params - */ - complete( - params: LLMCompletionParamsStreaming, - ): Promise<AsyncIterable<CompletionResponse>>; - complete( - params: LLMCompletionParamsNonStreaming, - ): Promise<CompletionResponse>; - - /** - * Calculates the number of tokens needed for the given chat messages - */ - tokens(messages: ChatMessage[]): number; -} - -export abstract class BaseLLM implements LLM { - abstract metadata: LLMMetadata; - - complete( - params: LLMCompletionParamsStreaming, - ): Promise<AsyncIterable<CompletionResponse>>; - complete( - params: LLMCompletionParamsNonStreaming, - ): Promise<CompletionResponse>; - async complete( - params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming, - ): Promise<CompletionResponse | AsyncIterable<CompletionResponse>> { - const { prompt, parentEvent, stream } = params; - if (stream) { - const stream = await this.chat({ - messages: [{ content: prompt, role: "user" }], - parentEvent, - stream: true, - }); - return streamConverter(stream, (chunk) => { - return { - text: chunk.delta, - }; - }); - } - const chatResponse = await this.chat({ - messages: [{ content: prompt, role: "user" }], - parentEvent, - }); - return { text: chatResponse.message.content as string }; - } - - abstract chat( - params: LLMChatParamsStreaming, - ): Promise<AsyncIterable<ChatResponseChunk>>; - abstract chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; - - abstract tokens(messages: ChatMessage[]): number; -} +import { + ChatMessage, + ChatResponse, + ChatResponseChunk, + LLMChatParamsNonStreaming, + LLMChatParamsStreaming, + LLMMetadata, + MessageType, +} from "./types"; export const GPT4_MODELS = { "gpt-4": { contextWindow: 8192 }, diff --git a/packages/core/src/llm/base.ts b/packages/core/src/llm/base.ts new file mode 100644 index 0000000000000000000000000000000000000000..574710b4cd9900e00d5d45ef85518101507d5869 --- /dev/null +++ b/packages/core/src/llm/base.ts @@ -0,0 +1,53 @@ +import { + ChatMessage, + ChatResponse, + ChatResponseChunk, + CompletionResponse, + LLM, + LLMChatParamsNonStreaming, + LLMChatParamsStreaming, + LLMCompletionParamsNonStreaming, + LLMCompletionParamsStreaming, + LLMMetadata, +} from "./types"; +import { streamConverter } from "./utils"; + +export abstract class BaseLLM implements LLM { + abstract metadata: LLMMetadata; + + complete( + params: LLMCompletionParamsStreaming, + ): Promise<AsyncIterable<CompletionResponse>>; + complete( + params: LLMCompletionParamsNonStreaming, + ): Promise<CompletionResponse>; + async complete( + params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming, + ): Promise<CompletionResponse | AsyncIterable<CompletionResponse>> { + const { prompt, parentEvent, stream } = params; + if (stream) { + const stream = await this.chat({ + messages: [{ content: prompt, role: "user" }], + parentEvent, + stream: true, + }); + return streamConverter(stream, (chunk) => { + return { + text: chunk.delta, + }; + }); + } + const chatResponse = await this.chat({ + messages: [{ content: prompt, role: "user" }], + parentEvent, + }); + return { text: chatResponse.message.content as string }; + } + + abstract chat( + params: LLMChatParamsStreaming, + ): Promise<AsyncIterable<ChatResponseChunk>>; + abstract chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; + + abstract tokens(messages: ChatMessage[]): number; +} diff --git a/packages/core/src/llm/index.ts b/packages/core/src/llm/index.ts index 74e0b91d939065e4d566517c614271e645f58677..a762899892dfa768af5646701d2616c0200a77ea 100644 --- a/packages/core/src/llm/index.ts +++ b/packages/core/src/llm/index.ts @@ -2,3 +2,4 @@ export * from "./LLM"; export * from "./mistral"; export { Ollama } from "./ollama"; export { TogetherLLM } from "./together"; +export * from "./types"; diff --git a/packages/core/src/llm/mistral.ts b/packages/core/src/llm/mistral.ts index 35b891e5e796ff8f4a01fc6a8547bba64d061ba4..61b979b740a9cf6aade2306cc825a47dcd23dd74 100644 --- a/packages/core/src/llm/mistral.ts +++ b/packages/core/src/llm/mistral.ts @@ -4,14 +4,14 @@ import { EventType, StreamCallbackResponse, } from "../callbacks/CallbackManager"; +import { BaseLLM } from "./base"; import { - BaseLLM, ChatMessage, ChatResponse, ChatResponseChunk, LLMChatParamsNonStreaming, LLMChatParamsStreaming, -} from "./LLM"; +} from "./types"; export const ALL_AVAILABLE_MISTRAL_MODELS = { "mistral-tiny": { contextWindow: 32000 }, diff --git a/packages/core/src/llm/ollama.ts b/packages/core/src/llm/ollama.ts index 29b24b1b28026acd4531455dc9fdf78320a0f851..96a9aba4401a0cf4a7ca981925920a22b52ba328 100644 --- a/packages/core/src/llm/ollama.ts +++ b/packages/core/src/llm/ollama.ts @@ -12,7 +12,7 @@ import { LLMCompletionParamsNonStreaming, LLMCompletionParamsStreaming, LLMMetadata, -} from "./LLM"; +} from "./types"; const messageAccessor = (data: any): ChatResponseChunk => { return { diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts new file mode 100644 index 0000000000000000000000000000000000000000..97ead4a7541ed5f6aacb814f9c155fbfd6886863 --- /dev/null +++ b/packages/core/src/llm/types.ts @@ -0,0 +1,110 @@ +import { Tokenizers } from "../GlobalsHelper"; +import { Event } from "../callbacks/CallbackManager"; + +/** + * Unified language model interface + */ +export interface LLM { + metadata: LLMMetadata; + /** + * Get a chat response from the LLM + * + * @param params + */ + chat( + params: LLMChatParamsStreaming, + ): Promise<AsyncIterable<ChatResponseChunk>>; + chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; + + /** + * Get a prompt completion from the LLM + * @param params + */ + complete( + params: LLMCompletionParamsStreaming, + ): Promise<AsyncIterable<CompletionResponse>>; + complete( + params: LLMCompletionParamsNonStreaming, + ): Promise<CompletionResponse>; + + /** + * Calculates the number of tokens needed for the given chat messages + */ + tokens(messages: ChatMessage[]): number; +} + +export type MessageType = + | "user" + | "assistant" + | "system" + | "generic" + | "function" + | "memory"; + +export interface ChatMessage { + // TODO: use MessageContent + content: any; + role: MessageType; +} + +export interface ChatResponse { + message: ChatMessage; + raw?: Record<string, any>; +} + +export interface ChatResponseChunk { + delta: string; +} + +export interface CompletionResponse { + text: string; + raw?: Record<string, any>; +} + +export interface LLMMetadata { + model: string; + temperature: number; + topP: number; + maxTokens?: number; + contextWindow: number; + tokenizer: Tokenizers | undefined; +} + +export interface LLMChatParamsBase { + messages: ChatMessage[]; + parentEvent?: Event; + extraParams?: Record<string, any>; +} + +export interface LLMChatParamsStreaming extends LLMChatParamsBase { + stream: true; +} + +export interface LLMChatParamsNonStreaming extends LLMChatParamsBase { + stream?: false | null; +} + +export interface LLMCompletionParamsBase { + prompt: any; + parentEvent?: Event; +} + +export interface LLMCompletionParamsStreaming extends LLMCompletionParamsBase { + stream: true; +} + +export interface LLMCompletionParamsNonStreaming + extends LLMCompletionParamsBase { + stream?: false | null; +} + +export interface MessageContentDetail { + type: "text" | "image_url"; + text?: string; + image_url?: { url: string }; +} + +/** + * Extended type for the content of a message that allows for multi-modal messages. + */ +export type MessageContent = string | MessageContentDetail[]; diff --git a/packages/core/src/llm/utils.ts b/packages/core/src/llm/utils.ts index 55fc5f78d4765bb7f7c78775075ff81fb5f2c18e..39bf27ba329de416a561cafcafa618a5ace946de 100644 --- a/packages/core/src/llm/utils.ts +++ b/packages/core/src/llm/utils.ts @@ -1,3 +1,5 @@ +import { MessageContent, MessageContentDetail } from "./types"; + export async function* streamConverter<S, D>( stream: AsyncIterable<S>, converter: (s: S) => D, @@ -22,3 +24,21 @@ export async function* streamReducer<S, D>(params: { params.finished(value); } } +/** + * Extracts just the text from a multi-modal message or the message itself if it's just text. + * + * @param message The message to extract text from. + * @returns The extracted text + */ + +export function extractText(message: MessageContent): string { + if (Array.isArray(message)) { + // message is of type MessageContentDetail[] - retrieve just the text parts and concatenate them + // so we can pass them to the context generator + return (message as MessageContentDetail[]) + .filter((c) => c.type === "text") + .map((c) => c.text) + .join("\n\n"); + } + return message; +} diff --git a/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts b/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts index b6d87ea4d8972f2ff4720c1f45c06ce729d50d82..307138350605cb30b2180921138ed4752c317ab7 100644 --- a/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts +++ b/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts @@ -1,8 +1,8 @@ -import { MessageContentDetail } from "../ChatEngine"; import { ImageNode, MetadataMode, splitNodesByType } from "../Node"; import { Response } from "../Response"; import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext"; import { imageToDataUrl } from "../embeddings"; +import { MessageContentDetail } from "../llm/types"; import { TextQaPrompt, defaultTextQaPrompt } from "./../Prompt"; import { BaseSynthesizer, diff --git a/packages/core/src/tests/utility/mockOpenAI.ts b/packages/core/src/tests/utility/mockOpenAI.ts index 8b19972c6fc60a19452afadb819ca859c73dc239..27e20e3bf2d5927c31eaa10b51b4fd9b468f447d 100644 --- a/packages/core/src/tests/utility/mockOpenAI.ts +++ b/packages/core/src/tests/utility/mockOpenAI.ts @@ -1,7 +1,8 @@ import { CallbackManager } from "../../callbacks/CallbackManager"; import { OpenAIEmbedding } from "../../embeddings"; import { globalsHelper } from "../../GlobalsHelper"; -import { LLMChatParamsBase, OpenAI } from "../../llm/LLM"; +import { OpenAI } from "../../llm/LLM"; +import { LLMChatParamsBase } from "../../llm/types"; export function mockLlmGeneration({ languageModel,