diff --git a/packages/create-llama/templates/types/simple/express/index.ts b/packages/create-llama/templates/types/simple/express/index.ts index 830c549f51dc4e0d540673a9abd1913ea1b44444..721c4ec9dd1922a36756c1b78142cab739cb1e85 100644 --- a/packages/create-llama/templates/types/simple/express/index.ts +++ b/packages/create-llama/templates/types/simple/express/index.ts @@ -11,6 +11,8 @@ const env = process.env["NODE_ENV"]; const isDevelopment = !env || env === "development"; const prodCorsOrigin = process.env["PROD_CORS_ORIGIN"]; +app.use(express.json()); + if (isDevelopment) { console.warn("Running in development mode - allowing CORS for all origins"); app.use(cors()); diff --git a/packages/create-llama/templates/types/simple/express/src/controllers/chat.controller.ts b/packages/create-llama/templates/types/simple/express/src/controllers/chat.controller.ts index 6612971ac5c445bce4dee0d961e8b04e74633597..8aa08613f4601f2b142c065028a22e9ef2d4ea74 100644 --- a/packages/create-llama/templates/types/simple/express/src/controllers/chat.controller.ts +++ b/packages/create-llama/templates/types/simple/express/src/controllers/chat.controller.ts @@ -1,11 +1,30 @@ -import { NextFunction, Request, Response } from "express"; -import { ChatMessage, OpenAI } from "llamaindex"; +import { Request, Response } from "express"; +import { ChatMessage, MessageContent, OpenAI } from "llamaindex"; import { MODEL } from "../../constants"; import { createChatEngine } from "./engine"; -export const chat = async (req: Request, res: Response, next: NextFunction) => { +const getLastMessageContent = ( + textMessage: string, + imageUrl: string | undefined, +): MessageContent => { + if (!imageUrl) return textMessage; + return [ + { + type: "text", + text: textMessage, + }, + { + type: "image_url", + image_url: { + url: imageUrl, + }, + }, + ]; +}; + +export const chat = async (req: Request, res: Response) => { try { - const { messages }: { messages: ChatMessage[] } = JSON.parse(req.body); + const { messages, data }: { messages: ChatMessage[]; data: any } = req.body; const lastMessage = messages.pop(); if (!messages || !lastMessage || lastMessage.role !== "user") { return res.status(400).json({ @@ -18,9 +37,17 @@ export const chat = async (req: Request, res: Response, next: NextFunction) => { model: MODEL, }); + const lastMessageContent = getLastMessageContent( + lastMessage.content, + data?.imageUrl, + ); + const chatEngine = await createChatEngine(llm); - const response = await chatEngine.chat(lastMessage.content, messages); + const response = await chatEngine.chat( + lastMessageContent as MessageContent, + messages, + ); const result: ChatMessage = { role: "assistant", content: response.response, diff --git a/packages/create-llama/templates/types/streaming/express/index.ts b/packages/create-llama/templates/types/streaming/express/index.ts index 830c549f51dc4e0d540673a9abd1913ea1b44444..721c4ec9dd1922a36756c1b78142cab739cb1e85 100644 --- a/packages/create-llama/templates/types/streaming/express/index.ts +++ b/packages/create-llama/templates/types/streaming/express/index.ts @@ -11,6 +11,8 @@ const env = process.env["NODE_ENV"]; const isDevelopment = !env || env === "development"; const prodCorsOrigin = process.env["PROD_CORS_ORIGIN"]; +app.use(express.json()); + if (isDevelopment) { console.warn("Running in development mode - allowing CORS for all origins"); app.use(cors()); diff --git a/packages/create-llama/templates/types/streaming/express/src/controllers/chat.controller.ts b/packages/create-llama/templates/types/streaming/express/src/controllers/chat.controller.ts index 76b8fafdb14adbf79099dd302fe78ccc4cdd119a..1dbd85d45f157bf5e98741e512826f27e420a15a 100644 --- a/packages/create-llama/templates/types/streaming/express/src/controllers/chat.controller.ts +++ b/packages/create-llama/templates/types/streaming/express/src/controllers/chat.controller.ts @@ -1,13 +1,32 @@ import { streamToResponse } from "ai"; -import { NextFunction, Request, Response } from "express"; -import { ChatMessage, OpenAI } from "llamaindex"; +import { Request, Response } from "express"; +import { ChatMessage, MessageContent, OpenAI } from "llamaindex"; import { MODEL } from "../../constants"; import { createChatEngine } from "./engine"; import { LlamaIndexStream } from "./llamaindex-stream"; -export const chat = async (req: Request, res: Response, next: NextFunction) => { +const getLastMessageContent = ( + textMessage: string, + imageUrl: string | undefined, +): MessageContent => { + if (!imageUrl) return textMessage; + return [ + { + type: "text", + text: textMessage, + }, + { + type: "image_url", + image_url: { + url: imageUrl, + }, + }, + ]; +}; + +export const chat = async (req: Request, res: Response) => { try { - const { messages }: { messages: ChatMessage[] } = JSON.parse(req.body); + const { messages, data }: { messages: ChatMessage[]; data: any } = req.body; const lastMessage = messages.pop(); if (!messages || !lastMessage || lastMessage.role !== "user") { return res.status(400).json({ @@ -22,7 +41,16 @@ export const chat = async (req: Request, res: Response, next: NextFunction) => { const chatEngine = await createChatEngine(llm); - const response = await chatEngine.chat(lastMessage.content, messages, true); + const lastMessageContent = getLastMessageContent( + lastMessage.content, + data?.imageUrl, + ); + + const response = await chatEngine.chat( + lastMessageContent as MessageContent, + messages, + true, + ); // Transform the response into a readable stream const stream = LlamaIndexStream(response); diff --git a/packages/create-llama/templates/types/streaming/nextjs/app/components/chat-section.tsx b/packages/create-llama/templates/types/streaming/nextjs/app/components/chat-section.tsx index 791b223f40eb4fe716369efdb4d712dfa415fd8c..b42edb2731f33a4394027569416b26f214c90c13 100644 --- a/packages/create-llama/templates/types/streaming/nextjs/app/components/chat-section.tsx +++ b/packages/create-llama/templates/types/streaming/nextjs/app/components/chat-section.tsx @@ -13,7 +13,12 @@ export default function ChatSection() { handleInputChange, reload, stop, - } = useChat({ api: process.env.NEXT_PUBLIC_CHAT_API }); + } = useChat({ + api: process.env.NEXT_PUBLIC_CHAT_API, + headers: { + "Content-Type": "application/json", // using JSON because of vercel/ai 2.2.26 + }, + }); return ( <div className="space-y-4 max-w-5xl w-full">