diff --git a/packages/core/package.json b/packages/core/package.json index 72f89fef98065eaafd8edeb2341ed18bb2bb554a..97dc54bf0f3101d26c67348c28b4330709625c4b 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -38,6 +38,7 @@ "scripts": { "lint": "eslint .", "test": "jest", - "build": "tsup src/index.ts --format esm,cjs --dts" + "build": "tsup src/index.ts --format esm,cjs --dts", + "dev": "tsup src/index.ts --format esm,cjs --watch" } -} +} \ No newline at end of file diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index 92501bf5ff9ac86bf9427ac98be9e4dc83fc4460..0d39392b271212cfc44504d660d18ba4d217bed4 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -314,11 +314,11 @@ export class HistoryChatEngine implements ChatEngine { ): Promise<R> { //Streaming option if (streaming) { - return this.streamChat(message, chatHistory) as R; + return this.streamChat(message) as R; } - this.chatHistory.addMessage({ content: message, role: "user" }); + await this.chatHistory.addMessage({ content: message, role: "user" }); const response = await this.llm.chat(this.chatHistory.requestMessages); - this.chatHistory.addMessage(response.message); + await this.chatHistory.addMessage(response.message); return new Response(response.message.content) as R; } @@ -326,7 +326,7 @@ export class HistoryChatEngine implements ChatEngine { message: string, chatHistory?: ChatMessage[] | undefined, ): AsyncGenerator<string, void, unknown> { - this.chatHistory.addMessage({ content: message, role: "user" }); + await this.chatHistory.addMessage({ content: message, role: "user" }); const response_stream = await this.llm.chat( this.chatHistory.requestMessages, undefined, @@ -338,7 +338,10 @@ export class HistoryChatEngine implements ChatEngine { accumulator += part; yield part; } - this.chatHistory.addMessage({ content: accumulator, role: "user" }); + await this.chatHistory.addMessage({ + content: accumulator, + role: "assistant", + }); return; } diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts index 39c19c3905af72c73f32f0574e84424242acf0a1..0ec6c7bb2bf1bfeaa36e3d98ffa25623f5ee576d 100644 --- a/packages/core/src/ChatHistory.ts +++ b/packages/core/src/ChatHistory.ts @@ -1,4 +1,10 @@ -import { ChatMessage, LLM, OpenAI } from "./llm/LLM"; +import tiktoken from "tiktoken"; +import { + ALL_AVAILABLE_OPENAI_MODELS, + ChatMessage, + MessageType, + OpenAI, +} from "./llm/LLM"; import { defaultSummaryPrompt, messagesToHistoryStr, @@ -47,59 +53,104 @@ export class SimpleChatHistory implements ChatHistory { } export class SummaryChatHistory implements ChatHistory { - messagesToSummarize: number; + tokensToSummarize: number; messages: ChatMessage[]; summaryPrompt: SummaryPrompt; - llm: LLM; + llm: OpenAI; constructor(init?: Partial<SummaryChatHistory>) { - this.messagesToSummarize = init?.messagesToSummarize ?? 5; this.messages = init?.messages ?? []; this.summaryPrompt = init?.summaryPrompt ?? defaultSummaryPrompt; this.llm = init?.llm ?? new OpenAI(); + if (!this.llm.maxTokens) { + throw new Error( + "LLM maxTokens is not set. Needed so the summarizer ensures the context window size of the LLM.", + ); + } + // TODO: currently, this only works with OpenAI + // to support more LLMs, we have to move the tokenizer and the context window size to the LLM interface + this.tokensToSummarize = + ALL_AVAILABLE_OPENAI_MODELS[this.llm.model].contextWindow - + this.llm.maxTokens; } - private async summarize() { - // get all messages after the last summary message (including) - const chatHistoryStr = messagesToHistoryStr( - this.messages.slice(this.getLastSummaryIndex()), - ); - - const response = await this.llm.complete( - this.summaryPrompt({ context: chatHistoryStr }), - ); + private tokens(messages: ChatMessage[]): number { + // for latest OpenAI models, see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + const encoding = tiktoken.encoding_for_model(this.llm.model); + const tokensPerMessage = 3; + let numTokens = 0; + for (const message of messages) { + numTokens += tokensPerMessage; + for (const value of Object.values(message)) { + numTokens += encoding.encode(value).length; + } + } + numTokens += 3; // every reply is primed with <|im_start|>assistant<|im_sep|> + return numTokens; + } - this.messages.push({ content: response.message.content, role: "memory" }); + private async summarize(): Promise<ChatMessage> { + // get all messages after the last summary message (including) + // if there's no summary message, get all messages (without system messages) + const lastSummaryIndex = this.getLastSummaryIndex(); + const messagesToSummarize = !lastSummaryIndex + ? this.nonSystemMessages + : this.messages.slice(lastSummaryIndex); + + let promptMessages; + do { + promptMessages = [ + { + content: this.summaryPrompt({ + context: messagesToHistoryStr(messagesToSummarize), + }), + role: "user" as MessageType, + }, + ]; + // remove oldest message until the chat history is short enough for the context window + messagesToSummarize.shift(); + } while (this.tokens(promptMessages) > this.tokensToSummarize); + + const response = await this.llm.chat(promptMessages); + return { content: response.message.content, role: "memory" }; } async addMessage(message: ChatMessage) { - const lastSummaryIndex = this.getLastSummaryIndex(); - // if there are more than or equal `messagesToSummarize` messages since the last summary, call summarize - if ( - lastSummaryIndex !== -1 && - this.messages.length - lastSummaryIndex - 1 >= this.messagesToSummarize - ) { - // TODO: define what are better conditions, e.g. depending on the context length of the LLM? - // for now we just summarize each `messagesToSummarize` messages - await this.summarize(); + // get tokens of current request messages and the new message + const tokens = this.tokens([...this.requestMessages, message]); + // if there are too many tokens for the next request, call summarize + if (tokens > this.tokensToSummarize) { + const memoryMessage = await this.summarize(); + this.messages.push(memoryMessage); } this.messages.push(message); } // Find last summary message - private getLastSummaryIndex() { - return this.messages - .slice() - .reverse() - .findIndex((message) => message.role === "memory"); + private getLastSummaryIndex(): number | null { + const reversedMessages = this.messages.slice().reverse(); + const index = reversedMessages.findIndex( + (message) => message.role === "memory", + ); + if (index === -1) { + return null; + } + return this.messages.length - 1 - index; + } + + private get systemMessages() { + // get array of all system messages + return this.messages.filter((message) => message.role === "system"); + } + + private get nonSystemMessages() { + // get array of all non-system messages + return this.messages.filter((message) => message.role !== "system"); } get requestMessages() { const lastSummaryIndex = this.getLastSummaryIndex(); - // get array of all system messages - const systemMessages = this.messages.filter( - (message) => message.role === "system", - ); + if (!lastSummaryIndex) return this.messages; // convert summary message so it can be send to the LLM const summaryMessage: ChatMessage = { content: `This is a summary of conversation so far: ${this.messages[lastSummaryIndex].content}`, @@ -107,7 +158,7 @@ export class SummaryChatHistory implements ChatHistory { }; // return system messages, last summary and all messages after the last summary message return [ - ...systemMessages, + ...this.systemMessages, summaryMessage, ...this.messages.slice(lastSummaryIndex + 1), ]; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 4b76a63f0bb3f75b9b36fd795c93e8ab197a8a0e..497ed2be81b2e9266db8297509ba8af8a166aa88 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,5 +1,6 @@ export * from "./callbacks/CallbackManager"; export * from "./ChatEngine"; +export * from "./ChatHistory"; export * from "./constants"; export * from "./Embedding"; export * from "./GlobalsHelper";