Skip to content
Snippets Groups Projects
Commit 27eef246 authored by Marcus Schiesser's avatar Marcus Schiesser
Browse files

feat: use context-generator for multi-modal messages

parent 1dabdbf7
No related branches found
No related tags found
No related merge requests found
......@@ -328,6 +328,17 @@ export class ContextChatEngine implements ChatEngine {
}
}
export interface MessageContentDetail {
type: "text" | "image_url";
text: string;
image_url: { url: string };
}
/**
* Extended type for the content of a message that allows for multi-modal messages.
*/
export type MessageContent = string | MessageContentDetail[];
/**
* HistoryChatEngine is a ChatEngine that uses a `ChatHistory` object
* to keeps track of chat's message history.
......@@ -347,38 +358,34 @@ export class HistoryChatEngine {
async chat<
T extends boolean | undefined = undefined,
R = T extends true ? AsyncGenerator<string, void, unknown> : Response,
>(message: any, chatHistory: ChatHistory, streaming?: T): Promise<R> {
>(
message: MessageContent,
chatHistory: ChatHistory,
streaming?: T,
): Promise<R> {
//Streaming option
if (streaming) {
return this.streamChat(message, chatHistory) as R;
}
const context = await this.contextGenerator?.generate(message);
chatHistory.addMessage({
content: message,
role: "user",
});
const response = await this.llm.chat(
await chatHistory.requestMessages(
context ? [context.message] : undefined,
),
const requestMessages = await this.prepareRequestMessages(
message,
chatHistory,
);
const response = await this.llm.chat(requestMessages);
chatHistory.addMessage(response.message);
return new Response(response.message.content) as R;
}
protected async *streamChat(
message: any,
message: MessageContent,
chatHistory: ChatHistory,
): AsyncGenerator<string, void, unknown> {
const context = await this.contextGenerator?.generate(message);
chatHistory.addMessage({
content: message,
role: "user",
});
const requestMessages = await this.prepareRequestMessages(
message,
chatHistory,
);
const response_stream = await this.llm.chat(
await chatHistory.requestMessages(
context ? [context.message] : undefined,
),
requestMessages,
undefined,
true,
);
......@@ -394,4 +401,31 @@ export class HistoryChatEngine {
});
return;
}
private async prepareRequestMessages(
message: MessageContent,
chatHistory: ChatHistory,
) {
chatHistory.addMessage({
content: message,
role: "user",
});
let requestMessages;
let context;
if (this.contextGenerator) {
if (Array.isArray(message)) {
// message is of type MessageContentDetail[] - retrieve just the text parts and concatenate them
// so we can pass them to the context generator
message = (message as MessageContentDetail[])
.filter((c) => c.type === "text")
.map((c) => c.text)
.join("\n\n");
}
context = await this.contextGenerator.generate(message);
}
requestMessages = await chatHistory.requestMessages(
context ? [context.message] : undefined,
);
return requestMessages;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment