From c0062746eb88768166594b3dac75abbc34987b87 Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Mon, 9 Oct 2023 16:55:05 +0700
Subject: [PATCH] feat: use tokenizer to ensure we're not running over the
 context window

---
 packages/core/package.json       |   5 +-
 packages/core/src/ChatEngine.ts  |   5 +-
 packages/core/src/ChatHistory.ts | 120 +++++++++++++++++++++----------
 3 files changed, 89 insertions(+), 41 deletions(-)

diff --git a/packages/core/package.json b/packages/core/package.json
index 72f89fef9..97dc54bf0 100644
--- a/packages/core/package.json
+++ b/packages/core/package.json
@@ -38,6 +38,7 @@
   "scripts": {
     "lint": "eslint .",
     "test": "jest",
-    "build": "tsup src/index.ts --format esm,cjs --dts"
+    "build": "tsup src/index.ts --format esm,cjs --dts",
+    "dev": "tsup src/index.ts --format esm,cjs --watch"
   }
-}
+}
\ No newline at end of file
diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts
index c76d59794..0d39392b2 100644
--- a/packages/core/src/ChatEngine.ts
+++ b/packages/core/src/ChatEngine.ts
@@ -338,7 +338,10 @@ export class HistoryChatEngine implements ChatEngine {
       accumulator += part;
       yield part;
     }
-    await this.chatHistory.addMessage({ content: accumulator, role: "user" });
+    await this.chatHistory.addMessage({
+      content: accumulator,
+      role: "assistant",
+    });
     return;
   }
 
diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts
index f4e2c8a02..761d43e6e 100644
--- a/packages/core/src/ChatHistory.ts
+++ b/packages/core/src/ChatHistory.ts
@@ -1,4 +1,10 @@
-import { ChatMessage, LLM, OpenAI } from "./llm/LLM";
+import tiktoken from "tiktoken";
+import {
+  ALL_AVAILABLE_OPENAI_MODELS,
+  ChatMessage,
+  MessageType,
+  OpenAI,
+} from "./llm/LLM";
 import {
   defaultSummaryPrompt,
   messagesToHistoryStr,
@@ -47,66 +53,104 @@ export class SimpleChatHistory implements ChatHistory {
 }
 
 export class SummaryChatHistory implements ChatHistory {
-  messagesToSummarize: number;
+  tokensToSummarize: number;
   messages: ChatMessage[];
   summaryPrompt: SummaryPrompt;
-  llm: LLM;
+  llm: OpenAI;
 
   constructor(init?: Partial<SummaryChatHistory>) {
-    this.messagesToSummarize = init?.messagesToSummarize ?? 5;
     this.messages = init?.messages ?? [];
     this.summaryPrompt = init?.summaryPrompt ?? defaultSummaryPrompt;
     this.llm = init?.llm ?? new OpenAI();
+    if (!this.llm.maxTokens) {
+      throw new Error(
+        "LLM maxTokens is not set. Needed so the summarizer ensures the context window size of the LLM.",
+      );
+    }
+    // TODO: currently, this only works with OpenAI
+    // to support more LLMs, we have to move the tokenizer and the context window size to the LLM interface
+    this.tokensToSummarize =
+      ALL_AVAILABLE_OPENAI_MODELS[this.llm.model].contextWindow -
+      this.llm.maxTokens;
+  }
+
+  private tokens(messages: ChatMessage[]): number {
+    // for latest OpenAI models, see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+    const encoding = tiktoken.encoding_for_model(this.llm.model);
+    const tokensPerMessage = 3;
+    let numTokens = 0;
+    for (const message of messages) {
+      numTokens += tokensPerMessage;
+      for (const value of Object.values(message)) {
+        numTokens += encoding.encode(value).length;
+      }
+    }
+    numTokens += 3; // every reply is primed with <|im_start|>assistant<|im_sep|>
+    return numTokens;
   }
 
-  private async summarize() {
+  private async summarize(): Promise<ChatMessage> {
     // get all messages after the last summary message (including)
-    // if there's no summary message, get all messages
+    // if there's no summary message, get all messages (without system messages)
     const lastSummaryIndex = this.getLastSummaryIndex();
-    const chatHistoryStr = messagesToHistoryStr(
-      lastSummaryIndex === -1
-        ? this.messages
-        : this.messages.slice(lastSummaryIndex),
-    );
-
-    const response = await this.llm.complete(
-      this.summaryPrompt({ context: chatHistoryStr }),
-    );
-
-    this.messages.push({ content: response.message.content, role: "memory" });
+    const messagesToSummarize = !lastSummaryIndex
+      ? this.nonSystemMessages
+      : this.messages.slice(lastSummaryIndex);
+
+    let promptMessages;
+    do {
+      promptMessages = [
+        {
+          content: this.summaryPrompt({
+            context: messagesToHistoryStr(messagesToSummarize),
+          }),
+          role: "user" as MessageType,
+        },
+      ];
+      // remove oldest message until the chat history is short enough for the context window
+      messagesToSummarize.shift();
+    } while (this.tokens(promptMessages) > this.tokensToSummarize);
+
+    const response = await this.llm.chat(promptMessages);
+    return { content: response.message.content, role: "memory" };
   }
 
   async addMessage(message: ChatMessage) {
-    const messagesSinceLastSummary =
-      this.messages.length - this.getLastSummaryIndex() - 1;
-    // if there are too many messages since the last summary, call summarize
-    if (messagesSinceLastSummary >= this.messagesToSummarize) {
-      // TODO: define what are better conditions, e.g. depending on the context length of the LLM?
-      // for now we just summarize each `messagesToSummarize` messages
-      await this.summarize();
+    // get tokens of current request messages and the new message
+    const tokens = this.tokens([...this.requestMessages, message]);
+    // if there are too many tokens for the next request, call summarize
+    if (tokens > this.tokensToSummarize) {
+      const memoryMessage = await this.summarize();
+      this.messages.push(memoryMessage);
     }
     this.messages.push(message);
   }
 
   // Find last summary message
-  private getLastSummaryIndex() {
-    return (
-      this.messages.length -
-      1 -
-      this.messages
-        .slice()
-        .reverse()
-        .findIndex((message) => message.role === "memory")
+  private getLastSummaryIndex(): number | null {
+    const reversedMessages = this.messages.slice().reverse();
+    const index = reversedMessages.findIndex(
+      (message) => message.role === "memory",
     );
+    if (index === -1) {
+      return null;
+    }
+    return this.messages.length - 1 - index;
+  }
+
+  private get systemMessages() {
+    // get array of all system messages
+    return this.messages.filter((message) => message.role === "system");
+  }
+
+  private get nonSystemMessages() {
+    // get array of all system messages
+    return this.messages.filter((message) => message.role !== "system");
   }
 
   get requestMessages() {
     const lastSummaryIndex = this.getLastSummaryIndex();
-    if (lastSummaryIndex === -1) return this.messages;
-    // get array of all system messages
-    const systemMessages = this.messages.filter(
-      (message) => message.role === "system",
-    );
+    if (!lastSummaryIndex) return this.messages;
     // convert summary message so it can be send to the LLM
     const summaryMessage: ChatMessage = {
       content: `This is a summary of conversation so far: ${this.messages[lastSummaryIndex].content}`,
@@ -114,7 +158,7 @@ export class SummaryChatHistory implements ChatHistory {
     };
     // return system messages, last summary and all messages after the last summary message
     return [
-      ...systemMessages,
+      ...this.systemMessages,
       summaryMessage,
       ...this.messages.slice(lastSummaryIndex + 1),
     ];
-- 
GitLab