diff --git a/examples/toolsStream.ts b/examples/toolsStream.ts index f52d9049ac978527bfcddf34ab147018a228e935..6aced425b094fe525828d8c60ca6f8b573597673 100644 --- a/examples/toolsStream.ts +++ b/examples/toolsStream.ts @@ -1,4 +1,4 @@ -import { ChatResponseChunk, OpenAI } from "llamaindex"; +import { OpenAI } from "llamaindex"; async function main() { const llm = new OpenAI({ model: "gpt-4-turbo" }); @@ -34,11 +34,10 @@ async function main() { }; const stream = await llm.chat({ ...args, stream: true }); - let chunk: ChatResponseChunk | null = null; - for await (chunk of stream) { + for await (const chunk of stream) { process.stdout.write(chunk.delta); + console.log(chunk.options?.toolCalls?.[0]); } - console.log(chunk?.additionalKwargs?.toolCalls[0]); } (async function () { diff --git a/packages/core/e2e/fixtures/llm/open_ai.ts b/packages/core/e2e/fixtures/llm/open_ai.ts index 201af668c701c14b5f085164e592dbe7eda6f3d3..f1c84d2e2e9a92c57023ad3967bba9720348dd08 100644 --- a/packages/core/e2e/fixtures/llm/open_ai.ts +++ b/packages/core/e2e/fixtures/llm/open_ai.ts @@ -50,6 +50,9 @@ export class OpenAI implements LLM { }; } return { + get raw(): never { + throw new Error("not implemented"); + }, message: { content: faker.lorem.paragraph(), role: "assistant", diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts index d727b248291c0181f253a1098a1d1ef5f09e7492..9c82ca9d412fa079d4a35a14ac58be842ebd7684 100644 --- a/packages/core/src/agent/openai/worker.ts +++ b/packages/core/src/agent/openai/worker.ts @@ -58,9 +58,9 @@ async function callFunction( return [ { - content: String(output), + content: `${output}`, role: "tool", - additionalKwargs: { + options: { name, tool_call_id: id_, }, @@ -138,7 +138,8 @@ export class OpenAIAgentWorker return null; } - return chatHistory[chatHistory.length - 1].additionalKwargs?.toolCalls; + // fixme + return chatHistory[chatHistory.length - 1].options?.toolCalls as any; } private _getLlmChatParams( @@ -184,7 +185,7 @@ export class OpenAIAgentWorker const iterator = stream[Symbol.asyncIterator](); let { value } = await iterator.next(); let content = value.delta; - const hasToolCalls = value.additionalKwargs?.toolCalls.length > 0; + const hasToolCalls = value.options?.toolCalls.length > 0; if (hasToolCalls) { // consume stream until we have all the tool calls and return a non-streamed response @@ -194,7 +195,7 @@ export class OpenAIAgentWorker return this._processMessage(task, { content, role: "assistant", - additionalKwargs: value.additionalKwargs, + options: value.options, }); } diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 0af646d2e71bdf707d3c095d5dd0629ef5da7479..7c8ff98d7238879e184964cb7652f92595278227 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -273,6 +273,7 @@ If a question does not make any sense, or is not factually coherent, explain why replicateOptions, ); return { + raw: response, message: { content: (response as Array<string>).join("").trimStart(), //^ We need to do this because Replicate returns a list of strings (for streaming functionality which is not exposed by the run function) @@ -330,7 +331,7 @@ export class Portkey extends BaseLLM { const content = response.choices[0].message?.content ?? ""; const role = response.choices[0].message?.role || "assistant"; - return { message: { content, role: role as MessageType } }; + return { raw: response, message: { content, role: role as MessageType } }; } } diff --git a/packages/core/src/llm/anthropic.ts b/packages/core/src/llm/anthropic.ts index 3fffecee131427798d84001997870c0a62913e2e..9c471eebe24170e51a9aa7d6c593500cc4db58f5 100644 --- a/packages/core/src/llm/anthropic.ts +++ b/packages/core/src/llm/anthropic.ts @@ -185,6 +185,7 @@ export class Anthropic extends BaseLLM { }); return { + raw: response, message: { content: response.content[0].text, role: "assistant" }, }; } diff --git a/packages/core/src/llm/base.ts b/packages/core/src/llm/base.ts index 8e1a4ec382457ff294f08003a2b89bdff4010771..8c245f4373296c96c95b17b55227842fd50c3361 100644 --- a/packages/core/src/llm/base.ts +++ b/packages/core/src/llm/base.ts @@ -16,6 +16,10 @@ export abstract class BaseLLM< string, unknown >, + AdditionalMessageOptions extends Record<string, unknown> = Record< + string, + unknown + >, > implements LLM<AdditionalChatOptions> { abstract metadata: LLMMetadata; @@ -55,5 +59,5 @@ export abstract class BaseLLM< ): Promise<AsyncIterable<ChatResponseChunk>>; abstract chat( params: LLMChatParamsNonStreaming<AdditionalChatOptions>, - ): Promise<ChatResponse>; + ): Promise<ChatResponse<AdditionalMessageOptions>>; } diff --git a/packages/core/src/llm/mistral.ts b/packages/core/src/llm/mistral.ts index ffbf19263618e787a907682891e92e292b0cc5de..28a11ab41c202859b9029333285ca04e87f47975 100644 --- a/packages/core/src/llm/mistral.ts +++ b/packages/core/src/llm/mistral.ts @@ -106,6 +106,7 @@ export class MistralAI extends BaseLLM { const response = await client.chat(this.buildParams(messages)); const message = response.choices[0].message; return { + raw: response, message, }; } diff --git a/packages/core/src/llm/open_ai.ts b/packages/core/src/llm/open_ai.ts index 00426c2215e5d537db802d8c1b807df74eae5267..0c3b4fafbe19745e17c9e5deb375cbe4c6a3cd3e 100644 --- a/packages/core/src/llm/open_ai.ts +++ b/packages/core/src/llm/open_ai.ts @@ -7,7 +7,15 @@ import type { } from "openai"; import { OpenAI as OrigOpenAI } from "openai"; -import type { ChatCompletionTool } from "openai/resources/chat/completions"; +import type { + ChatCompletionAssistantMessageParam, + ChatCompletionFunctionMessageParam, + ChatCompletionMessageToolCall, + ChatCompletionRole, + ChatCompletionSystemMessageParam, + ChatCompletionTool, + ChatCompletionUserMessageParam, +} from "openai/resources/chat/completions"; import type { ChatCompletionMessageParam } from "openai/resources/index.js"; import { Tokenizers } from "../GlobalsHelper.js"; import { wrapEventCaller } from "../internal/context/EventCaller.js"; @@ -32,7 +40,7 @@ import type { MessageToolCall, MessageType, } from "./types.js"; -import { wrapLLMEvent } from "./utils.js"; +import { extractText, wrapLLMEvent } from "./utils.js"; export class AzureOpenAI extends OrigOpenAI { protected override authHeaders() { @@ -135,6 +143,15 @@ export function isFunctionCallingModel(llm: LLM): llm is OpenAI { return isChatModel && !isOld; } +export type OpenAIAdditionalMetadata = { + isFunctionCallingModel: boolean; +}; + +export type OpenAIAdditionalMessageOptions = { + functionName?: string; + toolCalls?: ChatCompletionMessageToolCall[]; +}; + export type OpenAIAdditionalChatOptions = Omit< Partial<OpenAILLM.Chat.ChatCompletionCreateParams>, | "max_tokens" @@ -147,11 +164,10 @@ export type OpenAIAdditionalChatOptions = Omit< | "toolChoice" >; -export type OpenAIAdditionalMetadata = { - isFunctionCallingModel: boolean; -}; - -export class OpenAI extends BaseLLM<OpenAIAdditionalChatOptions> { +export class OpenAI extends BaseLLM< + OpenAIAdditionalChatOptions, + OpenAIAdditionalMessageOptions +> { // Per completion OpenAI params model: keyof typeof ALL_AVAILABLE_OPENAI_MODELS | string; temperature: number; @@ -238,9 +254,7 @@ export class OpenAI extends BaseLLM<OpenAIAdditionalChatOptions> { }; } - mapMessageType( - messageType: MessageType, - ): "user" | "assistant" | "system" | "function" | "tool" { + static toOpenAIRole(messageType: MessageType): ChatCompletionRole { switch (messageType) { case "user": return "user"; @@ -257,43 +271,76 @@ export class OpenAI extends BaseLLM<OpenAIAdditionalChatOptions> { } } - toOpenAIMessage(messages: ChatMessage[]) { + static toOpenAIMessage( + messages: ChatMessage<OpenAIAdditionalMessageOptions>[], + ): ChatCompletionMessageParam[] { return messages.map((message) => { - const additionalKwargs = message.additionalKwargs ?? {}; - - if (message.additionalKwargs?.toolCalls) { - additionalKwargs.tool_calls = message.additionalKwargs.toolCalls; - delete additionalKwargs.toolCalls; + const options: OpenAIAdditionalMessageOptions = message.options ?? {}; + if (message.role === "user") { + return { + role: "user", + content: message.content, + } satisfies ChatCompletionUserMessageParam; + } + if (typeof message.content !== "string") { + console.warn("Message content is not a string"); + } + if (message.role === "function") { + if (!options.functionName) { + console.warn("Function message does not have a name"); + } + return { + role: "function", + name: options.functionName ?? "UNKNOWN", + content: extractText(message.content), + // todo: remove this since this is deprecated in the OpenAI API + } satisfies ChatCompletionFunctionMessageParam; + } + if (message.role === "assistant") { + return { + role: "assistant", + content: extractText(message.content), + tool_calls: options.toolCalls, + } satisfies ChatCompletionAssistantMessageParam; } - return { - role: this.mapMessageType(message.role), - content: message.content, - ...additionalKwargs, + const response: + | ChatCompletionSystemMessageParam + | ChatCompletionUserMessageParam + | ChatCompletionMessageToolCall = { + // fixme(alex): type assertion + role: OpenAI.toOpenAIRole(message.role) as never, + // fixme: should not extract text, but assert content is string + content: extractText(message.content), + ...options, }; + return response; }); } chat( params: LLMChatParamsStreaming<OpenAIAdditionalChatOptions>, - ): Promise<AsyncIterable<ChatResponseChunk>>; + ): Promise<AsyncIterable<ChatResponseChunk<OpenAIAdditionalMessageOptions>>>; chat( params: LLMChatParamsNonStreaming<OpenAIAdditionalChatOptions>, - ): Promise<ChatResponse>; + ): Promise<ChatResponse<OpenAIAdditionalMessageOptions>>; @wrapEventCaller @wrapLLMEvent async chat( params: | LLMChatParamsNonStreaming<OpenAIAdditionalChatOptions> | LLMChatParamsStreaming<OpenAIAdditionalChatOptions>, - ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> { + ): Promise< + | ChatResponse<OpenAIAdditionalMessageOptions> + | AsyncIterable<ChatResponseChunk<OpenAIAdditionalMessageOptions>> + > { const { messages, stream, tools, additionalChatOptions } = params; const baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams = { model: this.model, temperature: this.temperature, max_tokens: this.maxTokens, tools: tools?.map(OpenAI.toTool), - messages: this.toOpenAIMessage(messages) as ChatCompletionMessageParam[], + messages: OpenAI.toOpenAIMessage(messages), top_p: this.topP, ...Object.assign({}, this.additionalChatOptions, additionalChatOptions), }; @@ -311,17 +358,18 @@ export class OpenAI extends BaseLLM<OpenAIAdditionalChatOptions> { const content = response.choices[0].message?.content ?? ""; - const kwargsOutput: Record<string, any> = {}; + const options: OpenAIAdditionalMessageOptions = {}; if (response.choices[0].message?.tool_calls) { - kwargsOutput.toolCalls = response.choices[0].message.tool_calls; + options.toolCalls = response.choices[0].message.tool_calls; } return { + raw: response, message: { content, role: response.choices[0].message.role, - additionalKwargs: kwargsOutput, + options, }, }; } @@ -329,7 +377,7 @@ export class OpenAI extends BaseLLM<OpenAIAdditionalChatOptions> { @wrapEventCaller protected async *streamChat( baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams, - ): AsyncIterable<ChatResponseChunk> { + ): AsyncIterable<ChatResponseChunk<OpenAIAdditionalMessageOptions>> { const stream: AsyncIterable<OpenAILLM.Chat.ChatCompletionChunk> = await this.session.openai.chat.completions.create({ ...baseRequestParams, @@ -355,8 +403,7 @@ export class OpenAI extends BaseLLM<OpenAIAdditionalChatOptions> { yield { // add tool calls to final chunk - additionalKwargs: - toolCalls.length > 0 ? { toolCalls: toolCalls } : undefined, + options: toolCalls.length > 0 ? { toolCalls: toolCalls } : {}, delta: choice.delta.content ?? "", }; } diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts index db039f58edcecb1c959a3333fd8950afbe1c7553..1231885b1f350b41077e3396a09931b71822dea1 100644 --- a/packages/core/src/llm/types.ts +++ b/packages/core/src/llm/types.ts @@ -27,13 +27,22 @@ export type LLMEndEvent = LLMBaseEvent< * @internal */ export interface LLMChat< - ExtraParams extends Record<string, unknown> = Record<string, unknown>, + AdditionalChatOptions extends Record<string, unknown> = Record< + string, + unknown + >, + AdditionalMessageOptions extends Record<string, unknown> = Record< + string, + unknown + >, > { chat( params: - | LLMChatParamsStreaming<ExtraParams> - | LLMChatParamsNonStreaming<ExtraParams>, - ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>>; + | LLMChatParamsStreaming<AdditionalChatOptions> + | LLMChatParamsNonStreaming<AdditionalChatOptions>, + ): Promise< + ChatResponse<AdditionalMessageOptions> | AsyncIterable<ChatResponseChunk> + >; } /** @@ -44,6 +53,10 @@ export interface LLM< string, unknown >, + AdditionalMessageOptions extends Record<string, unknown> = Record< + string, + unknown + >, > extends LLMChat<AdditionalChatOptions> { metadata: LLMMetadata; /** @@ -54,7 +67,7 @@ export interface LLM< ): Promise<AsyncIterable<ChatResponseChunk>>; chat( params: LLMChatParamsNonStreaming<AdditionalChatOptions>, - ): Promise<ChatResponse>; + ): Promise<ChatResponse<AdditionalMessageOptions>>; /** * Get a prompt completion from the LLM @@ -67,31 +80,71 @@ export interface LLM< ): Promise<CompletionResponse>; } +// todo: remove "generic", "function", "memory"; export type MessageType = | "user" | "assistant" | "system" + /** + * @deprecated + */ | "generic" + /** + * @deprecated + */ | "function" + /** + * @deprecated + */ | "memory" | "tool"; -export interface ChatMessage { - content: MessageContent; - role: MessageType; - additionalKwargs?: Record<string, any>; -} - -export interface ChatResponse { - message: ChatMessage; - raw?: Record<string, any>; - additionalKwargs?: Record<string, any>; +export type ChatMessage< + AdditionalMessageOptions extends Record<string, unknown> = Record< + string, + unknown + >, +> = + AdditionalMessageOptions extends Record<string, unknown> + ? { + content: MessageContent; + role: MessageType; + options?: AdditionalMessageOptions; + } + : { + content: MessageContent; + role: MessageType; + options: AdditionalMessageOptions; + }; + +export interface ChatResponse< + AdditionalMessageOptions extends Record<string, unknown> = Record< + string, + unknown + >, +> { + message: ChatMessage<AdditionalMessageOptions>; + /** + * Raw response from the LLM + */ + raw: object; } -export interface ChatResponseChunk { - delta: string; - additionalKwargs?: Record<string, any>; -} +export type ChatResponseChunk< + AdditionalMessageOptions extends Record<string, unknown> = Record< + string, + unknown + >, +> = + AdditionalMessageOptions extends Record<string, unknown> + ? { + delta: string; + options?: AdditionalMessageOptions; + } + : { + delta: string; + options: AdditionalMessageOptions; + }; export interface CompletionResponse { text: string; @@ -112,8 +165,12 @@ export interface LLMChatParamsBase< string, unknown >, + AdditionalMessageOptions extends Record<string, unknown> = Record< + string, + unknown + >, > { - messages: ChatMessage[]; + messages: ChatMessage<AdditionalMessageOptions>[]; additionalChatOptions?: AdditionalChatOptions; tools?: BaseTool[]; additionalKwargs?: Record<string, unknown>; diff --git a/packages/core/src/llm/utils.ts b/packages/core/src/llm/utils.ts index 49a671b1a40eb10882ae63c282d0d8bfad24ac89..fbbc0c6c796516b377e94f2c89efdbc91da0ac6d 100644 --- a/packages/core/src/llm/utils.ts +++ b/packages/core/src/llm/utils.ts @@ -84,6 +84,7 @@ export function wrapLLMEvent( }; response[Symbol.asyncIterator] = async function* () { const finalResponse: ChatResponse = { + raw: response, message: { content: "", role: "assistant", diff --git a/packages/core/src/memory/ChatMemoryBuffer.ts b/packages/core/src/memory/ChatMemoryBuffer.ts index fbd8708a8e72aa53f18f337677ad902a446909a1..54367a84aa384fc2f209b0ea31a2997f2bcac50b 100644 --- a/packages/core/src/memory/ChatMemoryBuffer.ts +++ b/packages/core/src/memory/ChatMemoryBuffer.ts @@ -6,28 +6,36 @@ import type { BaseMemory } from "./types.js"; const DEFAULT_TOKEN_LIMIT_RATIO = 0.75; const DEFAULT_TOKEN_LIMIT = 3000; -type ChatMemoryBufferParams = { +type ChatMemoryBufferParams< + AdditionalMessageOptions extends Record<string, unknown> = Record< + string, + unknown + >, +> = { tokenLimit?: number; - chatStore?: BaseChatStore; + chatStore?: BaseChatStore<AdditionalMessageOptions>; chatStoreKey?: string; - chatHistory?: ChatMessage[]; - llm?: LLM; + chatHistory?: ChatMessage<AdditionalMessageOptions>[]; + llm?: LLM<Record<string, unknown>, AdditionalMessageOptions>; }; -/** - * Chat memory buffer. - */ -export class ChatMemoryBuffer implements BaseMemory { +export class ChatMemoryBuffer< + AdditionalMessageOptions extends Record<string, unknown> = Record< + string, + unknown + >, +> implements BaseMemory<AdditionalMessageOptions> +{ tokenLimit: number; - chatStore: BaseChatStore; + chatStore: BaseChatStore<AdditionalMessageOptions>; chatStoreKey: string; - /** - * Initialize. - */ - constructor(init?: Partial<ChatMemoryBufferParams>) { - this.chatStore = init?.chatStore ?? new SimpleChatStore(); + constructor( + init?: Partial<ChatMemoryBufferParams<AdditionalMessageOptions>>, + ) { + this.chatStore = + init?.chatStore ?? new SimpleChatStore<AdditionalMessageOptions>(); this.chatStoreKey = init?.chatStoreKey ?? "chat_history"; if (init?.llm) { const contextWindow = init.llm.metadata.contextWindow; @@ -43,11 +51,7 @@ export class ChatMemoryBuffer implements BaseMemory { } } - /** - Get chat history. - @param initialTokenCount: number of tokens to start with - */ - get(initialTokenCount: number = 0): ChatMessage[] { + get(initialTokenCount: number = 0) { const chatHistory = this.getAll(); if (initialTokenCount > this.tokenLimit) { @@ -79,42 +83,22 @@ export class ChatMemoryBuffer implements BaseMemory { return chatHistory.slice(-messageCount); } - /** - * Get all chat history. - * @returns {ChatMessage[]} chat history - */ - getAll(): ChatMessage[] { + getAll() { return this.chatStore.getMessages(this.chatStoreKey); } - /** - * Put chat history. - * @param message - */ - put(message: ChatMessage): void { + put(message: ChatMessage<AdditionalMessageOptions>) { this.chatStore.addMessage(this.chatStoreKey, message); } - /** - * Set chat history. - * @param messages - */ - set(messages: ChatMessage[]): void { + set(messages: ChatMessage<AdditionalMessageOptions>[]) { this.chatStore.setMessages(this.chatStoreKey, messages); } - /** - * Reset chat history. - */ - reset(): void { + reset() { this.chatStore.deleteMessages(this.chatStoreKey); } - /** - * Get token count for message count. - * @param messageCount - * @returns {number} token count - */ private _tokenCountForMessageCount(messageCount: number): number { if (messageCount <= 0) { return 0; diff --git a/packages/core/src/memory/types.ts b/packages/core/src/memory/types.ts index c000e734d453892292b4b13176882f6a54e6695a..5a19c431bb51c1d0aa43a429e5c7b48639923f35 100644 --- a/packages/core/src/memory/types.ts +++ b/packages/core/src/memory/types.ts @@ -1,24 +1,15 @@ import type { ChatMessage } from "../llm/index.js"; -export interface BaseMemory { - /* - Get chat history. - */ - get(...args: any): ChatMessage[]; - /* - Get all chat history. - */ - getAll(): ChatMessage[]; - /* - Put chat history. - */ - put(message: ChatMessage): void; - /* - Set chat history. - */ - set(messages: ChatMessage[]): void; - /* - Reset chat history. - */ +export interface BaseMemory< + AdditionalMessageOptions extends Record<string, unknown> = Record< + string, + unknown + >, +> { + tokenLimit: number; + get(...args: unknown[]): ChatMessage<AdditionalMessageOptions>[]; + getAll(): ChatMessage<AdditionalMessageOptions>[]; + put(message: ChatMessage<AdditionalMessageOptions>): void; + set(messages: ChatMessage<AdditionalMessageOptions>[]): void; reset(): void; } diff --git a/packages/core/src/storage/chatStore/SimpleChatStore.ts b/packages/core/src/storage/chatStore/SimpleChatStore.ts index 43567abaf28b3b927d4c34a4408b49337129e9a4..4c09f2550e6e35498ba89665a59052426bdcc72a 100644 --- a/packages/core/src/storage/chatStore/SimpleChatStore.ts +++ b/packages/core/src/storage/chatStore/SimpleChatStore.ts @@ -2,47 +2,38 @@ import type { ChatMessage } from "../../llm/index.js"; import type { BaseChatStore } from "./types.js"; /** - * Simple chat store. + * fixme: User could carry object references in the messages. + * This could lead to memory leaks if the messages are not properly cleaned up. */ -export class SimpleChatStore implements BaseChatStore { - store: { [key: string]: ChatMessage[] } = {}; +export class SimpleChatStore< + AdditionalMessageOptions extends Record<string, unknown> = Record< + string, + unknown + >, +> implements BaseChatStore<AdditionalMessageOptions> +{ + store: { [key: string]: ChatMessage<AdditionalMessageOptions>[] } = {}; - /** - * Set messages. - * @param key: key - * @param messages: messages - * @returns: void - */ - public setMessages(key: string, messages: ChatMessage[]): void { + public setMessages( + key: string, + messages: ChatMessage<AdditionalMessageOptions>[], + ) { this.store[key] = messages; } - /** - * Get messages. - * @param key: key - * @returns: messages - */ - public getMessages(key: string): ChatMessage[] { + public getMessages(key: string): ChatMessage<AdditionalMessageOptions>[] { return this.store[key] || []; } - /** - * Add message. - * @param key: key - * @param message: message - * @returns: void - */ - public addMessage(key: string, message: ChatMessage): void { + public addMessage( + key: string, + message: ChatMessage<AdditionalMessageOptions>, + ) { this.store[key] = this.store[key] || []; this.store[key].push(message); } - /** - * Delete messages. - * @param key: key - * @returns: messages - */ - public deleteMessages(key: string): ChatMessage[] | null { + public deleteMessages(key: string) { if (!(key in this.store)) { return null; } @@ -51,13 +42,7 @@ export class SimpleChatStore implements BaseChatStore { return messages; } - /** - * Delete message. - * @param key: key - * @param idx: idx - * @returns: message - */ - public deleteMessage(key: string, idx: number): ChatMessage | null { + public deleteMessage(key: string, idx: number) { if (!(key in this.store)) { return null; } @@ -67,12 +52,7 @@ export class SimpleChatStore implements BaseChatStore { return this.store[key].splice(idx, 1)[0]; } - /** - * Delete last message. - * @param key: key - * @returns: message - */ - public deleteLastMessage(key: string): ChatMessage | null { + public deleteLastMessage(key: string) { if (!(key in this.store)) { return null; } @@ -82,10 +62,6 @@ export class SimpleChatStore implements BaseChatStore { return lastMessage || null; } - /** - * Get keys. - * @returns: keys - */ public getKeys(): string[] { return Object.keys(this.store); } diff --git a/packages/core/src/storage/chatStore/types.ts b/packages/core/src/storage/chatStore/types.ts index 6607ebac8c674fe229cc1610cd3c4d8db8bb2974..7b3a2acda0d23e8e51a17345c4adb06e1c42367e 100644 --- a/packages/core/src/storage/chatStore/types.ts +++ b/packages/core/src/storage/chatStore/types.ts @@ -1,11 +1,22 @@ import type { ChatMessage } from "../../llm/index.js"; -export interface BaseChatStore { - setMessages(key: string, messages: ChatMessage[]): void; - getMessages(key: string): ChatMessage[]; - addMessage(key: string, message: ChatMessage): void; - deleteMessages(key: string): ChatMessage[] | null; - deleteMessage(key: string, idx: number): ChatMessage | null; - deleteLastMessage(key: string): ChatMessage | null; +export interface BaseChatStore< + AdditionalMessageOptions extends Record<string, unknown> = Record< + string, + unknown + >, +> { + setMessages( + key: string, + messages: ChatMessage<AdditionalMessageOptions>[], + ): void; + getMessages(key: string): ChatMessage<AdditionalMessageOptions>[]; + addMessage(key: string, message: ChatMessage<AdditionalMessageOptions>): void; + deleteMessages(key: string): ChatMessage<AdditionalMessageOptions>[] | null; + deleteMessage( + key: string, + idx: number, + ): ChatMessage<AdditionalMessageOptions> | null; + deleteLastMessage(key: string): ChatMessage<AdditionalMessageOptions> | null; getKeys(): string[]; } diff --git a/packages/core/tests/utility/mockOpenAI.ts b/packages/core/tests/utility/mockOpenAI.ts index ddb066bba86d1a1b75a577103bcc7242c472a5fa..d3dd737587bab3aa96ab3d6113a3b417827303af 100644 --- a/packages/core/tests/utility/mockOpenAI.ts +++ b/packages/core/tests/utility/mockOpenAI.ts @@ -54,9 +54,13 @@ export function mockLlmGeneration({ } return new Promise((resolve) => { resolve({ + get raw() { + return {}; + }, message: { content: text, role: "assistant", + options: {}, }, }); }); @@ -75,9 +79,13 @@ export function mockLlmToolCallGeneration({ () => new Promise((resolve) => resolve({ + get raw() { + return {}; + }, message: { content: "The sum is 2", role: "assistant", + options: {}, }, }), ), @@ -155,9 +163,13 @@ export function mocStructuredkLlmGeneration({ } return new Promise((resolve) => { resolve({ + get raw() { + return {}; + }, message: { content: text, role: "assistant", + options: {}, }, }); });