From 809a904bc86e5f0e74225e5b35a0e0fbd8b268f8 Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Mon, 9 Oct 2023 11:48:15 +0700
Subject: [PATCH] fix: summarizer issues

---
 packages/core/src/ChatEngine.ts  | 10 +++++-----
 packages/core/src/ChatHistory.ts | 29 ++++++++++++++++++-----------
 packages/core/src/index.ts       |  1 +
 3 files changed, 24 insertions(+), 16 deletions(-)

diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts
index 92501bf5f..c76d59794 100644
--- a/packages/core/src/ChatEngine.ts
+++ b/packages/core/src/ChatEngine.ts
@@ -314,11 +314,11 @@ export class HistoryChatEngine implements ChatEngine {
   ): Promise<R> {
     //Streaming option
     if (streaming) {
-      return this.streamChat(message, chatHistory) as R;
+      return this.streamChat(message) as R;
     }
-    this.chatHistory.addMessage({ content: message, role: "user" });
+    await this.chatHistory.addMessage({ content: message, role: "user" });
     const response = await this.llm.chat(this.chatHistory.requestMessages);
-    this.chatHistory.addMessage(response.message);
+    await this.chatHistory.addMessage(response.message);
     return new Response(response.message.content) as R;
   }
 
@@ -326,7 +326,7 @@ export class HistoryChatEngine implements ChatEngine {
     message: string,
     chatHistory?: ChatMessage[] | undefined,
   ): AsyncGenerator<string, void, unknown> {
-    this.chatHistory.addMessage({ content: message, role: "user" });
+    await this.chatHistory.addMessage({ content: message, role: "user" });
     const response_stream = await this.llm.chat(
       this.chatHistory.requestMessages,
       undefined,
@@ -338,7 +338,7 @@ export class HistoryChatEngine implements ChatEngine {
       accumulator += part;
       yield part;
     }
-    this.chatHistory.addMessage({ content: accumulator, role: "user" });
+    await this.chatHistory.addMessage({ content: accumulator, role: "user" });
     return;
   }
 
diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts
index 39c19c390..f4e2c8a02 100644
--- a/packages/core/src/ChatHistory.ts
+++ b/packages/core/src/ChatHistory.ts
@@ -61,8 +61,12 @@ export class SummaryChatHistory implements ChatHistory {
 
   private async summarize() {
     // get all messages after the last summary message (including)
+    // if there's no summary message, get all messages
+    const lastSummaryIndex = this.getLastSummaryIndex();
     const chatHistoryStr = messagesToHistoryStr(
-      this.messages.slice(this.getLastSummaryIndex()),
+      lastSummaryIndex === -1
+        ? this.messages
+        : this.messages.slice(lastSummaryIndex),
     );
 
     const response = await this.llm.complete(
@@ -73,12 +77,10 @@ export class SummaryChatHistory implements ChatHistory {
   }
 
   async addMessage(message: ChatMessage) {
-    const lastSummaryIndex = this.getLastSummaryIndex();
-    // if there are more than or equal `messagesToSummarize` messages since the last summary, call summarize
-    if (
-      lastSummaryIndex !== -1 &&
-      this.messages.length - lastSummaryIndex - 1 >= this.messagesToSummarize
-    ) {
+    const messagesSinceLastSummary =
+      this.messages.length - this.getLastSummaryIndex() - 1;
+    // if there are too many messages since the last summary, call summarize
+    if (messagesSinceLastSummary >= this.messagesToSummarize) {
       // TODO: define what are better conditions, e.g. depending on the context length of the LLM?
       // for now we just summarize each `messagesToSummarize` messages
       await this.summarize();
@@ -88,14 +90,19 @@ export class SummaryChatHistory implements ChatHistory {
 
   // Find last summary message
   private getLastSummaryIndex() {
-    return this.messages
-      .slice()
-      .reverse()
-      .findIndex((message) => message.role === "memory");
+    return (
+      this.messages.length -
+      1 -
+      this.messages
+        .slice()
+        .reverse()
+        .findIndex((message) => message.role === "memory")
+    );
   }
 
   get requestMessages() {
     const lastSummaryIndex = this.getLastSummaryIndex();
+    if (lastSummaryIndex === -1) return this.messages;
     // get array of all system messages
     const systemMessages = this.messages.filter(
       (message) => message.role === "system",
diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts
index 4b76a63f0..497ed2be8 100644
--- a/packages/core/src/index.ts
+++ b/packages/core/src/index.ts
@@ -1,5 +1,6 @@
 export * from "./callbacks/CallbackManager";
 export * from "./ChatEngine";
+export * from "./ChatHistory";
 export * from "./constants";
 export * from "./Embedding";
 export * from "./GlobalsHelper";
-- 
GitLab