diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index 6089a4ac4c4bb052b22e80113bc4801c6085094a..5311ffa8a639288443b9097a92d781f4ed34f2c4 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -30,7 +30,7 @@ export interface ChatEngine { T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : Response, >( - message: string, + message: MessageContent, chatHistory?: ChatMessage[], streaming?: T, ): Promise<R>; @@ -56,7 +56,11 @@ export class SimpleChatEngine implements ChatEngine { async chat< T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : Response, - >(message: string, chatHistory?: ChatMessage[], streaming?: T): Promise<R> { + >( + message: MessageContent, + chatHistory?: ChatMessage[], + streaming?: T, + ): Promise<R> { //Streaming option if (streaming) { return this.streamChat(message, chatHistory) as R; @@ -72,7 +76,7 @@ export class SimpleChatEngine implements ChatEngine { } protected async *streamChat( - message: string, + message: MessageContent, chatHistory?: ChatMessage[], ): AsyncGenerator<string, void, unknown> { chatHistory = chatHistory ?? this.chatHistory; @@ -144,14 +148,14 @@ export class CondenseQuestionChatEngine implements ChatEngine { T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : Response, >( - message: string, + message: MessageContent, chatHistory?: ChatMessage[] | undefined, streaming?: T, ): Promise<R> { chatHistory = chatHistory ?? this.chatHistory; const condensedQuestion = ( - await this.condenseQuestion(chatHistory, message) + await this.condenseQuestion(chatHistory, extractText(message)) ).message.content; const response = await this.queryEngine.query(condensedQuestion); @@ -256,7 +260,7 @@ export class ContextChatEngine implements ChatEngine { T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator<string, void, unknown> : Response, >( - message: string, + message: MessageContent, chatHistory?: ChatMessage[] | undefined, streaming?: T, ): Promise<R> { @@ -272,7 +276,10 @@ export class ContextChatEngine implements ChatEngine { type: "wrapper", tags: ["final"], }; - const context = await this.contextGenerator.generate(message, parentEvent); + const context = await this.contextGenerator.generate( + extractText(message), + parentEvent, + ); chatHistory.push({ content: message, role: "user" }); @@ -291,7 +298,7 @@ export class ContextChatEngine implements ChatEngine { } protected async *streamChat( - message: string, + message: MessageContent, chatHistory?: ChatMessage[] | undefined, ): AsyncGenerator<string, void, unknown> { chatHistory = chatHistory ?? this.chatHistory; @@ -301,7 +308,10 @@ export class ContextChatEngine implements ChatEngine { type: "wrapper", tags: ["final"], }; - const context = await this.contextGenerator.generate(message, parentEvent); + const context = await this.contextGenerator.generate( + extractText(message), + parentEvent, + ); chatHistory.push({ content: message, role: "user" }); @@ -330,8 +340,8 @@ export class ContextChatEngine implements ChatEngine { export interface MessageContentDetail { type: "text" | "image_url"; - text: string; - image_url: { url: string }; + text?: string; + image_url?: { url: string }; } /** @@ -339,6 +349,24 @@ export interface MessageContentDetail { */ export type MessageContent = string | MessageContentDetail[]; +/** + * Extracts just the text from a multi-modal message or the message itself if it's just text. + * + * @param message The message to extract text from. + * @returns The extracted text + */ +function extractText(message: MessageContent): string { + if (Array.isArray(message)) { + // message is of type MessageContentDetail[] - retrieve just the text parts and concatenate them + // so we can pass them to the context generator + return (message as MessageContentDetail[]) + .filter((c) => c.type === "text") + .map((c) => c.text) + .join("\n\n"); + } + return message; +} + /** * HistoryChatEngine is a ChatEngine that uses a `ChatHistory` object * to keeps track of chat's message history. @@ -413,15 +441,8 @@ export class HistoryChatEngine { let requestMessages; let context; if (this.contextGenerator) { - if (Array.isArray(message)) { - // message is of type MessageContentDetail[] - retrieve just the text parts and concatenate them - // so we can pass them to the context generator - message = (message as MessageContentDetail[]) - .filter((c) => c.type === "text") - .map((c) => c.text) - .join("\n\n"); - } - context = await this.contextGenerator.generate(message); + const textOnly = extractText(message); + context = await this.contextGenerator.generate(textOnly); } requestMessages = await chatHistory.requestMessages( context ? [context.message] : undefined, diff --git a/packages/create-llama/create-app.ts b/packages/create-llama/create-app.ts index 25c47dbd0bcd72a01625c2fadd0b04fda7da22be..a6be4dd31de74132a9a9926260ca374bdb896d5f 100644 --- a/packages/create-llama/create-app.ts +++ b/packages/create-llama/create-app.ts @@ -22,6 +22,7 @@ export async function createApp({ eslint, frontend, openAIKey, + model, }: Omit< InstallTemplateArgs, "appName" | "root" | "isOnline" | "customApiPath" @@ -65,6 +66,7 @@ export async function createApp({ isOnline, eslint, openAIKey, + model, }; if (frontend) { diff --git a/packages/create-llama/index.ts b/packages/create-llama/index.ts index 14571b74b651992c937279a5a0cff365d33de6c3..e051f595f84687d1b870450590ece16e5cf16aa4 100644 --- a/packages/create-llama/index.ts +++ b/packages/create-llama/index.ts @@ -277,6 +277,32 @@ async function run(): Promise<void> { } } + if (program.framework === "nextjs") { + if (!program.model) { + if (ciInfo.isCI) { + program.model = getPrefOrDefault("model"); + } else { + const { model } = await prompts( + { + type: "select", + name: "model", + message: "Which model would you like to use?", + choices: [ + { title: "gpt-3.5-turbo", value: "gpt-3.5-turbo" }, + { title: "gpt-4", value: "gpt-4" }, + { title: "gpt-4-1106-preview", value: "gpt-4-1106-preview" }, + { title: "gpt-4-vision-preview", value: "gpt-4-vision-preview" }, + ], + initial: 0, + }, + handlers, + ); + program.model = model; + preferences.model = model; + } + } + } + if (program.framework === "express" || program.framework === "nextjs") { if (!program.engine) { if (ciInfo.isCI) { @@ -350,6 +376,7 @@ async function run(): Promise<void> { eslint: program.eslint, frontend: program.frontend, openAIKey: program.openAIKey, + model: program.model, }); conf.set("preferences", preferences); } diff --git a/packages/create-llama/templates/components/ui/shadcn/chat/chat-input.tsx b/packages/create-llama/templates/components/ui/shadcn/chat/chat-input.tsx index 1a0cc3e0cc6d92178bafce5e8d573586e024e3f6..435637e5ec94fdb9fe03faa3c3e1791a0be584bb 100644 --- a/packages/create-llama/templates/components/ui/shadcn/chat/chat-input.tsx +++ b/packages/create-llama/templates/components/ui/shadcn/chat/chat-input.tsx @@ -1,29 +1,84 @@ +import { useState } from "react"; import { Button } from "../button"; +import FileUploader from "../file-uploader"; import { Input } from "../input"; +import UploadImagePreview from "../upload-image-preview"; import { ChatHandler } from "./chat.interface"; export default function ChatInput( props: Pick< ChatHandler, - "isLoading" | "handleSubmit" | "handleInputChange" | "input" - >, + | "isLoading" + | "input" + | "onFileUpload" + | "onFileError" + | "handleSubmit" + | "handleInputChange" + > & { + multiModal?: boolean; + }, ) { + const [imageUrl, setImageUrl] = useState<string | null>(null); + + const onSubmit = (e: React.FormEvent<HTMLFormElement>) => { + if (imageUrl) { + props.handleSubmit(e, { + data: { imageUrl: imageUrl }, + }); + setImageUrl(null); + return; + } + props.handleSubmit(e); + }; + + const onRemovePreviewImage = () => setImageUrl(null); + + const handleUploadImageFile = async (file: File) => { + const base64 = await new Promise<string>((resolve, reject) => { + const reader = new FileReader(); + reader.readAsDataURL(file); + reader.onload = () => resolve(reader.result as string); + reader.onerror = (error) => reject(error); + }); + setImageUrl(base64); + }; + + const handleUploadFile = async (file: File) => { + try { + if (props.multiModal && file.type.startsWith("image/")) { + return await handleUploadImageFile(file); + } + props.onFileUpload?.(file); + } catch (error: any) { + props.onFileError?.(error.message); + } + }; + return ( <form - onSubmit={props.handleSubmit} - className="flex w-full items-start justify-between gap-4 rounded-xl bg-white p-4 shadow-xl" + onSubmit={onSubmit} + className="rounded-xl bg-white p-4 shadow-xl space-y-4" > - <Input - autoFocus - name="message" - placeholder="Type a message" - className="flex-1" - value={props.input} - onChange={props.handleInputChange} - /> - <Button type="submit" disabled={props.isLoading}> - Send message - </Button> + {imageUrl && ( + <UploadImagePreview url={imageUrl} onRemove={onRemovePreviewImage} /> + )} + <div className="flex w-full items-start justify-between gap-4 "> + <Input + autoFocus + name="message" + placeholder="Type a message" + className="flex-1" + value={props.input} + onChange={props.handleInputChange} + /> + <FileUploader + onFileUpload={handleUploadFile} + onFileError={props.onFileError} + /> + <Button type="submit" disabled={props.isLoading}> + Send message + </Button> + </div> </form> ); } diff --git a/packages/create-llama/templates/components/ui/shadcn/chat/chat.interface.ts b/packages/create-llama/templates/components/ui/shadcn/chat/chat.interface.ts index 3256f7f031b42114f192f3375632654dc21f78d8..584a63f7333b86531148dd7849c3233c82fdbf16 100644 --- a/packages/create-llama/templates/components/ui/shadcn/chat/chat.interface.ts +++ b/packages/create-llama/templates/components/ui/shadcn/chat/chat.interface.ts @@ -8,8 +8,15 @@ export interface ChatHandler { messages: Message[]; input: string; isLoading: boolean; - handleSubmit: (e: React.FormEvent<HTMLFormElement>) => void; + handleSubmit: ( + e: React.FormEvent<HTMLFormElement>, + ops?: { + data?: any; + }, + ) => void; handleInputChange: (e: React.ChangeEvent<HTMLInputElement>) => void; reload?: () => void; stop?: () => void; + onFileUpload?: (file: File) => Promise<void>; + onFileError?: (errMsg: string) => void; } diff --git a/packages/create-llama/templates/components/ui/shadcn/file-uploader.tsx b/packages/create-llama/templates/components/ui/shadcn/file-uploader.tsx new file mode 100644 index 0000000000000000000000000000000000000000..e42a267d18cbe76391d1decad0acca8fdf4dc295 --- /dev/null +++ b/packages/create-llama/templates/components/ui/shadcn/file-uploader.tsx @@ -0,0 +1,105 @@ +"use client"; + +import { Loader2, Paperclip } from "lucide-react"; +import { ChangeEvent, useState } from "react"; +import { buttonVariants } from "./button"; +import { cn } from "./lib/utils"; + +export interface FileUploaderProps { + config?: { + inputId?: string; + fileSizeLimit?: number; + allowedExtensions?: string[]; + checkExtension?: (extension: string) => string | null; + disabled: boolean; + }; + onFileUpload: (file: File) => Promise<void>; + onFileError?: (errMsg: string) => void; +} + +const DEFAULT_INPUT_ID = "fileInput"; +const DEFAULT_FILE_SIZE_LIMIT = 1024 * 1024 * 50; // 50 MB + +export default function FileUploader({ + config, + onFileUpload, + onFileError, +}: FileUploaderProps) { + const [uploading, setUploading] = useState(false); + + const inputId = config?.inputId || DEFAULT_INPUT_ID; + const fileSizeLimit = config?.fileSizeLimit || DEFAULT_FILE_SIZE_LIMIT; + const allowedExtensions = config?.allowedExtensions; + const defaultCheckExtension = (extension: string) => { + if (allowedExtensions && !allowedExtensions.includes(extension)) { + return `Invalid file type. Please select a file with one of these formats: ${allowedExtensions!.join( + ",", + )}`; + } + return null; + }; + const checkExtension = config?.checkExtension ?? defaultCheckExtension; + + const isFileSizeExceeded = (file: File) => { + return file.size > fileSizeLimit; + }; + + const resetInput = () => { + const fileInput = document.getElementById(inputId) as HTMLInputElement; + fileInput.value = ""; + }; + + const onFileChange = async (e: ChangeEvent<HTMLInputElement>) => { + const file = e.target.files?.[0]; + if (!file) return; + + setUploading(true); + await handleUpload(file); + resetInput(); + setUploading(false); + }; + + const handleUpload = async (file: File) => { + const onFileUploadError = onFileError || window.alert; + const fileExtension = file.name.split(".").pop() || ""; + const extensionFileError = checkExtension(fileExtension); + if (extensionFileError) { + return onFileUploadError(extensionFileError); + } + + if (isFileSizeExceeded(file)) { + return onFileUploadError( + `File size exceeded. Limit is ${fileSizeLimit / 1024 / 1024} MB`, + ); + } + + await onFileUpload(file); + }; + + return ( + <div className="self-stretch"> + <input + type="file" + id={inputId} + style={{ display: "none" }} + onChange={onFileChange} + accept={allowedExtensions?.join(",")} + disabled={config?.disabled || uploading} + /> + <label + htmlFor={inputId} + className={cn( + buttonVariants({ variant: "secondary", size: "icon" }), + "cursor-pointer", + uploading && "opacity-50", + )} + > + {uploading ? ( + <Loader2 className="h-4 w-4 animate-spin" /> + ) : ( + <Paperclip className="-rotate-45 w-4 h-4" /> + )} + </label> + </div> + ); +} diff --git a/packages/create-llama/templates/components/ui/shadcn/upload-image-preview.tsx b/packages/create-llama/templates/components/ui/shadcn/upload-image-preview.tsx new file mode 100644 index 0000000000000000000000000000000000000000..55ef6e9c2793ef4eb935422a9eedbfdb611a2304 --- /dev/null +++ b/packages/create-llama/templates/components/ui/shadcn/upload-image-preview.tsx @@ -0,0 +1,32 @@ +import { XCircleIcon } from "lucide-react"; +import Image from "next/image"; +import { cn } from "./lib/utils"; + +export default function UploadImagePreview({ + url, + onRemove, +}: { + url: string; + onRemove: () => void; +}) { + return ( + <div className="relative w-20 h-20 group"> + <Image + src={url} + alt="Uploaded image" + fill + className="object-cover w-full h-full rounded-xl hover:brightness-75" + /> + <div + className={cn( + "absolute -top-2 -right-2 w-6 h-6 z-10 bg-gray-500 text-white rounded-full hidden group-hover:block", + )} + > + <XCircleIcon + className="w-6 h-6 bg-gray-500 text-white rounded-full" + onClick={onRemove} + /> + </div> + </div> + ); +} diff --git a/packages/create-llama/templates/index.ts b/packages/create-llama/templates/index.ts index 2c1f8737706d9f58afb7a269d15c743f2f8f2bc3..fd5377d87115361e4e7c609569928ac4da7686fd 100644 --- a/packages/create-llama/templates/index.ts +++ b/packages/create-llama/templates/index.ts @@ -101,6 +101,7 @@ const installTSTemplate = async ({ eslint, customApiPath, forBackend, + model, }: InstallTemplateArgs) => { console.log(bold(`Using ${packageManager}.`)); @@ -173,6 +174,14 @@ const installTSTemplate = async ({ }); } + if (framework === "nextjs") { + await fs.writeFile( + path.join(root, "constants.ts"), + `export const MODEL = "${model || "gpt-3.5-turbo"}";\n`, + ); + console.log("\nUsing OpenAI model: ", model || "gpt-3.5-turbo", "\n"); + } + /** * Update the package.json scripts. */ diff --git a/packages/create-llama/templates/types.ts b/packages/create-llama/templates/types.ts index 926dddd5d97899e0886f6c2fc788906d168363e4..f6af4de02e21cbd018b2b82cf7473ce54e1caa6e 100644 --- a/packages/create-llama/templates/types.ts +++ b/packages/create-llama/templates/types.ts @@ -18,4 +18,5 @@ export interface InstallTemplateArgs { customApiPath?: string; openAIKey?: string; forBackend?: string; + model: string; } diff --git a/packages/create-llama/templates/types/streaming/nextjs/app/api/chat/route.ts b/packages/create-llama/templates/types/streaming/nextjs/app/api/chat/route.ts index 989a5fec484fae3b34060ed20b4c1d5f1e42773f..850ab55f615320b1d43af2d6f98896b4a17b4ef0 100644 --- a/packages/create-llama/templates/types/streaming/nextjs/app/api/chat/route.ts +++ b/packages/create-llama/templates/types/streaming/nextjs/app/api/chat/route.ts @@ -1,5 +1,6 @@ +import { MODEL } from "@/constants"; import { Message, StreamingTextResponse } from "ai"; -import { OpenAI } from "llamaindex"; +import { MessageContent, OpenAI } from "llamaindex"; import { NextRequest, NextResponse } from "next/server"; import { createChatEngine } from "./engine"; import { LlamaIndexStream } from "./llamaindex-stream"; @@ -7,10 +8,29 @@ import { LlamaIndexStream } from "./llamaindex-stream"; export const runtime = "nodejs"; export const dynamic = "force-dynamic"; +const getLastMessageContent = ( + textMessage: string, + imageUrl: string | undefined, +): MessageContent => { + if (!imageUrl) return textMessage; + return [ + { + type: "text", + text: textMessage, + }, + { + type: "image_url", + image_url: { + url: imageUrl, + }, + }, + ]; +}; + export async function POST(request: NextRequest) { try { const body = await request.json(); - const { messages }: { messages: Message[] } = body; + const { messages, data }: { messages: Message[]; data: any } = body; const lastMessage = messages.pop(); if (!messages || !lastMessage || lastMessage.role !== "user") { return NextResponse.json( @@ -23,12 +43,21 @@ export async function POST(request: NextRequest) { } const llm = new OpenAI({ - model: "gpt-3.5-turbo", + model: MODEL, }); const chatEngine = await createChatEngine(llm); - const response = await chatEngine.chat(lastMessage.content, messages, true); + const lastMessageContent = getLastMessageContent( + lastMessage.content, + data?.imageUrl, + ); + + const response = await chatEngine.chat( + lastMessageContent as MessageContent, + messages, + true, + ); // Transform the response into a readable stream const stream = LlamaIndexStream(response); diff --git a/packages/create-llama/templates/types/streaming/nextjs/app/components/chat-section.tsx b/packages/create-llama/templates/types/streaming/nextjs/app/components/chat-section.tsx index 04098fcdf41a4e686459f0ce0d6f3ca311895a9e..791b223f40eb4fe716369efdb4d712dfa415fd8c 100644 --- a/packages/create-llama/templates/types/streaming/nextjs/app/components/chat-section.tsx +++ b/packages/create-llama/templates/types/streaming/nextjs/app/components/chat-section.tsx @@ -1,5 +1,6 @@ "use client"; +import { MODEL } from "@/constants"; import { useChat } from "ai/react"; import { ChatInput, ChatMessages } from "./ui/chat"; @@ -27,6 +28,7 @@ export default function ChatSection() { handleSubmit={handleSubmit} handleInputChange={handleInputChange} isLoading={isLoading} + multiModal={MODEL === "gpt-4-vision-preview"} /> </div> ); diff --git a/packages/create-llama/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx b/packages/create-llama/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx index 3eb979b02735943f3f11290c78b84f0e37709438..7c3e87280b03ed571e8fc081a38c15a8d36df1ab 100644 --- a/packages/create-llama/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx +++ b/packages/create-llama/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx @@ -12,6 +12,7 @@ export interface ChatInputProps { /** Form submission handler to automatically reset input and append a user message */ handleSubmit: (e: React.FormEvent<HTMLFormElement>) => void; isLoading: boolean; + multiModal?: boolean; } export default function ChatInput(props: ChatInputProps) { diff --git a/packages/create-llama/templates/types/streaming/nextjs/constants.ts b/packages/create-llama/templates/types/streaming/nextjs/constants.ts new file mode 100644 index 0000000000000000000000000000000000000000..0959a5f6f4f9301f2559cf5b04c36de0ef3afe4a --- /dev/null +++ b/packages/create-llama/templates/types/streaming/nextjs/constants.ts @@ -0,0 +1 @@ +export const MODEL = "gpt-4-vision-preview"; diff --git a/packages/create-llama/templates/types/streaming/nextjs/package.json b/packages/create-llama/templates/types/streaming/nextjs/package.json index fbac43957dfce816506d1db7c4431352109ca61d..ea58702ea4375ab11d5a717ceef261f876f237bf 100644 --- a/packages/create-llama/templates/types/streaming/nextjs/package.json +++ b/packages/create-llama/templates/types/streaming/nextjs/package.json @@ -8,7 +8,7 @@ "lint": "next lint" }, "dependencies": { - "ai": "^2.2.5", + "ai": "^2.2.25", "llamaindex": "0.0.31", "dotenv": "^16.3.1", "next": "^13", @@ -26,4 +26,4 @@ "tailwindcss": "^3.3", "typescript": "^5" } -} +} \ No newline at end of file