diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index 322bfd9b5ce016c8ba62882b6f284354a5f5ede0..6089a4ac4c4bb052b22e80113bc4801c6085094a 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -328,6 +328,17 @@ export class ContextChatEngine implements ChatEngine { } } +export interface MessageContentDetail { + type: "text" | "image_url"; + text: string; + image_url: { url: string }; +} + +/** + * Extended type for the content of a message that allows for multi-modal messages. + */ +export type MessageContent = string | MessageContentDetail[]; + /** * HistoryChatEngine is a ChatEngine that uses a `ChatHistory` object * to keeps track of chat's message history. @@ -347,38 +358,34 @@ export class HistoryChatEngine { async chat< T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : Response, - >(message: any, chatHistory: ChatHistory, streaming?: T): Promise<R> { + >( + message: MessageContent, + chatHistory: ChatHistory, + streaming?: T, + ): Promise<R> { //Streaming option if (streaming) { return this.streamChat(message, chatHistory) as R; } - const context = await this.contextGenerator?.generate(message); - chatHistory.addMessage({ - content: message, - role: "user", - }); - const response = await this.llm.chat( - await chatHistory.requestMessages( - context ? [context.message] : undefined, - ), + const requestMessages = await this.prepareRequestMessages( + message, + chatHistory, ); + const response = await this.llm.chat(requestMessages); chatHistory.addMessage(response.message); return new Response(response.message.content) as R; } protected async *streamChat( - message: any, + message: MessageContent, chatHistory: ChatHistory, ): AsyncGenerator<string, void, unknown> { - const context = await this.contextGenerator?.generate(message); - chatHistory.addMessage({ - content: message, - role: "user", - }); + const requestMessages = await this.prepareRequestMessages( + message, + chatHistory, + ); const response_stream = await this.llm.chat( - await chatHistory.requestMessages( - context ? [context.message] : undefined, - ), + requestMessages, undefined, true, ); @@ -394,4 +401,31 @@ export class HistoryChatEngine { }); return; } + + private async prepareRequestMessages( + message: MessageContent, + chatHistory: ChatHistory, + ) { + chatHistory.addMessage({ + content: message, + role: "user", + }); + let requestMessages; + let context; + if (this.contextGenerator) { + if (Array.isArray(message)) { + // message is of type MessageContentDetail[] - retrieve just the text parts and concatenate them + // so we can pass them to the context generator + message = (message as MessageContentDetail[]) + .filter((c) => c.type === "text") + .map((c) => c.text) + .join("\n\n"); + } + context = await this.contextGenerator.generate(message); + } + requestMessages = await chatHistory.requestMessages( + context ? [context.message] : undefined, + ); + return requestMessages; + } }