diff --git a/.changeset/cold-gifts-help.md b/.changeset/cold-gifts-help.md new file mode 100644 index 0000000000000000000000000000000000000000..cd98a3329d0c83e7a96122a9ec2d7e9a9bc82fe6 --- /dev/null +++ b/.changeset/cold-gifts-help.md @@ -0,0 +1,6 @@ +--- +"llamaindex": patch +"@llamaindex/core": patch +--- + +Fix context not being sent using ContextChatEngine diff --git a/packages/core/src/memory/base.ts b/packages/core/src/memory/base.ts index 0302d7d75f468d66659fad4976afc345b10b21ce..e89e4822dfc8945f12f87a5579492689e4c9d4be 100644 --- a/packages/core/src/memory/base.ts +++ b/packages/core/src/memory/base.ts @@ -1,5 +1,5 @@ import { Settings } from "../global"; -import type { ChatMessage, MessageContent } from "../llms"; +import type { ChatMessage } from "../llms"; import { type BaseChatStore, SimpleChatStore } from "../storage/chat-store"; import { extractText } from "../utils"; @@ -12,15 +12,36 @@ export const DEFAULT_CHAT_STORE_KEY = "chat_history"; export abstract class BaseMemory< AdditionalMessageOptions extends object = object, > { + /** + * Retrieves messages from the memory, optionally including transient messages. + * Compared to getAllMessages, this method a) allows for transient messages to be included in the retrieval and b) may return a subset of the total messages by applying a token limit. + * @param transientMessages Optional array of temporary messages to be included in the retrieval. + * These messages are not stored in the memory but are considered for the current interaction. + * @returns An array of chat messages, either synchronously or as a Promise. + */ abstract getMessages( - input?: MessageContent | undefined, + transientMessages?: ChatMessage<AdditionalMessageOptions>[] | undefined, ): | ChatMessage<AdditionalMessageOptions>[] | Promise<ChatMessage<AdditionalMessageOptions>[]>; + + /** + * Retrieves all messages stored in the memory. + * @returns An array of all chat messages, either synchronously or as a Promise. + */ abstract getAllMessages(): | ChatMessage<AdditionalMessageOptions>[] | Promise<ChatMessage<AdditionalMessageOptions>[]>; + + /** + * Adds a new message to the memory. + * @param messages The chat message to be added to the memory. + */ abstract put(messages: ChatMessage<AdditionalMessageOptions>): void; + + /** + * Clears all messages from the memory. + */ abstract reset(): void; protected _tokenCountForMessages(messages: ChatMessage[]): number { diff --git a/packages/core/src/memory/chat-memory-buffer.ts b/packages/core/src/memory/chat-memory-buffer.ts index c84a837c77d1c647300cc70907029cbe333a5102..9fd0673681e5d211c5129fcd8476dd292d766c19 100644 --- a/packages/core/src/memory/chat-memory-buffer.ts +++ b/packages/core/src/memory/chat-memory-buffer.ts @@ -1,5 +1,5 @@ import { Settings } from "../global"; -import type { ChatMessage, LLM, MessageContent } from "../llms"; +import type { ChatMessage, LLM } from "../llms"; import { type BaseChatStore } from "../storage/chat-store"; import { BaseChatStoreMemory, DEFAULT_TOKEN_LIMIT_RATIO } from "./base"; @@ -34,7 +34,7 @@ export class ChatMemoryBuffer< } getMessages( - input?: MessageContent | undefined, + transientMessages?: ChatMessage<AdditionalMessageOptions>[] | undefined, initialTokenCount: number = 0, ) { const messages = this.getAllMessages(); @@ -43,16 +43,22 @@ export class ChatMemoryBuffer< throw new Error("Initial token count exceeds token limit"); } - let messageCount = messages.length; - let currentMessages = messages.slice(-messageCount); - let tokenCount = this._tokenCountForMessages(messages) + initialTokenCount; + // Add input messages as transient messages + const messagesWithInput = transientMessages + ? [...transientMessages, ...messages] + : messages; + + let messageCount = messagesWithInput.length; + let currentMessages = messagesWithInput.slice(-messageCount); + let tokenCount = + this._tokenCountForMessages(messagesWithInput) + initialTokenCount; while (tokenCount > this.tokenLimit && messageCount > 1) { messageCount -= 1; - if (messages.at(-messageCount)!.role === "assistant") { + if (messagesWithInput.at(-messageCount)!.role === "assistant") { messageCount -= 1; } - currentMessages = messages.slice(-messageCount); + currentMessages = messagesWithInput.slice(-messageCount); tokenCount = this._tokenCountForMessages(currentMessages) + initialTokenCount; } @@ -60,6 +66,6 @@ export class ChatMemoryBuffer< if (tokenCount > this.tokenLimit && messageCount <= 0) { return []; } - return messages.slice(-messageCount); + return messagesWithInput.slice(-messageCount); } } diff --git a/packages/core/src/memory/summary-memory.ts b/packages/core/src/memory/summary-memory.ts index 87b5c11108e2c7d40940606467cc65a7e92d1751..e750aa3671d3c1e601275c53a292b3be8bd84c2d 100644 --- a/packages/core/src/memory/summary-memory.ts +++ b/packages/core/src/memory/summary-memory.ts @@ -114,18 +114,22 @@ export class ChatSummaryMemoryBuffer extends BaseMemory { } } - private calcCurrentRequestMessages() { - // TODO: check order: currently, we're sending: + private calcCurrentRequestMessages(transientMessages?: ChatMessage[]) { + // currently, we're sending: // system messages first, then transient messages and then the messages that describe the conversation so far - return [...this.systemMessages, ...this.calcConversationMessages(true)]; + return [ + ...this.systemMessages, + ...(transientMessages ? transientMessages : []), + ...this.calcConversationMessages(true), + ]; } reset() { this.messages = []; } - async getMessages(): Promise<ChatMessage[]> { - const requestMessages = this.calcCurrentRequestMessages(); + async getMessages(transientMessages?: ChatMessage[]): Promise<ChatMessage[]> { + const requestMessages = this.calcCurrentRequestMessages(transientMessages); // get tokens of current request messages and the transient messages const tokens = requestMessages.reduce( @@ -149,7 +153,7 @@ export class ChatSummaryMemoryBuffer extends BaseMemory { // TODO: we still might have too many tokens // e.g. too large system messages or transient messages // how should we deal with that? - return this.calcCurrentRequestMessages(); + return this.calcCurrentRequestMessages(transientMessages); } return requestMessages; } diff --git a/packages/core/tests/memory/chat-memory-buffer.test.ts b/packages/core/tests/memory/chat-memory-buffer.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..0118d24ad898be47a94c725f424e2da3b1cc1949 --- /dev/null +++ b/packages/core/tests/memory/chat-memory-buffer.test.ts @@ -0,0 +1,74 @@ +import { Settings } from "@llamaindex/core/global"; +import type { ChatMessage } from "@llamaindex/core/llms"; +import { ChatMemoryBuffer } from "@llamaindex/core/memory"; +import { beforeEach, describe, expect, test } from "vitest"; + +describe("ChatMemoryBuffer", () => { + beforeEach(() => { + // Mock the Settings.llm + (Settings.llm as any) = { + metadata: { + contextWindow: 1000, + }, + }; + }); + + test("constructor initializes with custom token limit", () => { + const buffer = new ChatMemoryBuffer({ tokenLimit: 500 }); + expect(buffer.tokenLimit).toBe(500); + }); + + test("getMessages returns all messages when under token limit", () => { + const messages: ChatMessage[] = [ + { role: "user", content: "Hello" }, + { role: "assistant", content: "Hi there!" }, + { role: "user", content: "How are you?" }, + ]; + const buffer = new ChatMemoryBuffer({ + tokenLimit: 1000, + chatHistory: messages, + }); + + const result = buffer.getMessages(); + expect(result).toEqual(messages); + }); + + test("getMessages truncates messages when over token limit", () => { + const messages: ChatMessage[] = [ + { role: "user", content: "This is a long message" }, + { role: "assistant", content: "This is also a long reply" }, + { role: "user", content: "Short" }, + ]; + const buffer = new ChatMemoryBuffer({ + tokenLimit: 5, // limit to only allow the last message + chatHistory: messages, + }); + + const result = buffer.getMessages(); + expect(result).toEqual([{ role: "user", content: "Short" }]); + }); + + test("getMessages handles input messages", () => { + const storedMessages: ChatMessage[] = [ + { role: "user", content: "Hello" }, + { role: "assistant", content: "Hi there!" }, + ]; + const buffer = new ChatMemoryBuffer({ + tokenLimit: 50, + chatHistory: storedMessages, + }); + + const inputMessages: ChatMessage[] = [ + { role: "user", content: "New message" }, + ]; + const result = buffer.getMessages(inputMessages); + expect(result).toEqual([...inputMessages, ...storedMessages]); + }); + + test("getMessages throws error when initial token count exceeds limit", () => { + const buffer = new ChatMemoryBuffer({ tokenLimit: 10 }); + expect(() => buffer.getMessages(undefined, 20)).toThrow( + "Initial token count exceeds token limit", + ); + }); +}); diff --git a/packages/llamaindex/src/agent/base.ts b/packages/llamaindex/src/agent/base.ts index fb2d20f6bc7402eedccf1a5a1c963f0cf5ac404a..2715b4ffebf357876ac3a421cf23b91e6fcfbed2 100644 --- a/packages/llamaindex/src/agent/base.ts +++ b/packages/llamaindex/src/agent/base.ts @@ -356,9 +356,8 @@ export abstract class AgentRunner< let chatHistory: ChatMessage<AdditionalMessageOptions>[] = []; if (params.chatHistory instanceof BaseMemory) { - chatHistory = (await params.chatHistory.getMessages( - params.message, - )) as ChatMessage<AdditionalMessageOptions>[]; + chatHistory = + (await params.chatHistory.getMessages()) as ChatMessage<AdditionalMessageOptions>[]; } else { chatHistory = params.chatHistory as ChatMessage<AdditionalMessageOptions>[]; diff --git a/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts b/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts index 1ff9fdfe4fb1bf73dc3edf78b6611233fbf8767e..2be3eb548a5595fe61772628dfffc94612a959a1 100644 --- a/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts +++ b/packages/llamaindex/src/engines/chat/CondenseQuestionChatEngine.ts @@ -78,9 +78,7 @@ export class CondenseQuestionChatEngine } private async condenseQuestion(chatHistory: BaseMemory, question: string) { - const chatHistoryStr = messagesToHistory( - await chatHistory.getMessages(question), - ); + const chatHistoryStr = messagesToHistory(await chatHistory.getMessages()); return this.llm.complete({ prompt: this.condenseMessagePrompt.format({ @@ -103,7 +101,7 @@ export class CondenseQuestionChatEngine ? new ChatMemoryBuffer({ chatHistory: params.chatHistory instanceof BaseMemory - ? await params.chatHistory.getMessages(message) + ? await params.chatHistory.getMessages() : params.chatHistory, }) : this.chatHistory; diff --git a/packages/llamaindex/src/engines/chat/ContextChatEngine.ts b/packages/llamaindex/src/engines/chat/ContextChatEngine.ts index 3ea12161ab5d056cb8fdca77e88e8640ab55e44c..23d56a00a324fb11d3853d500e8701a9b81b5229 100644 --- a/packages/llamaindex/src/engines/chat/ContextChatEngine.ts +++ b/packages/llamaindex/src/engines/chat/ContextChatEngine.ts @@ -92,7 +92,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { ? new ChatMemoryBuffer({ chatHistory: params.chatHistory instanceof BaseMemory - ? await params.chatHistory.getMessages(message) + ? await params.chatHistory.getMessages() : params.chatHistory, }) : this.chatHistory; @@ -139,7 +139,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { const textOnly = extractText(message); const context = await this.contextGenerator.generate(textOnly); const systemMessage = this.prependSystemPrompt(context.message); - const messages = await chatHistory.getMessages(systemMessage.content); + const messages = await chatHistory.getMessages([systemMessage]); return { nodes: context.nodes, messages }; } diff --git a/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts b/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts index a123881b8f6e2af35499f1d8ba4c366e641753d9..5fba0250144edac7e32a1a5daa9baf8efde7049b 100644 --- a/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts +++ b/packages/llamaindex/src/engines/chat/SimpleChatEngine.ts @@ -40,7 +40,7 @@ export class SimpleChatEngine implements ChatEngine { ? new ChatMemoryBuffer({ chatHistory: params.chatHistory instanceof BaseMemory - ? await params.chatHistory.getMessages(message) + ? await params.chatHistory.getMessages() : params.chatHistory, }) : this.chatHistory; @@ -48,7 +48,7 @@ export class SimpleChatEngine implements ChatEngine { if (stream) { const stream = await this.llm.chat({ - messages: await chatHistory.getMessages(params.message), + messages: await chatHistory.getMessages(), stream: true, }); return streamConverter( @@ -66,7 +66,7 @@ export class SimpleChatEngine implements ChatEngine { const response = await this.llm.chat({ stream: false, - messages: await chatHistory.getMessages(params.message), + messages: await chatHistory.getMessages(), }); chatHistory.put(response.message); return EngineResponse.fromChatResponse(response);