From 809a904bc86e5f0e74225e5b35a0e0fbd8b268f8 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Mon, 9 Oct 2023 11:48:15 +0700 Subject: [PATCH] fix: summarizer issues --- packages/core/src/ChatEngine.ts | 10 +++++----- packages/core/src/ChatHistory.ts | 29 ++++++++++++++++++----------- packages/core/src/index.ts | 1 + 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index 92501bf5f..c76d59794 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -314,11 +314,11 @@ export class HistoryChatEngine implements ChatEngine { ): Promise<R> { //Streaming option if (streaming) { - return this.streamChat(message, chatHistory) as R; + return this.streamChat(message) as R; } - this.chatHistory.addMessage({ content: message, role: "user" }); + await this.chatHistory.addMessage({ content: message, role: "user" }); const response = await this.llm.chat(this.chatHistory.requestMessages); - this.chatHistory.addMessage(response.message); + await this.chatHistory.addMessage(response.message); return new Response(response.message.content) as R; } @@ -326,7 +326,7 @@ export class HistoryChatEngine implements ChatEngine { message: string, chatHistory?: ChatMessage[] | undefined, ): AsyncGenerator<string, void, unknown> { - this.chatHistory.addMessage({ content: message, role: "user" }); + await this.chatHistory.addMessage({ content: message, role: "user" }); const response_stream = await this.llm.chat( this.chatHistory.requestMessages, undefined, @@ -338,7 +338,7 @@ export class HistoryChatEngine implements ChatEngine { accumulator += part; yield part; } - this.chatHistory.addMessage({ content: accumulator, role: "user" }); + await this.chatHistory.addMessage({ content: accumulator, role: "user" }); return; } diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts index 39c19c390..f4e2c8a02 100644 --- a/packages/core/src/ChatHistory.ts +++ b/packages/core/src/ChatHistory.ts @@ -61,8 +61,12 @@ export class SummaryChatHistory implements ChatHistory { private async summarize() { // get all messages after the last summary message (including) + // if there's no summary message, get all messages + const lastSummaryIndex = this.getLastSummaryIndex(); const chatHistoryStr = messagesToHistoryStr( - this.messages.slice(this.getLastSummaryIndex()), + lastSummaryIndex === -1 + ? this.messages + : this.messages.slice(lastSummaryIndex), ); const response = await this.llm.complete( @@ -73,12 +77,10 @@ export class SummaryChatHistory implements ChatHistory { } async addMessage(message: ChatMessage) { - const lastSummaryIndex = this.getLastSummaryIndex(); - // if there are more than or equal `messagesToSummarize` messages since the last summary, call summarize - if ( - lastSummaryIndex !== -1 && - this.messages.length - lastSummaryIndex - 1 >= this.messagesToSummarize - ) { + const messagesSinceLastSummary = + this.messages.length - this.getLastSummaryIndex() - 1; + // if there are too many messages since the last summary, call summarize + if (messagesSinceLastSummary >= this.messagesToSummarize) { // TODO: define what are better conditions, e.g. depending on the context length of the LLM? // for now we just summarize each `messagesToSummarize` messages await this.summarize(); @@ -88,14 +90,19 @@ export class SummaryChatHistory implements ChatHistory { // Find last summary message private getLastSummaryIndex() { - return this.messages - .slice() - .reverse() - .findIndex((message) => message.role === "memory"); + return ( + this.messages.length - + 1 - + this.messages + .slice() + .reverse() + .findIndex((message) => message.role === "memory") + ); } get requestMessages() { const lastSummaryIndex = this.getLastSummaryIndex(); + if (lastSummaryIndex === -1) return this.messages; // get array of all system messages const systemMessages = this.messages.filter( (message) => message.role === "system", diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 4b76a63f0..497ed2be8 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,5 +1,6 @@ export * from "./callbacks/CallbackManager"; export * from "./ChatEngine"; +export * from "./ChatHistory"; export * from "./constants"; export * from "./Embedding"; export * from "./GlobalsHelper"; -- GitLab