From 8890e27a1410c206950372b15d7e887847d1bc32 Mon Sep 17 00:00:00 2001 From: Thuc Pham <51660321+thucpn@users.noreply.github.com> Date: Mon, 5 Aug 2024 22:18:20 +0700 Subject: [PATCH] feat: implement index selector for LlamaCloud (#200) --------- Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de> --- .changeset/quick-walls-switch.md | 5 + helpers/env-variables.ts | 17 ++ helpers/index.ts | 1 + .../engines/python/agent/__init__.py | 2 +- .../engines/python/chat/__init__.py | 4 +- .../engines/typescript/agent/chat.ts | 4 +- .../engines/typescript/chat/chat.ts | 4 +- .../typescript/streaming/service.ts | 95 +++++++++-- .../vectordbs/python/llamacloud/index.py | 9 +- .../components/vectordbs/python/none/index.py | 2 +- .../vectordbs/typescript/astra/index.ts | 2 +- .../vectordbs/typescript/chroma/index.ts | 2 +- .../vectordbs/typescript/llamacloud/index.ts | 26 ++- .../vectordbs/typescript/milvus/index.ts | 2 +- .../vectordbs/typescript/mongo/index.ts | 2 +- .../vectordbs/typescript/none/index.ts | 2 +- .../vectordbs/typescript/pg/index.ts | 2 +- .../vectordbs/typescript/pinecone/index.ts | 2 +- .../vectordbs/typescript/qdrant/index.ts | 2 +- .../src/controllers/chat-config.controller.ts | 12 ++ .../src/controllers/chat.controller.ts | 4 +- .../express/src/routes/chat.route.ts | 6 +- .../streaming/fastapi/app/api/routers/chat.py | 23 ++- .../fastapi/app/api/routers/models.py | 3 +- .../fastapi/app/api/services/llama_cloud.py | 26 +++ .../streaming/fastapi/app/engine/index.py | 2 +- .../app/api/chat/config/llamacloud/route.ts | 16 ++ .../streaming/nextjs/app/api/chat/route.ts | 4 +- .../app/components/ui/chat/chat-input.tsx | 23 ++- .../ui/chat/widgets/LlamaCloudSelector.tsx | 151 +++++++++++++++++ .../nextjs/app/components/ui/select.tsx | 159 ++++++++++++++++++ templates/types/streaming/nextjs/package.json | 1 + 32 files changed, 556 insertions(+), 59 deletions(-) create mode 100644 .changeset/quick-walls-switch.md create mode 100644 templates/types/streaming/nextjs/app/api/chat/config/llamacloud/route.ts create mode 100644 templates/types/streaming/nextjs/app/components/ui/chat/widgets/LlamaCloudSelector.tsx create mode 100644 templates/types/streaming/nextjs/app/components/ui/select.tsx diff --git a/.changeset/quick-walls-switch.md b/.changeset/quick-walls-switch.md new file mode 100644 index 00000000..31ff2a41 --- /dev/null +++ b/.changeset/quick-walls-switch.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Let user change indexes in LlamaCloud projects diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts index bb04a753..34930cad 100644 --- a/helpers/env-variables.ts +++ b/helpers/env-variables.ts @@ -160,6 +160,17 @@ const getVectorDBEnvs = ( description: "The organization ID for the LlamaCloud project (uses default organization if not specified - Python only)", }, + ...(framework === "nextjs" + ? // activate index selector per default (not needed for non-NextJS backends as it's handled by createFrontendEnvFile) + [ + { + name: "NEXT_PUBLIC_USE_LLAMACLOUD", + description: + "Let's the user change indexes in LlamaCloud projects", + value: "true", + }, + ] + : []), ]; case "chroma": const envs = [ @@ -493,6 +504,7 @@ export const createFrontendEnvFile = async ( root: string, opts: { customApiPath?: string; + vectorDb?: TemplateVectorDB; }, ) => { const defaultFrontendEnvs = [ @@ -503,6 +515,11 @@ export const createFrontendEnvFile = async ( ? opts.customApiPath : "http://localhost:8000/api/chat", }, + { + name: "NEXT_PUBLIC_USE_LLAMACLOUD", + description: "Let's the user change indexes in LlamaCloud projects", + value: opts.vectorDb === "llamacloud" ? "true" : "false", + }, ]; const content = renderEnvVar(defaultFrontendEnvs); await fs.writeFile(path.join(root, ".env"), content); diff --git a/helpers/index.ts b/helpers/index.ts index ae83f56c..79f3c687 100644 --- a/helpers/index.ts +++ b/helpers/index.ts @@ -209,6 +209,7 @@ export const installTemplate = async ( // this is a frontend for a full-stack app, create .env file with model information await createFrontendEnvFile(props.root, { customApiPath: props.customApiPath, + vectorDb: props.vectorDb, }); } }; diff --git a/templates/components/engines/python/agent/__init__.py b/templates/components/engines/python/agent/__init__.py index 17e36236..fb8d410c 100644 --- a/templates/components/engines/python/agent/__init__.py +++ b/templates/components/engines/python/agent/__init__.py @@ -6,7 +6,7 @@ from app.engine.tools import ToolFactory from app.engine.index import get_index -def get_chat_engine(filters=None): +def get_chat_engine(filters=None, params=None): system_prompt = os.getenv("SYSTEM_PROMPT") top_k = os.getenv("TOP_K", "3") tools = [] diff --git a/templates/components/engines/python/chat/__init__.py b/templates/components/engines/python/chat/__init__.py index f885ed13..7d8df55a 100644 --- a/templates/components/engines/python/chat/__init__.py +++ b/templates/components/engines/python/chat/__init__.py @@ -3,11 +3,11 @@ from app.engine.index import get_index from fastapi import HTTPException -def get_chat_engine(filters=None): +def get_chat_engine(filters=None, params=None): system_prompt = os.getenv("SYSTEM_PROMPT") top_k = os.getenv("TOP_K", 3) - index = get_index() + index = get_index(params) if index is None: raise HTTPException( status_code=500, diff --git a/templates/components/engines/typescript/agent/chat.ts b/templates/components/engines/typescript/agent/chat.ts index f3b43840..c9868ff7 100644 --- a/templates/components/engines/typescript/agent/chat.ts +++ b/templates/components/engines/typescript/agent/chat.ts @@ -10,12 +10,12 @@ import path from "node:path"; import { getDataSource } from "./index"; import { createTools } from "./tools"; -export async function createChatEngine(documentIds?: string[]) { +export async function createChatEngine(documentIds?: string[], params?: any) { const tools: BaseToolWithCall[] = []; // Add a query engine tool if we have a data source // Delete this code if you don't have a data source - const index = await getDataSource(); + const index = await getDataSource(params); if (index) { tools.push( new QueryEngineTool({ diff --git a/templates/components/engines/typescript/chat/chat.ts b/templates/components/engines/typescript/chat/chat.ts index 1144256a..d2badd71 100644 --- a/templates/components/engines/typescript/chat/chat.ts +++ b/templates/components/engines/typescript/chat/chat.ts @@ -6,8 +6,8 @@ import { } from "llamaindex"; import { getDataSource } from "./index"; -export async function createChatEngine(documentIds?: string[]) { - const index = await getDataSource(); +export async function createChatEngine(documentIds?: string[], params?: any) { + const index = await getDataSource(params); if (!index) { throw new Error( `StorageContext is empty - call 'npm run generate' to generate the storage first`, diff --git a/templates/components/llamaindex/typescript/streaming/service.ts b/templates/components/llamaindex/typescript/streaming/service.ts index 6b6c4206..91001e91 100644 --- a/templates/components/llamaindex/typescript/streaming/service.ts +++ b/templates/components/llamaindex/typescript/streaming/service.ts @@ -7,19 +7,51 @@ const LLAMA_CLOUD_OUTPUT_DIR = "output/llamacloud"; const LLAMA_CLOUD_BASE_URL = "https://cloud.llamaindex.ai/api/v1"; const FILE_DELIMITER = "$"; // delimiter between pipelineId and filename -interface LlamaCloudFile { +type LlamaCloudFile = { name: string; file_id: string; project_id: string; -} +}; + +type LLamaCloudProject = { + id: string; + organization_id: string; + name: string; + is_default: boolean; +}; + +type LLamaCloudPipeline = { + id: string; + name: string; + project_id: string; +}; export class LLamaCloudFileService { + private static readonly headers = { + Accept: "application/json", + Authorization: `Bearer ${process.env.LLAMA_CLOUD_API_KEY}`, + }; + + public static async getAllProjectsWithPipelines() { + try { + const projects = await LLamaCloudFileService.getAllProjects(); + const pipelines = await LLamaCloudFileService.getAllPipelines(); + return projects.map((project) => ({ + ...project, + pipelines: pipelines.filter((p) => p.project_id === project.id), + })); + } catch (error) { + console.error("Error listing projects and pipelines:", error); + return []; + } + } + public static async downloadFiles(nodes: NodeWithScore<Metadata>[]) { - const files = this.nodesToDownloadFiles(nodes); + const files = LLamaCloudFileService.nodesToDownloadFiles(nodes); if (!files.length) return; console.log("Downloading files from LlamaCloud..."); for (const file of files) { - await this.downloadFile(file.pipelineId, file.fileName); + await LLamaCloudFileService.downloadFile(file.pipelineId, file.fileName); } } @@ -59,13 +91,19 @@ export class LLamaCloudFileService { private static async downloadFile(pipelineId: string, fileName: string) { try { - const downloadedName = this.toDownloadedName(pipelineId, fileName); + const downloadedName = LLamaCloudFileService.toDownloadedName( + pipelineId, + fileName, + ); const downloadedPath = path.join(LLAMA_CLOUD_OUTPUT_DIR, downloadedName); // Check if file already exists if (fs.existsSync(downloadedPath)) return; - const urlToDownload = await this.getFileUrlByName(pipelineId, fileName); + const urlToDownload = await LLamaCloudFileService.getFileUrlByName( + pipelineId, + fileName, + ); if (!urlToDownload) throw new Error("File not found in LlamaCloud"); const file = fs.createWriteStream(downloadedPath); @@ -93,10 +131,13 @@ export class LLamaCloudFileService { pipelineId: string, name: string, ): Promise<string | null> { - const files = await this.getAllFiles(pipelineId); + const files = await LLamaCloudFileService.getAllFiles(pipelineId); const file = files.find((file) => file.name === name); if (!file) return null; - return await this.getFileUrlById(file.project_id, file.file_id); + return await LLamaCloudFileService.getFileUrlById( + file.project_id, + file.file_id, + ); } private static async getFileUrlById( @@ -104,11 +145,10 @@ export class LLamaCloudFileService { fileId: string, ): Promise<string> { const url = `${LLAMA_CLOUD_BASE_URL}/files/${fileId}/content?project_id=${projectId}`; - const headers = { - Accept: "application/json", - Authorization: `Bearer ${process.env.LLAMA_CLOUD_API_KEY}`, - }; - const response = await fetch(url, { method: "GET", headers }); + const response = await fetch(url, { + method: "GET", + headers: LLamaCloudFileService.headers, + }); const data = (await response.json()) as { url: string }; return data.url; } @@ -117,12 +157,31 @@ export class LLamaCloudFileService { pipelineId: string, ): Promise<LlamaCloudFile[]> { const url = `${LLAMA_CLOUD_BASE_URL}/pipelines/${pipelineId}/files`; - const headers = { - Accept: "application/json", - Authorization: `Bearer ${process.env.LLAMA_CLOUD_API_KEY}`, - }; - const response = await fetch(url, { method: "GET", headers }); + const response = await fetch(url, { + method: "GET", + headers: LLamaCloudFileService.headers, + }); const data = await response.json(); return data; } + + private static async getAllProjects(): Promise<LLamaCloudProject[]> { + const url = `${LLAMA_CLOUD_BASE_URL}/projects`; + const response = await fetch(url, { + method: "GET", + headers: LLamaCloudFileService.headers, + }); + const data = (await response.json()) as LLamaCloudProject[]; + return data; + } + + private static async getAllPipelines(): Promise<LLamaCloudPipeline[]> { + const url = `${LLAMA_CLOUD_BASE_URL}/pipelines`; + const response = await fetch(url, { + method: "GET", + headers: LLamaCloudFileService.headers, + }); + const data = (await response.json()) as LLamaCloudPipeline[]; + return data; + } } diff --git a/templates/components/vectordbs/python/llamacloud/index.py b/templates/components/vectordbs/python/llamacloud/index.py index da73434f..e54e8ca9 100644 --- a/templates/components/vectordbs/python/llamacloud/index.py +++ b/templates/components/vectordbs/python/llamacloud/index.py @@ -5,10 +5,11 @@ from llama_index.indices.managed.llama_cloud import LlamaCloudIndex logger = logging.getLogger("uvicorn") - -def get_index(): - name = os.getenv("LLAMA_CLOUD_INDEX_NAME") - project_name = os.getenv("LLAMA_CLOUD_PROJECT_NAME") +def get_index(params=None): + configParams = params or {} + pipelineConfig = configParams.get("llamaCloudPipeline", {}) + name = pipelineConfig.get("pipeline", os.getenv("LLAMA_CLOUD_INDEX_NAME")) + project_name = pipelineConfig.get("project", os.getenv("LLAMA_CLOUD_PROJECT_NAME")) api_key = os.getenv("LLAMA_CLOUD_API_KEY") base_url = os.getenv("LLAMA_CLOUD_BASE_URL") organization_id = os.getenv("LLAMA_CLOUD_ORGANIZATION_ID") diff --git a/templates/components/vectordbs/python/none/index.py b/templates/components/vectordbs/python/none/index.py index f7949d66..65fd5ad5 100644 --- a/templates/components/vectordbs/python/none/index.py +++ b/templates/components/vectordbs/python/none/index.py @@ -17,7 +17,7 @@ def get_storage_context(persist_dir: str) -> StorageContext: return StorageContext.from_defaults(persist_dir=persist_dir) -def get_index(): +def get_index(params=None): storage_dir = os.getenv("STORAGE_DIR", "storage") # check if storage already exists if not os.path.exists(storage_dir): diff --git a/templates/components/vectordbs/typescript/astra/index.ts b/templates/components/vectordbs/typescript/astra/index.ts index e29ed353..38c5bbbd 100644 --- a/templates/components/vectordbs/typescript/astra/index.ts +++ b/templates/components/vectordbs/typescript/astra/index.ts @@ -3,7 +3,7 @@ import { VectorStoreIndex } from "llamaindex"; import { AstraDBVectorStore } from "llamaindex/storage/vectorStore/AstraDBVectorStore"; import { checkRequiredEnvVars } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const store = new AstraDBVectorStore(); await store.connect(process.env.ASTRA_DB_COLLECTION!); diff --git a/templates/components/vectordbs/typescript/chroma/index.ts b/templates/components/vectordbs/typescript/chroma/index.ts index 1d36e643..fbc7b4bf 100644 --- a/templates/components/vectordbs/typescript/chroma/index.ts +++ b/templates/components/vectordbs/typescript/chroma/index.ts @@ -3,7 +3,7 @@ import { VectorStoreIndex } from "llamaindex"; import { ChromaVectorStore } from "llamaindex/storage/vectorStore/ChromaVectorStore"; import { checkRequiredEnvVars } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const chromaUri = `http://${process.env.CHROMA_HOST}:${process.env.CHROMA_PORT}`; diff --git a/templates/components/vectordbs/typescript/llamacloud/index.ts b/templates/components/vectordbs/typescript/llamacloud/index.ts index 3f0875cc..413d97b6 100644 --- a/templates/components/vectordbs/typescript/llamacloud/index.ts +++ b/templates/components/vectordbs/typescript/llamacloud/index.ts @@ -1,12 +1,26 @@ import { LlamaCloudIndex } from "llamaindex/cloud/LlamaCloudIndex"; -import { checkRequiredEnvVars } from "./shared"; -export async function getDataSource() { - checkRequiredEnvVars(); +type LlamaCloudDataSourceParams = { + llamaCloudPipeline?: { + project: string; + pipeline: string; + }; +}; + +export async function getDataSource(params?: LlamaCloudDataSourceParams) { + const { project, pipeline } = params?.llamaCloudPipeline ?? {}; + const projectName = project ?? process.env.LLAMA_CLOUD_PROJECT_NAME; + const pipelineName = pipeline ?? process.env.LLAMA_CLOUD_INDEX_NAME; + const apiKey = process.env.LLAMA_CLOUD_API_KEY; + if (!projectName || !pipelineName || !apiKey) { + throw new Error( + "Set project, pipeline, and api key in the params or as environment variables.", + ); + } const index = new LlamaCloudIndex({ - name: process.env.LLAMA_CLOUD_INDEX_NAME!, - projectName: process.env.LLAMA_CLOUD_PROJECT_NAME!, - apiKey: process.env.LLAMA_CLOUD_API_KEY, + name: pipelineName, + projectName, + apiKey, baseUrl: process.env.LLAMA_CLOUD_BASE_URL, }); return index; diff --git a/templates/components/vectordbs/typescript/milvus/index.ts b/templates/components/vectordbs/typescript/milvus/index.ts index c290175f..91275b11 100644 --- a/templates/components/vectordbs/typescript/milvus/index.ts +++ b/templates/components/vectordbs/typescript/milvus/index.ts @@ -2,7 +2,7 @@ import { VectorStoreIndex } from "llamaindex"; import { MilvusVectorStore } from "llamaindex/storage/vectorStore/MilvusVectorStore"; import { checkRequiredEnvVars, getMilvusClient } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const milvusClient = getMilvusClient(); const store = new MilvusVectorStore({ milvusClient }); diff --git a/templates/components/vectordbs/typescript/mongo/index.ts b/templates/components/vectordbs/typescript/mongo/index.ts index effb8f92..efa35fa5 100644 --- a/templates/components/vectordbs/typescript/mongo/index.ts +++ b/templates/components/vectordbs/typescript/mongo/index.ts @@ -4,7 +4,7 @@ import { MongoDBAtlasVectorSearch } from "llamaindex/storage/vectorStore/MongoDB import { MongoClient } from "mongodb"; import { checkRequiredEnvVars } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const client = new MongoClient(process.env.MONGO_URI!); const store = new MongoDBAtlasVectorSearch({ diff --git a/templates/components/vectordbs/typescript/none/index.ts b/templates/components/vectordbs/typescript/none/index.ts index 64b28975..fecc76f4 100644 --- a/templates/components/vectordbs/typescript/none/index.ts +++ b/templates/components/vectordbs/typescript/none/index.ts @@ -2,7 +2,7 @@ import { SimpleDocumentStore, VectorStoreIndex } from "llamaindex"; import { storageContextFromDefaults } from "llamaindex/storage/StorageContext"; import { STORAGE_CACHE_DIR } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { const storageContext = await storageContextFromDefaults({ persistDir: `${STORAGE_CACHE_DIR}`, }); diff --git a/templates/components/vectordbs/typescript/pg/index.ts b/templates/components/vectordbs/typescript/pg/index.ts index 787cae74..75bcd403 100644 --- a/templates/components/vectordbs/typescript/pg/index.ts +++ b/templates/components/vectordbs/typescript/pg/index.ts @@ -7,7 +7,7 @@ import { checkRequiredEnvVars, } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const pgvs = new PGVectorStore({ connectionString: process.env.PG_CONNECTION_STRING, diff --git a/templates/components/vectordbs/typescript/pinecone/index.ts b/templates/components/vectordbs/typescript/pinecone/index.ts index 15072cff..66a22d46 100644 --- a/templates/components/vectordbs/typescript/pinecone/index.ts +++ b/templates/components/vectordbs/typescript/pinecone/index.ts @@ -3,7 +3,7 @@ import { VectorStoreIndex } from "llamaindex"; import { PineconeVectorStore } from "llamaindex/storage/vectorStore/PineconeVectorStore"; import { checkRequiredEnvVars } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const store = new PineconeVectorStore(); return await VectorStoreIndex.fromVectorStore(store); diff --git a/templates/components/vectordbs/typescript/qdrant/index.ts b/templates/components/vectordbs/typescript/qdrant/index.ts index 0233d088..a9d87ab8 100644 --- a/templates/components/vectordbs/typescript/qdrant/index.ts +++ b/templates/components/vectordbs/typescript/qdrant/index.ts @@ -5,7 +5,7 @@ import { checkRequiredEnvVars, getQdrantClient } from "./shared"; dotenv.config(); -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const collectionName = process.env.QDRANT_COLLECTION; const store = new QdrantVectorStore({ diff --git a/templates/types/streaming/express/src/controllers/chat-config.controller.ts b/templates/types/streaming/express/src/controllers/chat-config.controller.ts index 4481e10d..af843c2c 100644 --- a/templates/types/streaming/express/src/controllers/chat-config.controller.ts +++ b/templates/types/streaming/express/src/controllers/chat-config.controller.ts @@ -1,4 +1,5 @@ import { Request, Response } from "express"; +import { LLamaCloudFileService } from "./llamaindex/streaming/service"; export const chatConfig = async (_req: Request, res: Response) => { let starterQuestions = undefined; @@ -12,3 +13,14 @@ export const chatConfig = async (_req: Request, res: Response) => { starterQuestions, }); }; + +export const chatLlamaCloudConfig = async (_req: Request, res: Response) => { + const config = { + projects: await LLamaCloudFileService.getAllProjectsWithPipelines(), + pipeline: { + pipeline: process.env.LLAMA_CLOUD_INDEX_NAME, + project: process.env.LLAMA_CLOUD_PROJECT_NAME, + }, + }; + return res.status(200).json(config); +}; diff --git a/templates/types/streaming/express/src/controllers/chat.controller.ts b/templates/types/streaming/express/src/controllers/chat.controller.ts index 95228e8d..50b70789 100644 --- a/templates/types/streaming/express/src/controllers/chat.controller.ts +++ b/templates/types/streaming/express/src/controllers/chat.controller.ts @@ -17,7 +17,7 @@ export const chat = async (req: Request, res: Response) => { const vercelStreamData = new StreamData(); const streamTimeout = createStreamTimeout(vercelStreamData); try { - const { messages }: { messages: Message[] } = req.body; + const { messages, data }: { messages: Message[]; data?: any } = req.body; const userMessage = messages.pop(); if (!messages || !userMessage || userMessage.role !== "user") { return res.status(400).json({ @@ -46,7 +46,7 @@ export const chat = async (req: Request, res: Response) => { }, ); const ids = retrieveDocumentIds(allAnnotations); - const chatEngine = await createChatEngine(ids); + const chatEngine = await createChatEngine(ids, data); // Convert message content from Vercel/AI format to LlamaIndex/OpenAI format const userMessageContent = convertMessageContent( diff --git a/templates/types/streaming/express/src/routes/chat.route.ts b/templates/types/streaming/express/src/routes/chat.route.ts index 96711d50..5044efcb 100644 --- a/templates/types/streaming/express/src/routes/chat.route.ts +++ b/templates/types/streaming/express/src/routes/chat.route.ts @@ -1,5 +1,8 @@ import express, { Router } from "express"; -import { chatConfig } from "../controllers/chat-config.controller"; +import { + chatConfig, + chatLlamaCloudConfig, +} from "../controllers/chat-config.controller"; import { chatRequest } from "../controllers/chat-request.controller"; import { chatUpload } from "../controllers/chat-upload.controller"; import { chat } from "../controllers/chat.controller"; @@ -11,6 +14,7 @@ initSettings(); llmRouter.route("/").post(chat); llmRouter.route("/request").post(chatRequest); llmRouter.route("/config").get(chatConfig); +llmRouter.route("/config/llamacloud").get(chatLlamaCloudConfig); llmRouter.route("/upload").post(chatUpload); export default llmRouter; diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py index e2cffadf..cb7036d9 100644 --- a/templates/types/streaming/fastapi/app/api/routers/chat.py +++ b/templates/types/streaming/fastapi/app/api/routers/chat.py @@ -52,8 +52,9 @@ async def chat( doc_ids = data.get_chat_document_ids() filters = generate_filters(doc_ids) + params = data.data or {} logger.info("Creating chat engine with filters", filters.dict()) - chat_engine = get_chat_engine(filters=filters) + chat_engine = get_chat_engine(filters=filters, params=params) event_handler = EventCallbackHandler() chat_engine.callback_manager.handlers.append(event_handler) # type: ignore @@ -125,3 +126,23 @@ async def chat_config() -> ChatConfig: if conversation_starters and conversation_starters.strip(): starter_questions = conversation_starters.strip().split("\n") return ChatConfig(starter_questions=starter_questions) + + +@r.get("/config/llamacloud") +async def chat_llama_cloud_config(): + projects = LLamaCloudFileService.get_all_projects_with_pipelines() + pipeline = os.getenv("LLAMA_CLOUD_INDEX_NAME") + project = os.getenv("LLAMA_CLOUD_PROJECT_NAME") + pipeline_config = ( + pipeline + and project + and { + "pipeline": pipeline, + "project": project, + } + or None + ) + return { + "projects": projects, + "pipeline": pipeline_config, + } diff --git a/templates/types/streaming/fastapi/app/api/routers/models.py b/templates/types/streaming/fastapi/app/api/routers/models.py index b510e848..c9ea1adb 100644 --- a/templates/types/streaming/fastapi/app/api/routers/models.py +++ b/templates/types/streaming/fastapi/app/api/routers/models.py @@ -75,6 +75,7 @@ class Message(BaseModel): class ChatData(BaseModel): messages: List[Message] + data: Any = None class Config: json_schema_extra = { @@ -237,7 +238,7 @@ class ChatConfig(BaseModel): starter_questions: Optional[List[str]] = Field( default=None, description="List of starter questions", - serialization_alias="starterQuestions" + serialization_alias="starterQuestions", ) class Config: diff --git a/templates/types/streaming/fastapi/app/api/services/llama_cloud.py b/templates/types/streaming/fastapi/app/api/services/llama_cloud.py index ea03e64b..852ae7ce 100644 --- a/templates/types/streaming/fastapi/app/api/services/llama_cloud.py +++ b/templates/types/streaming/fastapi/app/api/services/llama_cloud.py @@ -14,6 +14,32 @@ class LLamaCloudFileService: DOWNLOAD_FILE_NAME_TPL = "{pipeline_id}${filename}" + @classmethod + def get_all_projects(cls) -> List[Dict[str, Any]]: + url = f"{cls.LLAMA_CLOUD_URL}/projects" + return cls._make_request(url) + + @classmethod + def get_all_pipelines(cls) -> List[Dict[str, Any]]: + url = f"{cls.LLAMA_CLOUD_URL}/pipelines" + return cls._make_request(url) + + @classmethod + def get_all_projects_with_pipelines(cls) -> List[Dict[str, Any]]: + try: + projects = cls.get_all_projects() + pipelines = cls.get_all_pipelines() + return [ + { + **project, + "pipelines": [p for p in pipelines if p["project_id"] == project["id"]], + } + for project in projects + ] + except Exception as error: + logger.error(f"Error listing projects and pipelines: {error}") + return [] + @classmethod def _get_files(cls, pipeline_id: str) -> List[Dict[str, Any]]: url = f"{cls.LLAMA_CLOUD_URL}/pipelines/{pipeline_id}/files" diff --git a/templates/types/streaming/fastapi/app/engine/index.py b/templates/types/streaming/fastapi/app/engine/index.py index 2dbc589b..e1adcb80 100644 --- a/templates/types/streaming/fastapi/app/engine/index.py +++ b/templates/types/streaming/fastapi/app/engine/index.py @@ -6,7 +6,7 @@ from app.engine.vectordb import get_vector_store logger = logging.getLogger("uvicorn") -def get_index(): +def get_index(params=None): logger.info("Connecting vector store...") store = get_vector_store() # Load the index from the vector store diff --git a/templates/types/streaming/nextjs/app/api/chat/config/llamacloud/route.ts b/templates/types/streaming/nextjs/app/api/chat/config/llamacloud/route.ts new file mode 100644 index 00000000..e40409f1 --- /dev/null +++ b/templates/types/streaming/nextjs/app/api/chat/config/llamacloud/route.ts @@ -0,0 +1,16 @@ +import { NextResponse } from "next/server"; +import { LLamaCloudFileService } from "../../llamaindex/streaming/service"; + +/** + * This API is to get config from the backend envs and expose them to the frontend + */ +export async function GET() { + const config = { + projects: await LLamaCloudFileService.getAllProjectsWithPipelines(), + pipeline: { + pipeline: process.env.LLAMA_CLOUD_INDEX_NAME, + project: process.env.LLAMA_CLOUD_PROJECT_NAME, + }, + }; + return NextResponse.json(config, { status: 200 }); +} diff --git a/templates/types/streaming/nextjs/app/api/chat/route.ts b/templates/types/streaming/nextjs/app/api/chat/route.ts index 792ecb7b..adfccf13 100644 --- a/templates/types/streaming/nextjs/app/api/chat/route.ts +++ b/templates/types/streaming/nextjs/app/api/chat/route.ts @@ -27,7 +27,7 @@ export async function POST(request: NextRequest) { try { const body = await request.json(); - const { messages }: { messages: Message[] } = body; + const { messages, data }: { messages: Message[]; data?: any } = body; const userMessage = messages.pop(); if (!messages || !userMessage || userMessage.role !== "user") { return NextResponse.json( @@ -59,7 +59,7 @@ export async function POST(request: NextRequest) { }, ); const ids = retrieveDocumentIds(allAnnotations); - const chatEngine = await createChatEngine(ids); + const chatEngine = await createChatEngine(ids, data); // Convert message content from Vercel/AI format to LlamaIndex/OpenAI format const userMessageContent = convertMessageContent( diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx index 01c7c0b1..4c582966 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx @@ -1,4 +1,5 @@ import { JSONValue } from "ai"; +import { useState } from "react"; import { Button } from "../button"; import { DocumentPreview } from "../document-preview"; import FileUploader from "../file-uploader"; @@ -6,6 +7,7 @@ import { Input } from "../input"; import UploadImagePreview from "../upload-image-preview"; import { ChatHandler } from "./chat.interface"; import { useFile } from "./hooks/use-file"; +import { LlamaCloudSelector } from "./widgets/LlamaCloudSelector"; const ALLOWED_EXTENSIONS = ["png", "jpg", "jpeg", "csv", "pdf", "txt", "docx"]; @@ -34,6 +36,7 @@ export default function ChatInput( reset, getAnnotations, } = useFile(); + const [requestData, setRequestData] = useState<any>(); // default submit function does not handle including annotations in the message // so we need to use append function to submit new message with annotations @@ -42,12 +45,15 @@ export default function ChatInput( annotations: JSONValue[] | undefined, ) => { e.preventDefault(); - props.append!({ - content: props.input, - role: "user", - createdAt: new Date(), - annotations, - }); + props.append!( + { + content: props.input, + role: "user", + createdAt: new Date(), + annotations, + }, + { data: requestData }, + ); props.setInput!(""); }; @@ -57,7 +63,7 @@ export default function ChatInput( handleSubmitWithAnnotations(e, annotations); return reset(); } - props.handleSubmit(e); + props.handleSubmit(e, { data: requestData }); }; const handleUploadFile = async (file: File) => { @@ -109,6 +115,9 @@ export default function ChatInput( disabled: props.isLoading, }} /> + {process.env.NEXT_PUBLIC_USE_LLAMACLOUD === "true" && ( + <LlamaCloudSelector setRequestData={setRequestData} /> + )} <Button type="submit" disabled={props.isLoading || !props.input.trim()}> Send message </Button> diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/widgets/LlamaCloudSelector.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/widgets/LlamaCloudSelector.tsx new file mode 100644 index 00000000..aa995c91 --- /dev/null +++ b/templates/types/streaming/nextjs/app/components/ui/chat/widgets/LlamaCloudSelector.tsx @@ -0,0 +1,151 @@ +import { Loader2 } from "lucide-react"; +import { useEffect, useState } from "react"; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectTrigger, + SelectValue, +} from "../../select"; +import { useClientConfig } from "../hooks/use-config"; + +type LLamaCloudPipeline = { + id: string; + name: string; +}; + +type LLamaCloudProject = { + id: string; + organization_id: string; + name: string; + is_default: boolean; + pipelines: Array<LLamaCloudPipeline>; +}; + +type PipelineConfig = { + project: string; // project name + pipeline: string; // pipeline name +}; + +type LlamaCloudConfig = { + projects?: LLamaCloudProject[]; + pipeline?: PipelineConfig; +}; + +export interface LlamaCloudSelectorProps { + setRequestData: React.Dispatch<any>; +} + +export function LlamaCloudSelector({ + setRequestData, +}: LlamaCloudSelectorProps) { + const { backend } = useClientConfig(); + const [config, setConfig] = useState<LlamaCloudConfig>(); + + useEffect(() => { + if (process.env.NEXT_PUBLIC_USE_LLAMACLOUD === "true" && !config) { + fetch(`${backend}/api/chat/config/llamacloud`) + .then((response) => response.json()) + .then((data) => { + setConfig(data); + setRequestData({ + llamaCloudPipeline: data.pipeline, + }); + }) + .catch((error) => console.error("Error fetching config", error)); + } + }, [backend, config, setRequestData]); + + const setPipeline = (pipelineConfig?: PipelineConfig) => { + setConfig((prevConfig: any) => ({ + ...prevConfig, + pipeline: pipelineConfig, + })); + setRequestData((prevData: any) => { + if (!prevData) return { llamaCloudPipeline: pipelineConfig }; + return { + ...prevData, + llamaCloudPipeline: pipelineConfig, + }; + }); + }; + + const handlePipelineSelect = async (value: string) => { + setPipeline(JSON.parse(value) as PipelineConfig); + }; + + if (!config) { + return ( + <div className="flex justify-center items-center p-3"> + <Loader2 className="h-4 w-4 animate-spin" /> + </div> + ); + } + if (!isValid(config)) { + return ( + <p className="text-red-500"> + Invalid LlamaCloud configuration. Check console logs. + </p> + ); + } + const { projects, pipeline } = config; + + return ( + <Select + onValueChange={handlePipelineSelect} + defaultValue={JSON.stringify(pipeline)} + > + <SelectTrigger className="w-[200px]"> + <SelectValue placeholder="Select a pipeline" /> + </SelectTrigger> + <SelectContent> + {projects!.map((project: LLamaCloudProject) => ( + <SelectGroup key={project.id}> + <SelectLabel className="capitalize"> + Project: {project.name} + </SelectLabel> + {project.pipelines.map((pipeline) => ( + <SelectItem + key={pipeline.id} + className="last:border-b" + value={JSON.stringify({ + pipeline: pipeline.name, + project: project.name, + })} + > + <span className="pl-2">{pipeline.name}</span> + </SelectItem> + ))} + </SelectGroup> + ))} + </SelectContent> + </Select> + ); +} + +function isValid(config: LlamaCloudConfig): boolean { + const { projects, pipeline } = config; + if (!projects?.length) return false; + if (!pipeline) return false; + const matchedProject = projects.find( + (project: LLamaCloudProject) => project.name === pipeline.project, + ); + if (!matchedProject) { + console.error( + `LlamaCloud project ${pipeline.project} not found. Check LLAMA_CLOUD_PROJECT_NAME variable`, + ); + return false; + } + const pipelineExists = matchedProject.pipelines.some( + (p) => p.name === pipeline.pipeline, + ); + if (!pipelineExists) { + console.error( + `LlamaCloud pipeline ${pipeline.pipeline} not found. Check LLAMA_CLOUD_INDEX_NAME variable`, + ); + return false; + } + return true; +} diff --git a/templates/types/streaming/nextjs/app/components/ui/select.tsx b/templates/types/streaming/nextjs/app/components/ui/select.tsx new file mode 100644 index 00000000..c01b068b --- /dev/null +++ b/templates/types/streaming/nextjs/app/components/ui/select.tsx @@ -0,0 +1,159 @@ +"use client"; + +import * as SelectPrimitive from "@radix-ui/react-select"; +import { Check, ChevronDown, ChevronUp } from "lucide-react"; +import * as React from "react"; +import { cn } from "./lib/utils"; + +const Select = SelectPrimitive.Root; + +const SelectGroup = SelectPrimitive.Group; + +const SelectValue = SelectPrimitive.Value; + +const SelectTrigger = React.forwardRef< + React.ElementRef<typeof SelectPrimitive.Trigger>, + React.ComponentPropsWithoutRef<typeof SelectPrimitive.Trigger> +>(({ className, children, ...props }, ref) => ( + <SelectPrimitive.Trigger + ref={ref} + className={cn( + "flex h-10 w-full items-center justify-between rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1", + className, + )} + {...props} + > + {children} + <SelectPrimitive.Icon asChild> + <ChevronDown className="h-4 w-4 opacity-50" /> + </SelectPrimitive.Icon> + </SelectPrimitive.Trigger> +)); +SelectTrigger.displayName = SelectPrimitive.Trigger.displayName; + +const SelectScrollUpButton = React.forwardRef< + React.ElementRef<typeof SelectPrimitive.ScrollUpButton>, + React.ComponentPropsWithoutRef<typeof SelectPrimitive.ScrollUpButton> +>(({ className, ...props }, ref) => ( + <SelectPrimitive.ScrollUpButton + ref={ref} + className={cn( + "flex cursor-default items-center justify-center py-1", + className, + )} + {...props} + > + <ChevronUp className="h-4 w-4" /> + </SelectPrimitive.ScrollUpButton> +)); +SelectScrollUpButton.displayName = SelectPrimitive.ScrollUpButton.displayName; + +const SelectScrollDownButton = React.forwardRef< + React.ElementRef<typeof SelectPrimitive.ScrollDownButton>, + React.ComponentPropsWithoutRef<typeof SelectPrimitive.ScrollDownButton> +>(({ className, ...props }, ref) => ( + <SelectPrimitive.ScrollDownButton + ref={ref} + className={cn( + "flex cursor-default items-center justify-center py-1", + className, + )} + {...props} + > + <ChevronDown className="h-4 w-4" /> + </SelectPrimitive.ScrollDownButton> +)); +SelectScrollDownButton.displayName = + SelectPrimitive.ScrollDownButton.displayName; + +const SelectContent = React.forwardRef< + React.ElementRef<typeof SelectPrimitive.Content>, + React.ComponentPropsWithoutRef<typeof SelectPrimitive.Content> +>(({ className, children, position = "popper", ...props }, ref) => ( + <SelectPrimitive.Portal> + <SelectPrimitive.Content + ref={ref} + className={cn( + "relative z-50 max-h-96 min-w-[8rem] overflow-hidden rounded-md border bg-popover text-popover-foreground shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2", + position === "popper" && + "data-[side=bottom]:translate-y-1 data-[side=left]:-translate-x-1 data-[side=right]:translate-x-1 data-[side=top]:-translate-y-1", + className, + )} + position={position} + {...props} + > + <SelectScrollUpButton /> + <SelectPrimitive.Viewport + className={cn( + "p-1", + position === "popper" && + "h-[var(--radix-select-trigger-height)] w-full min-w-[var(--radix-select-trigger-width)]", + )} + > + {children} + </SelectPrimitive.Viewport> + <SelectScrollDownButton /> + </SelectPrimitive.Content> + </SelectPrimitive.Portal> +)); +SelectContent.displayName = SelectPrimitive.Content.displayName; + +const SelectLabel = React.forwardRef< + React.ElementRef<typeof SelectPrimitive.Label>, + React.ComponentPropsWithoutRef<typeof SelectPrimitive.Label> +>(({ className, ...props }, ref) => ( + <SelectPrimitive.Label + ref={ref} + className={cn("py-1.5 pl-8 pr-2 text-sm font-semibold", className)} + {...props} + /> +)); +SelectLabel.displayName = SelectPrimitive.Label.displayName; + +const SelectItem = React.forwardRef< + React.ElementRef<typeof SelectPrimitive.Item>, + React.ComponentPropsWithoutRef<typeof SelectPrimitive.Item> +>(({ className, children, ...props }, ref) => ( + <SelectPrimitive.Item + ref={ref} + className={cn( + "relative flex w-full cursor-default select-none items-center rounded-sm py-1.5 pl-8 pr-2 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50", + className, + )} + {...props} + > + <span className="absolute left-2 flex h-3.5 w-3.5 items-center justify-center"> + <SelectPrimitive.ItemIndicator> + <Check className="h-4 w-4" /> + </SelectPrimitive.ItemIndicator> + </span> + + <SelectPrimitive.ItemText>{children}</SelectPrimitive.ItemText> + </SelectPrimitive.Item> +)); +SelectItem.displayName = SelectPrimitive.Item.displayName; + +const SelectSeparator = React.forwardRef< + React.ElementRef<typeof SelectPrimitive.Separator>, + React.ComponentPropsWithoutRef<typeof SelectPrimitive.Separator> +>(({ className, ...props }, ref) => ( + <SelectPrimitive.Separator + ref={ref} + className={cn("-mx-1 my-1 h-px bg-muted", className)} + {...props} + /> +)); +SelectSeparator.displayName = SelectPrimitive.Separator.displayName; + +export { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectScrollDownButton, + SelectScrollUpButton, + SelectSeparator, + SelectTrigger, + SelectValue, +}; diff --git a/templates/types/streaming/nextjs/package.json b/templates/types/streaming/nextjs/package.json index 5d429a41..b0b8bd57 100644 --- a/templates/types/streaming/nextjs/package.json +++ b/templates/types/streaming/nextjs/package.json @@ -15,6 +15,7 @@ "@llamaindex/pdf-viewer": "^1.1.3", "@radix-ui/react-collapsible": "^1.0.3", "@radix-ui/react-hover-card": "^1.0.7", + "@radix-ui/react-select": "^2.1.1", "@radix-ui/react-slot": "^1.0.2", "ai": "^3.0.21", "ajv": "^8.12.0", -- GitLab