Skip to content
Snippets Groups Projects
Unverified Commit ee3eb7d8 authored by Marcus Schiesser's avatar Marcus Schiesser Committed by GitHub
Browse files

fix: update create-llama examples for new chat engine (#396)



---------

Co-authored-by: default avatarthucpn <thucsh2@gmail.com>
parent 75f94eea
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ import { Request, Response } from "express"; ...@@ -2,7 +2,7 @@ import { Request, Response } from "express";
import { ChatMessage, MessageContent, OpenAI } from "llamaindex"; import { ChatMessage, MessageContent, OpenAI } from "llamaindex";
import { createChatEngine } from "./engine"; import { createChatEngine } from "./engine";
const getLastMessageContent = ( const convertMessageContent = (
textMessage: string, textMessage: string,
imageUrl: string | undefined, imageUrl: string | undefined,
): MessageContent => { ): MessageContent => {
...@@ -24,8 +24,8 @@ const getLastMessageContent = ( ...@@ -24,8 +24,8 @@ const getLastMessageContent = (
export const chat = async (req: Request, res: Response) => { export const chat = async (req: Request, res: Response) => {
try { try {
const { messages, data }: { messages: ChatMessage[]; data: any } = req.body; const { messages, data }: { messages: ChatMessage[]; data: any } = req.body;
const lastMessage = messages.pop(); const userMessage = messages.pop();
if (!messages || !lastMessage || lastMessage.role !== "user") { if (!messages || !userMessage || userMessage.role !== "user") {
return res.status(400).json({ return res.status(400).json({
error: error:
"messages are required in the request body and the last message must be from the user", "messages are required in the request body and the last message must be from the user",
...@@ -36,17 +36,20 @@ export const chat = async (req: Request, res: Response) => { ...@@ -36,17 +36,20 @@ export const chat = async (req: Request, res: Response) => {
model: process.env.MODEL || "gpt-3.5-turbo", model: process.env.MODEL || "gpt-3.5-turbo",
}); });
const lastMessageContent = getLastMessageContent( // Convert message content from Vercel/AI format to LlamaIndex/OpenAI format
lastMessage.content, // Note: The non-streaming template does not need the Vercel/AI format, we're still using it for consistency with the streaming template
const userMessageContent = convertMessageContent(
userMessage.content,
data?.imageUrl, data?.imageUrl,
); );
const chatEngine = await createChatEngine(llm); const chatEngine = await createChatEngine(llm);
const response = await chatEngine.chat( // Calling LlamaIndex's ChatEngine to get a response
lastMessageContent as MessageContent, const response = await chatEngine.chat({
message: userMessageContent,
messages, messages,
); });
const result: ChatMessage = { const result: ChatMessage = {
role: "assistant", role: "assistant",
content: response.response, content: response.response,
......
...@@ -4,7 +4,7 @@ import { ChatMessage, MessageContent, OpenAI } from "llamaindex"; ...@@ -4,7 +4,7 @@ import { ChatMessage, MessageContent, OpenAI } from "llamaindex";
import { createChatEngine } from "./engine"; import { createChatEngine } from "./engine";
import { LlamaIndexStream } from "./llamaindex-stream"; import { LlamaIndexStream } from "./llamaindex-stream";
const getLastMessageContent = ( const convertMessageContent = (
textMessage: string, textMessage: string,
imageUrl: string | undefined, imageUrl: string | undefined,
): MessageContent => { ): MessageContent => {
...@@ -26,8 +26,8 @@ const getLastMessageContent = ( ...@@ -26,8 +26,8 @@ const getLastMessageContent = (
export const chat = async (req: Request, res: Response) => { export const chat = async (req: Request, res: Response) => {
try { try {
const { messages, data }: { messages: ChatMessage[]; data: any } = req.body; const { messages, data }: { messages: ChatMessage[]; data: any } = req.body;
const lastMessage = messages.pop(); const userMessage = messages.pop();
if (!messages || !lastMessage || lastMessage.role !== "user") { if (!messages || !userMessage || userMessage.role !== "user") {
return res.status(400).json({ return res.status(400).json({
error: error:
"messages are required in the request body and the last message must be from the user", "messages are required in the request body and the last message must be from the user",
...@@ -40,18 +40,20 @@ export const chat = async (req: Request, res: Response) => { ...@@ -40,18 +40,20 @@ export const chat = async (req: Request, res: Response) => {
const chatEngine = await createChatEngine(llm); const chatEngine = await createChatEngine(llm);
const lastMessageContent = getLastMessageContent( // Convert message content from Vercel/AI format to LlamaIndex/OpenAI format
lastMessage.content, const userMessageContent = convertMessageContent(
userMessage.content,
data?.imageUrl, data?.imageUrl,
); );
const response = await chatEngine.chat( // Calling LlamaIndex's ChatEngine to get a streamed response
lastMessageContent as MessageContent, const response = await chatEngine.chat({
messages, message: userMessageContent,
true, chatHistory: messages,
); stream: true,
});
// Transform the response into a readable stream // Return a stream, which can be consumed by the Vercel/AI client
const stream = LlamaIndexStream(response); const stream = LlamaIndexStream(response);
streamToResponse(stream, res); streamToResponse(stream, res);
......
...@@ -4,18 +4,20 @@ import { ...@@ -4,18 +4,20 @@ import {
trimStartOfStreamHelper, trimStartOfStreamHelper,
type AIStreamCallbacksAndOptions, type AIStreamCallbacksAndOptions,
} from "ai"; } from "ai";
import { Response } from "llamaindex";
function createParser(res: AsyncGenerator<any>) { function createParser(res: AsyncIterable<Response>) {
const it = res[Symbol.asyncIterator]();
const trimStartOfStream = trimStartOfStreamHelper(); const trimStartOfStream = trimStartOfStreamHelper();
return new ReadableStream<string>({ return new ReadableStream<string>({
async pull(controller): Promise<void> { async pull(controller): Promise<void> {
const { value, done } = await res.next(); const { value, done } = await it.next();
if (done) { if (done) {
controller.close(); controller.close();
return; return;
} }
const text = trimStartOfStream(value ?? ""); const text = trimStartOfStream(value.response ?? "");
if (text) { if (text) {
controller.enqueue(text); controller.enqueue(text);
} }
...@@ -24,7 +26,7 @@ function createParser(res: AsyncGenerator<any>) { ...@@ -24,7 +26,7 @@ function createParser(res: AsyncGenerator<any>) {
} }
export function LlamaIndexStream( export function LlamaIndexStream(
res: AsyncGenerator<any>, res: AsyncIterable<Response>,
callbacks?: AIStreamCallbacksAndOptions, callbacks?: AIStreamCallbacksAndOptions,
): ReadableStream { ): ReadableStream {
return createParser(res) return createParser(res)
......
...@@ -6,16 +6,18 @@ import { ...@@ -6,16 +6,18 @@ import {
trimStartOfStreamHelper, trimStartOfStreamHelper,
type AIStreamCallbacksAndOptions, type AIStreamCallbacksAndOptions,
} from "ai"; } from "ai";
import { Response } from "llamaindex";
type ParserOptions = { type ParserOptions = {
image_url?: string; image_url?: string;
}; };
function createParser( function createParser(
res: AsyncGenerator<any>, res: AsyncIterable<Response>,
data: experimental_StreamData, data: experimental_StreamData,
opts?: ParserOptions, opts?: ParserOptions,
) { ) {
const it = res[Symbol.asyncIterator]();
const trimStartOfStream = trimStartOfStreamHelper(); const trimStartOfStream = trimStartOfStreamHelper();
return new ReadableStream<string>({ return new ReadableStream<string>({
start() { start() {
...@@ -33,7 +35,7 @@ function createParser( ...@@ -33,7 +35,7 @@ function createParser(
} }
}, },
async pull(controller): Promise<void> { async pull(controller): Promise<void> {
const { value, done } = await res.next(); const { value, done } = await it.next();
if (done) { if (done) {
controller.close(); controller.close();
data.append({}); // send an empty image response for the assistant's message data.append({}); // send an empty image response for the assistant's message
...@@ -41,7 +43,7 @@ function createParser( ...@@ -41,7 +43,7 @@ function createParser(
return; return;
} }
const text = trimStartOfStream(value ?? ""); const text = trimStartOfStream(value.response ?? "");
if (text) { if (text) {
controller.enqueue(text); controller.enqueue(text);
} }
...@@ -50,7 +52,7 @@ function createParser( ...@@ -50,7 +52,7 @@ function createParser(
} }
export function LlamaIndexStream( export function LlamaIndexStream(
res: AsyncGenerator<any>, res: AsyncIterable<Response>,
opts?: { opts?: {
callbacks?: AIStreamCallbacksAndOptions; callbacks?: AIStreamCallbacksAndOptions;
parserOptions?: ParserOptions; parserOptions?: ParserOptions;
......
import { Message, StreamingTextResponse } from "ai"; import { StreamingTextResponse } from "ai";
import { ChatMessage, MessageContent, OpenAI } from "llamaindex"; import { ChatMessage, MessageContent, OpenAI } from "llamaindex";
import { NextRequest, NextResponse } from "next/server"; import { NextRequest, NextResponse } from "next/server";
import { createChatEngine } from "./engine"; import { createChatEngine } from "./engine";
...@@ -7,7 +7,7 @@ import { LlamaIndexStream } from "./llamaindex-stream"; ...@@ -7,7 +7,7 @@ import { LlamaIndexStream } from "./llamaindex-stream";
export const runtime = "nodejs"; export const runtime = "nodejs";
export const dynamic = "force-dynamic"; export const dynamic = "force-dynamic";
const getLastMessageContent = ( const convertMessageContent = (
textMessage: string, textMessage: string,
imageUrl: string | undefined, imageUrl: string | undefined,
): MessageContent => { ): MessageContent => {
...@@ -29,9 +29,9 @@ const getLastMessageContent = ( ...@@ -29,9 +29,9 @@ const getLastMessageContent = (
export async function POST(request: NextRequest) { export async function POST(request: NextRequest) {
try { try {
const body = await request.json(); const body = await request.json();
const { messages, data }: { messages: Message[]; data: any } = body; const { messages, data }: { messages: ChatMessage[]; data: any } = body;
const lastMessage = messages.pop(); const userMessage = messages.pop();
if (!messages || !lastMessage || lastMessage.role !== "user") { if (!messages || !userMessage || userMessage.role !== "user") {
return NextResponse.json( return NextResponse.json(
{ {
error: error:
...@@ -48,25 +48,27 @@ export async function POST(request: NextRequest) { ...@@ -48,25 +48,27 @@ export async function POST(request: NextRequest) {
const chatEngine = await createChatEngine(llm); const chatEngine = await createChatEngine(llm);
const lastMessageContent = getLastMessageContent( // Convert message content from Vercel/AI format to LlamaIndex/OpenAI format
lastMessage.content, const userMessageContent = convertMessageContent(
userMessage.content,
data?.imageUrl, data?.imageUrl,
); );
const response = await chatEngine.chat( // Calling LlamaIndex's ChatEngine to get a streamed response
lastMessageContent as MessageContent, const response = await chatEngine.chat({
messages as ChatMessage[], message: userMessageContent,
true, chatHistory: messages,
); stream: true,
});
// Transform the response into a readable stream // Transform LlamaIndex stream to Vercel/AI format
const { stream, data: streamData } = LlamaIndexStream(response, { const { stream, data: streamData } = LlamaIndexStream(response, {
parserOptions: { parserOptions: {
image_url: data?.imageUrl, image_url: data?.imageUrl,
}, },
}); });
// Return a StreamingTextResponse, which can be consumed by the client // Return a StreamingTextResponse, which can be consumed by the Vercel/AI client
return new StreamingTextResponse(stream, {}, streamData); return new StreamingTextResponse(stream, {}, streamData);
} catch (error) { } catch (error) {
console.error("[LlamaIndex]", error); console.error("[LlamaIndex]", error);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment