diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index 92501bf5ff9ac86bf9427ac98be9e4dc83fc4460..c76d59794926e00fb6c064175177dbae74b36c90 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -314,11 +314,11 @@ export class HistoryChatEngine implements ChatEngine { ): Promise<R> { //Streaming option if (streaming) { - return this.streamChat(message, chatHistory) as R; + return this.streamChat(message) as R; } - this.chatHistory.addMessage({ content: message, role: "user" }); + await this.chatHistory.addMessage({ content: message, role: "user" }); const response = await this.llm.chat(this.chatHistory.requestMessages); - this.chatHistory.addMessage(response.message); + await this.chatHistory.addMessage(response.message); return new Response(response.message.content) as R; } @@ -326,7 +326,7 @@ export class HistoryChatEngine implements ChatEngine { message: string, chatHistory?: ChatMessage[] | undefined, ): AsyncGenerator<string, void, unknown> { - this.chatHistory.addMessage({ content: message, role: "user" }); + await this.chatHistory.addMessage({ content: message, role: "user" }); const response_stream = await this.llm.chat( this.chatHistory.requestMessages, undefined, @@ -338,7 +338,7 @@ export class HistoryChatEngine implements ChatEngine { accumulator += part; yield part; } - this.chatHistory.addMessage({ content: accumulator, role: "user" }); + await this.chatHistory.addMessage({ content: accumulator, role: "user" }); return; } diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts index 39c19c3905af72c73f32f0574e84424242acf0a1..f4e2c8a02698583c2d4836c5540f0184fb08a8fe 100644 --- a/packages/core/src/ChatHistory.ts +++ b/packages/core/src/ChatHistory.ts @@ -61,8 +61,12 @@ export class SummaryChatHistory implements ChatHistory { private async summarize() { // get all messages after the last summary message (including) + // if there's no summary message, get all messages + const lastSummaryIndex = this.getLastSummaryIndex(); const chatHistoryStr = messagesToHistoryStr( - this.messages.slice(this.getLastSummaryIndex()), + lastSummaryIndex === -1 + ? this.messages + : this.messages.slice(lastSummaryIndex), ); const response = await this.llm.complete( @@ -73,12 +77,10 @@ export class SummaryChatHistory implements ChatHistory { } async addMessage(message: ChatMessage) { - const lastSummaryIndex = this.getLastSummaryIndex(); - // if there are more than or equal `messagesToSummarize` messages since the last summary, call summarize - if ( - lastSummaryIndex !== -1 && - this.messages.length - lastSummaryIndex - 1 >= this.messagesToSummarize - ) { + const messagesSinceLastSummary = + this.messages.length - this.getLastSummaryIndex() - 1; + // if there are too many messages since the last summary, call summarize + if (messagesSinceLastSummary >= this.messagesToSummarize) { // TODO: define what are better conditions, e.g. depending on the context length of the LLM? // for now we just summarize each `messagesToSummarize` messages await this.summarize(); @@ -88,14 +90,19 @@ export class SummaryChatHistory implements ChatHistory { // Find last summary message private getLastSummaryIndex() { - return this.messages - .slice() - .reverse() - .findIndex((message) => message.role === "memory"); + return ( + this.messages.length - + 1 - + this.messages + .slice() + .reverse() + .findIndex((message) => message.role === "memory") + ); } get requestMessages() { const lastSummaryIndex = this.getLastSummaryIndex(); + if (lastSummaryIndex === -1) return this.messages; // get array of all system messages const systemMessages = this.messages.filter( (message) => message.role === "system", diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 4b76a63f0bb3f75b9b36fd795c93e8ab197a8a0e..497ed2be81b2e9266db8297509ba8af8a166aa88 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,5 +1,6 @@ export * from "./callbacks/CallbackManager"; export * from "./ChatEngine"; +export * from "./ChatHistory"; export * from "./constants"; export * from "./Embedding"; export * from "./GlobalsHelper";