From ee3eb7d8e2af46f38847d45cc1b17f607d5d6ca9 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Thu, 18 Jan 2024 12:10:37 +0700 Subject: [PATCH] fix: update create-llama examples for new chat engine (#396) --------- Co-authored-by: thucpn <thucsh2@gmail.com> --- .../src/controllers/chat.controller.ts | 19 +++++++----- .../src/controllers/chat.controller.ts | 24 ++++++++------- .../src/controllers/llamaindex-stream.ts | 10 ++++--- .../nextjs/app/api/chat/llamaindex-stream.ts | 10 ++++--- .../streaming/nextjs/app/api/chat/route.ts | 30 ++++++++++--------- 5 files changed, 52 insertions(+), 41 deletions(-) diff --git a/packages/create-llama/templates/types/simple/express/src/controllers/chat.controller.ts b/packages/create-llama/templates/types/simple/express/src/controllers/chat.controller.ts index 46dfa5684..9f9639b72 100644 --- a/packages/create-llama/templates/types/simple/express/src/controllers/chat.controller.ts +++ b/packages/create-llama/templates/types/simple/express/src/controllers/chat.controller.ts @@ -2,7 +2,7 @@ import { Request, Response } from "express"; import { ChatMessage, MessageContent, OpenAI } from "llamaindex"; import { createChatEngine } from "./engine"; -const getLastMessageContent = ( +const convertMessageContent = ( textMessage: string, imageUrl: string | undefined, ): MessageContent => { @@ -24,8 +24,8 @@ const getLastMessageContent = ( export const chat = async (req: Request, res: Response) => { try { const { messages, data }: { messages: ChatMessage[]; data: any } = req.body; - const lastMessage = messages.pop(); - if (!messages || !lastMessage || lastMessage.role !== "user") { + const userMessage = messages.pop(); + if (!messages || !userMessage || userMessage.role !== "user") { return res.status(400).json({ error: "messages are required in the request body and the last message must be from the user", @@ -36,17 +36,20 @@ export const chat = async (req: Request, res: Response) => { model: process.env.MODEL || "gpt-3.5-turbo", }); - const lastMessageContent = getLastMessageContent( - lastMessage.content, + // Convert message content from Vercel/AI format to LlamaIndex/OpenAI format + // Note: The non-streaming template does not need the Vercel/AI format, we're still using it for consistency with the streaming template + const userMessageContent = convertMessageContent( + userMessage.content, data?.imageUrl, ); const chatEngine = await createChatEngine(llm); - const response = await chatEngine.chat( - lastMessageContent as MessageContent, + // Calling LlamaIndex's ChatEngine to get a response + const response = await chatEngine.chat({ + message: userMessageContent, messages, - ); + }); const result: ChatMessage = { role: "assistant", content: response.response, diff --git a/packages/create-llama/templates/types/streaming/express/src/controllers/chat.controller.ts b/packages/create-llama/templates/types/streaming/express/src/controllers/chat.controller.ts index 4bd1c8da6..e82658016 100644 --- a/packages/create-llama/templates/types/streaming/express/src/controllers/chat.controller.ts +++ b/packages/create-llama/templates/types/streaming/express/src/controllers/chat.controller.ts @@ -4,7 +4,7 @@ import { ChatMessage, MessageContent, OpenAI } from "llamaindex"; import { createChatEngine } from "./engine"; import { LlamaIndexStream } from "./llamaindex-stream"; -const getLastMessageContent = ( +const convertMessageContent = ( textMessage: string, imageUrl: string | undefined, ): MessageContent => { @@ -26,8 +26,8 @@ const getLastMessageContent = ( export const chat = async (req: Request, res: Response) => { try { const { messages, data }: { messages: ChatMessage[]; data: any } = req.body; - const lastMessage = messages.pop(); - if (!messages || !lastMessage || lastMessage.role !== "user") { + const userMessage = messages.pop(); + if (!messages || !userMessage || userMessage.role !== "user") { return res.status(400).json({ error: "messages are required in the request body and the last message must be from the user", @@ -40,18 +40,20 @@ export const chat = async (req: Request, res: Response) => { const chatEngine = await createChatEngine(llm); - const lastMessageContent = getLastMessageContent( - lastMessage.content, + // Convert message content from Vercel/AI format to LlamaIndex/OpenAI format + const userMessageContent = convertMessageContent( + userMessage.content, data?.imageUrl, ); - const response = await chatEngine.chat( - lastMessageContent as MessageContent, - messages, - true, - ); + // Calling LlamaIndex's ChatEngine to get a streamed response + const response = await chatEngine.chat({ + message: userMessageContent, + chatHistory: messages, + stream: true, + }); - // Transform the response into a readable stream + // Return a stream, which can be consumed by the Vercel/AI client const stream = LlamaIndexStream(response); streamToResponse(stream, res); diff --git a/packages/create-llama/templates/types/streaming/express/src/controllers/llamaindex-stream.ts b/packages/create-llama/templates/types/streaming/express/src/controllers/llamaindex-stream.ts index 12328de87..e86c7626f 100644 --- a/packages/create-llama/templates/types/streaming/express/src/controllers/llamaindex-stream.ts +++ b/packages/create-llama/templates/types/streaming/express/src/controllers/llamaindex-stream.ts @@ -4,18 +4,20 @@ import { trimStartOfStreamHelper, type AIStreamCallbacksAndOptions, } from "ai"; +import { Response } from "llamaindex"; -function createParser(res: AsyncGenerator<any>) { +function createParser(res: AsyncIterable<Response>) { + const it = res[Symbol.asyncIterator](); const trimStartOfStream = trimStartOfStreamHelper(); return new ReadableStream<string>({ async pull(controller): Promise<void> { - const { value, done } = await res.next(); + const { value, done } = await it.next(); if (done) { controller.close(); return; } - const text = trimStartOfStream(value ?? ""); + const text = trimStartOfStream(value.response ?? ""); if (text) { controller.enqueue(text); } @@ -24,7 +26,7 @@ function createParser(res: AsyncGenerator<any>) { } export function LlamaIndexStream( - res: AsyncGenerator<any>, + res: AsyncIterable<Response>, callbacks?: AIStreamCallbacksAndOptions, ): ReadableStream { return createParser(res) diff --git a/packages/create-llama/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts b/packages/create-llama/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts index 5ac376d63..6ddd8eae6 100644 --- a/packages/create-llama/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts +++ b/packages/create-llama/templates/types/streaming/nextjs/app/api/chat/llamaindex-stream.ts @@ -6,16 +6,18 @@ import { trimStartOfStreamHelper, type AIStreamCallbacksAndOptions, } from "ai"; +import { Response } from "llamaindex"; type ParserOptions = { image_url?: string; }; function createParser( - res: AsyncGenerator<any>, + res: AsyncIterable<Response>, data: experimental_StreamData, opts?: ParserOptions, ) { + const it = res[Symbol.asyncIterator](); const trimStartOfStream = trimStartOfStreamHelper(); return new ReadableStream<string>({ start() { @@ -33,7 +35,7 @@ function createParser( } }, async pull(controller): Promise<void> { - const { value, done } = await res.next(); + const { value, done } = await it.next(); if (done) { controller.close(); data.append({}); // send an empty image response for the assistant's message @@ -41,7 +43,7 @@ function createParser( return; } - const text = trimStartOfStream(value ?? ""); + const text = trimStartOfStream(value.response ?? ""); if (text) { controller.enqueue(text); } @@ -50,7 +52,7 @@ function createParser( } export function LlamaIndexStream( - res: AsyncGenerator<any>, + res: AsyncIterable<Response>, opts?: { callbacks?: AIStreamCallbacksAndOptions; parserOptions?: ParserOptions; diff --git a/packages/create-llama/templates/types/streaming/nextjs/app/api/chat/route.ts b/packages/create-llama/templates/types/streaming/nextjs/app/api/chat/route.ts index a4a9f30b7..ef35bf76e 100644 --- a/packages/create-llama/templates/types/streaming/nextjs/app/api/chat/route.ts +++ b/packages/create-llama/templates/types/streaming/nextjs/app/api/chat/route.ts @@ -1,4 +1,4 @@ -import { Message, StreamingTextResponse } from "ai"; +import { StreamingTextResponse } from "ai"; import { ChatMessage, MessageContent, OpenAI } from "llamaindex"; import { NextRequest, NextResponse } from "next/server"; import { createChatEngine } from "./engine"; @@ -7,7 +7,7 @@ import { LlamaIndexStream } from "./llamaindex-stream"; export const runtime = "nodejs"; export const dynamic = "force-dynamic"; -const getLastMessageContent = ( +const convertMessageContent = ( textMessage: string, imageUrl: string | undefined, ): MessageContent => { @@ -29,9 +29,9 @@ const getLastMessageContent = ( export async function POST(request: NextRequest) { try { const body = await request.json(); - const { messages, data }: { messages: Message[]; data: any } = body; - const lastMessage = messages.pop(); - if (!messages || !lastMessage || lastMessage.role !== "user") { + const { messages, data }: { messages: ChatMessage[]; data: any } = body; + const userMessage = messages.pop(); + if (!messages || !userMessage || userMessage.role !== "user") { return NextResponse.json( { error: @@ -48,25 +48,27 @@ export async function POST(request: NextRequest) { const chatEngine = await createChatEngine(llm); - const lastMessageContent = getLastMessageContent( - lastMessage.content, + // Convert message content from Vercel/AI format to LlamaIndex/OpenAI format + const userMessageContent = convertMessageContent( + userMessage.content, data?.imageUrl, ); - const response = await chatEngine.chat( - lastMessageContent as MessageContent, - messages as ChatMessage[], - true, - ); + // Calling LlamaIndex's ChatEngine to get a streamed response + const response = await chatEngine.chat({ + message: userMessageContent, + chatHistory: messages, + stream: true, + }); - // Transform the response into a readable stream + // Transform LlamaIndex stream to Vercel/AI format const { stream, data: streamData } = LlamaIndexStream(response, { parserOptions: { image_url: data?.imageUrl, }, }); - // Return a StreamingTextResponse, which can be consumed by the client + // Return a StreamingTextResponse, which can be consumed by the Vercel/AI client return new StreamingTextResponse(stream, {}, streamData); } catch (error) { console.error("[LlamaIndex]", error); -- GitLab