diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index bc299356e36263a247d047b1b4b807e843162ee1..92501bf5ff9ac86bf9427ac98be9e4dc83fc4460 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -317,7 +317,7 @@ export class HistoryChatEngine implements ChatEngine { return this.streamChat(message, chatHistory) as R; } this.chatHistory.addMessage({ content: message, role: "user" }); - const response = await this.llm.chat(this.chatHistory.messages); + const response = await this.llm.chat(this.chatHistory.requestMessages); this.chatHistory.addMessage(response.message); return new Response(response.message.content) as R; } @@ -328,7 +328,7 @@ export class HistoryChatEngine implements ChatEngine { ): AsyncGenerator<string, void, unknown> { this.chatHistory.addMessage({ content: message, role: "user" }); const response_stream = await this.llm.chat( - this.chatHistory.messages, + this.chatHistory.requestMessages, undefined, true, ); diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts index bc2d87b2ccb825b7489f597574c58703ca950bd7..39c19c3905af72c73f32f0574e84424242acf0a1 100644 --- a/packages/core/src/ChatHistory.ts +++ b/packages/core/src/ChatHistory.ts @@ -16,6 +16,11 @@ export interface ChatHistory { */ addMessage(message: ChatMessage): Promise<void>; + /** + * Returns the messages that should be used as input to the LLM. + */ + requestMessages: ChatMessage[]; + /** * Resets the chat history so that it's empty. */ @@ -28,45 +33,86 @@ export class SimpleChatHistory implements ChatHistory { constructor(init?: Partial<SimpleChatHistory>) { this.messages = init?.messages ?? []; } - async addMessage(message: ChatMessage) { this.messages.push(message); } + get requestMessages() { + return this.messages; + } + reset() { this.messages = []; } } export class SummaryChatHistory implements ChatHistory { + messagesToSummarize: number; messages: ChatMessage[]; summaryPrompt: SummaryPrompt; llm: LLM; constructor(init?: Partial<SummaryChatHistory>) { + this.messagesToSummarize = init?.messagesToSummarize ?? 5; this.messages = init?.messages ?? []; this.summaryPrompt = init?.summaryPrompt ?? defaultSummaryPrompt; this.llm = init?.llm ?? new OpenAI(); } private async summarize() { - const chatHistoryStr = messagesToHistoryStr(this.messages); + // get all messages after the last summary message (including) + const chatHistoryStr = messagesToHistoryStr( + this.messages.slice(this.getLastSummaryIndex()), + ); const response = await this.llm.complete( this.summaryPrompt({ context: chatHistoryStr }), ); - this.messages = [{ content: response.message.content, role: "system" }]; + this.messages.push({ content: response.message.content, role: "memory" }); } async addMessage(message: ChatMessage) { - // TODO: check if summarization is necessary - // TBD what are good conditions, e.g. depending on the context length of the LLM? - // for now we just have a dummy implementation at always summarizes the messages - await this.summarize(); + const lastSummaryIndex = this.getLastSummaryIndex(); + // if there are more than or equal `messagesToSummarize` messages since the last summary, call summarize + if ( + lastSummaryIndex !== -1 && + this.messages.length - lastSummaryIndex - 1 >= this.messagesToSummarize + ) { + // TODO: define what are better conditions, e.g. depending on the context length of the LLM? + // for now we just summarize each `messagesToSummarize` messages + await this.summarize(); + } this.messages.push(message); } + // Find last summary message + private getLastSummaryIndex() { + return this.messages + .slice() + .reverse() + .findIndex((message) => message.role === "memory"); + } + + get requestMessages() { + const lastSummaryIndex = this.getLastSummaryIndex(); + // get array of all system messages + const systemMessages = this.messages.filter( + (message) => message.role === "system", + ); + // convert summary message so it can be send to the LLM + const summaryMessage: ChatMessage = { + content: `This is a summary of conversation so far: ${this.messages[lastSummaryIndex].content}`, + role: "system", + }; + // return system messages, last summary and all messages after the last summary message + return [ + ...systemMessages, + summaryMessage, + ...this.messages.slice(lastSummaryIndex + 1), + ]; + } + reset() { this.messages = []; } diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 8c997f5f076a589e18312af97f991a4a00872266..597dbad399e6094eeec3628becb9fa9d2c9b4acb 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -31,7 +31,8 @@ export type MessageType = | "assistant" | "system" | "generic" - | "function"; + | "function" + | "memory"; export interface ChatMessage { content: string;