From c0062746eb88768166594b3dac75abbc34987b87 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Mon, 9 Oct 2023 16:55:05 +0700 Subject: [PATCH] feat: use tokenizer to ensure we're not running over the context window --- packages/core/package.json | 5 +- packages/core/src/ChatEngine.ts | 5 +- packages/core/src/ChatHistory.ts | 120 +++++++++++++++++++++---------- 3 files changed, 89 insertions(+), 41 deletions(-) diff --git a/packages/core/package.json b/packages/core/package.json index 72f89fef9..97dc54bf0 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -38,6 +38,7 @@ "scripts": { "lint": "eslint .", "test": "jest", - "build": "tsup src/index.ts --format esm,cjs --dts" + "build": "tsup src/index.ts --format esm,cjs --dts", + "dev": "tsup src/index.ts --format esm,cjs --watch" } -} +} \ No newline at end of file diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index c76d59794..0d39392b2 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -338,7 +338,10 @@ export class HistoryChatEngine implements ChatEngine { accumulator += part; yield part; } - await this.chatHistory.addMessage({ content: accumulator, role: "user" }); + await this.chatHistory.addMessage({ + content: accumulator, + role: "assistant", + }); return; } diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts index f4e2c8a02..761d43e6e 100644 --- a/packages/core/src/ChatHistory.ts +++ b/packages/core/src/ChatHistory.ts @@ -1,4 +1,10 @@ -import { ChatMessage, LLM, OpenAI } from "./llm/LLM"; +import tiktoken from "tiktoken"; +import { + ALL_AVAILABLE_OPENAI_MODELS, + ChatMessage, + MessageType, + OpenAI, +} from "./llm/LLM"; import { defaultSummaryPrompt, messagesToHistoryStr, @@ -47,66 +53,104 @@ export class SimpleChatHistory implements ChatHistory { } export class SummaryChatHistory implements ChatHistory { - messagesToSummarize: number; + tokensToSummarize: number; messages: ChatMessage[]; summaryPrompt: SummaryPrompt; - llm: LLM; + llm: OpenAI; constructor(init?: Partial<SummaryChatHistory>) { - this.messagesToSummarize = init?.messagesToSummarize ?? 5; this.messages = init?.messages ?? []; this.summaryPrompt = init?.summaryPrompt ?? defaultSummaryPrompt; this.llm = init?.llm ?? new OpenAI(); + if (!this.llm.maxTokens) { + throw new Error( + "LLM maxTokens is not set. Needed so the summarizer ensures the context window size of the LLM.", + ); + } + // TODO: currently, this only works with OpenAI + // to support more LLMs, we have to move the tokenizer and the context window size to the LLM interface + this.tokensToSummarize = + ALL_AVAILABLE_OPENAI_MODELS[this.llm.model].contextWindow - + this.llm.maxTokens; + } + + private tokens(messages: ChatMessage[]): number { + // for latest OpenAI models, see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + const encoding = tiktoken.encoding_for_model(this.llm.model); + const tokensPerMessage = 3; + let numTokens = 0; + for (const message of messages) { + numTokens += tokensPerMessage; + for (const value of Object.values(message)) { + numTokens += encoding.encode(value).length; + } + } + numTokens += 3; // every reply is primed with <|im_start|>assistant<|im_sep|> + return numTokens; } - private async summarize() { + private async summarize(): Promise<ChatMessage> { // get all messages after the last summary message (including) - // if there's no summary message, get all messages + // if there's no summary message, get all messages (without system messages) const lastSummaryIndex = this.getLastSummaryIndex(); - const chatHistoryStr = messagesToHistoryStr( - lastSummaryIndex === -1 - ? this.messages - : this.messages.slice(lastSummaryIndex), - ); - - const response = await this.llm.complete( - this.summaryPrompt({ context: chatHistoryStr }), - ); - - this.messages.push({ content: response.message.content, role: "memory" }); + const messagesToSummarize = !lastSummaryIndex + ? this.nonSystemMessages + : this.messages.slice(lastSummaryIndex); + + let promptMessages; + do { + promptMessages = [ + { + content: this.summaryPrompt({ + context: messagesToHistoryStr(messagesToSummarize), + }), + role: "user" as MessageType, + }, + ]; + // remove oldest message until the chat history is short enough for the context window + messagesToSummarize.shift(); + } while (this.tokens(promptMessages) > this.tokensToSummarize); + + const response = await this.llm.chat(promptMessages); + return { content: response.message.content, role: "memory" }; } async addMessage(message: ChatMessage) { - const messagesSinceLastSummary = - this.messages.length - this.getLastSummaryIndex() - 1; - // if there are too many messages since the last summary, call summarize - if (messagesSinceLastSummary >= this.messagesToSummarize) { - // TODO: define what are better conditions, e.g. depending on the context length of the LLM? - // for now we just summarize each `messagesToSummarize` messages - await this.summarize(); + // get tokens of current request messages and the new message + const tokens = this.tokens([...this.requestMessages, message]); + // if there are too many tokens for the next request, call summarize + if (tokens > this.tokensToSummarize) { + const memoryMessage = await this.summarize(); + this.messages.push(memoryMessage); } this.messages.push(message); } // Find last summary message - private getLastSummaryIndex() { - return ( - this.messages.length - - 1 - - this.messages - .slice() - .reverse() - .findIndex((message) => message.role === "memory") + private getLastSummaryIndex(): number | null { + const reversedMessages = this.messages.slice().reverse(); + const index = reversedMessages.findIndex( + (message) => message.role === "memory", ); + if (index === -1) { + return null; + } + return this.messages.length - 1 - index; + } + + private get systemMessages() { + // get array of all system messages + return this.messages.filter((message) => message.role === "system"); + } + + private get nonSystemMessages() { + // get array of all system messages + return this.messages.filter((message) => message.role !== "system"); } get requestMessages() { const lastSummaryIndex = this.getLastSummaryIndex(); - if (lastSummaryIndex === -1) return this.messages; - // get array of all system messages - const systemMessages = this.messages.filter( - (message) => message.role === "system", - ); + if (!lastSummaryIndex) return this.messages; // convert summary message so it can be send to the LLM const summaryMessage: ChatMessage = { content: `This is a summary of conversation so far: ${this.messages[lastSummaryIndex].content}`, @@ -114,7 +158,7 @@ export class SummaryChatHistory implements ChatHistory { }; // return system messages, last summary and all messages after the last summary message return [ - ...systemMessages, + ...this.systemMessages, summaryMessage, ...this.messages.slice(lastSummaryIndex + 1), ]; -- GitLab