diff --git a/.changeset/quick-walls-switch.md b/.changeset/quick-walls-switch.md new file mode 100644 index 0000000000000000000000000000000000000000..31ff2a41a1c550990008ab6e66bb0027e046298c --- /dev/null +++ b/.changeset/quick-walls-switch.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Let user change indexes in LlamaCloud projects diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts index bb04a75392aa52cf4cd39dd5ccdf0477a26f0f7c..34930cad5b26d2e60fcb62b2410402647af4e45f 100644 --- a/helpers/env-variables.ts +++ b/helpers/env-variables.ts @@ -160,6 +160,17 @@ const getVectorDBEnvs = ( description: "The organization ID for the LlamaCloud project (uses default organization if not specified - Python only)", }, + ...(framework === "nextjs" + ? // activate index selector per default (not needed for non-NextJS backends as it's handled by createFrontendEnvFile) + [ + { + name: "NEXT_PUBLIC_USE_LLAMACLOUD", + description: + "Let's the user change indexes in LlamaCloud projects", + value: "true", + }, + ] + : []), ]; case "chroma": const envs = [ @@ -493,6 +504,7 @@ export const createFrontendEnvFile = async ( root: string, opts: { customApiPath?: string; + vectorDb?: TemplateVectorDB; }, ) => { const defaultFrontendEnvs = [ @@ -503,6 +515,11 @@ export const createFrontendEnvFile = async ( ? opts.customApiPath : "http://localhost:8000/api/chat", }, + { + name: "NEXT_PUBLIC_USE_LLAMACLOUD", + description: "Let's the user change indexes in LlamaCloud projects", + value: opts.vectorDb === "llamacloud" ? "true" : "false", + }, ]; const content = renderEnvVar(defaultFrontendEnvs); await fs.writeFile(path.join(root, ".env"), content); diff --git a/helpers/index.ts b/helpers/index.ts index ae83f56cc7f3840f3ebaefd8f8e7f32d98182d14..79f3c6873910614c265ebd83601fb33682e18144 100644 --- a/helpers/index.ts +++ b/helpers/index.ts @@ -209,6 +209,7 @@ export const installTemplate = async ( // this is a frontend for a full-stack app, create .env file with model information await createFrontendEnvFile(props.root, { customApiPath: props.customApiPath, + vectorDb: props.vectorDb, }); } }; diff --git a/templates/components/engines/python/agent/__init__.py b/templates/components/engines/python/agent/__init__.py index 17e36236be01489a90fa9d49faec6937227dc1ad..fb8d410c600aa2167f39cf37f813fc073636f871 100644 --- a/templates/components/engines/python/agent/__init__.py +++ b/templates/components/engines/python/agent/__init__.py @@ -6,7 +6,7 @@ from app.engine.tools import ToolFactory from app.engine.index import get_index -def get_chat_engine(filters=None): +def get_chat_engine(filters=None, params=None): system_prompt = os.getenv("SYSTEM_PROMPT") top_k = os.getenv("TOP_K", "3") tools = [] diff --git a/templates/components/engines/python/chat/__init__.py b/templates/components/engines/python/chat/__init__.py index f885ed13220621bde2126de42a69a89500ed63ee..7d8df55ae0074af3f695df0a1cd3ede7ea999053 100644 --- a/templates/components/engines/python/chat/__init__.py +++ b/templates/components/engines/python/chat/__init__.py @@ -3,11 +3,11 @@ from app.engine.index import get_index from fastapi import HTTPException -def get_chat_engine(filters=None): +def get_chat_engine(filters=None, params=None): system_prompt = os.getenv("SYSTEM_PROMPT") top_k = os.getenv("TOP_K", 3) - index = get_index() + index = get_index(params) if index is None: raise HTTPException( status_code=500, diff --git a/templates/components/engines/typescript/agent/chat.ts b/templates/components/engines/typescript/agent/chat.ts index f3b43840854f848cdfb81c2a1211096fed70fa76..c9868ff71d1545c7500e41f29a8f45a57eb199c4 100644 --- a/templates/components/engines/typescript/agent/chat.ts +++ b/templates/components/engines/typescript/agent/chat.ts @@ -10,12 +10,12 @@ import path from "node:path"; import { getDataSource } from "./index"; import { createTools } from "./tools"; -export async function createChatEngine(documentIds?: string[]) { +export async function createChatEngine(documentIds?: string[], params?: any) { const tools: BaseToolWithCall[] = []; // Add a query engine tool if we have a data source // Delete this code if you don't have a data source - const index = await getDataSource(); + const index = await getDataSource(params); if (index) { tools.push( new QueryEngineTool({ diff --git a/templates/components/engines/typescript/chat/chat.ts b/templates/components/engines/typescript/chat/chat.ts index 1144256a59c124e1901ab396bce9850b11484b11..d2badd713c34a73b80c3d3b97279a4984bb7b9ce 100644 --- a/templates/components/engines/typescript/chat/chat.ts +++ b/templates/components/engines/typescript/chat/chat.ts @@ -6,8 +6,8 @@ import { } from "llamaindex"; import { getDataSource } from "./index"; -export async function createChatEngine(documentIds?: string[]) { - const index = await getDataSource(); +export async function createChatEngine(documentIds?: string[], params?: any) { + const index = await getDataSource(params); if (!index) { throw new Error( `StorageContext is empty - call 'npm run generate' to generate the storage first`, diff --git a/templates/components/llamaindex/typescript/streaming/service.ts b/templates/components/llamaindex/typescript/streaming/service.ts index 6b6c4206ceff34ac44adf0ffdbcdebaad02bf7e7..91001e9166feba1dd29fc4d8efe1ce27b116e84f 100644 --- a/templates/components/llamaindex/typescript/streaming/service.ts +++ b/templates/components/llamaindex/typescript/streaming/service.ts @@ -7,19 +7,51 @@ const LLAMA_CLOUD_OUTPUT_DIR = "output/llamacloud"; const LLAMA_CLOUD_BASE_URL = "https://cloud.llamaindex.ai/api/v1"; const FILE_DELIMITER = "$"; // delimiter between pipelineId and filename -interface LlamaCloudFile { +type LlamaCloudFile = { name: string; file_id: string; project_id: string; -} +}; + +type LLamaCloudProject = { + id: string; + organization_id: string; + name: string; + is_default: boolean; +}; + +type LLamaCloudPipeline = { + id: string; + name: string; + project_id: string; +}; export class LLamaCloudFileService { + private static readonly headers = { + Accept: "application/json", + Authorization: `Bearer ${process.env.LLAMA_CLOUD_API_KEY}`, + }; + + public static async getAllProjectsWithPipelines() { + try { + const projects = await LLamaCloudFileService.getAllProjects(); + const pipelines = await LLamaCloudFileService.getAllPipelines(); + return projects.map((project) => ({ + ...project, + pipelines: pipelines.filter((p) => p.project_id === project.id), + })); + } catch (error) { + console.error("Error listing projects and pipelines:", error); + return []; + } + } + public static async downloadFiles(nodes: NodeWithScore<Metadata>[]) { - const files = this.nodesToDownloadFiles(nodes); + const files = LLamaCloudFileService.nodesToDownloadFiles(nodes); if (!files.length) return; console.log("Downloading files from LlamaCloud..."); for (const file of files) { - await this.downloadFile(file.pipelineId, file.fileName); + await LLamaCloudFileService.downloadFile(file.pipelineId, file.fileName); } } @@ -59,13 +91,19 @@ export class LLamaCloudFileService { private static async downloadFile(pipelineId: string, fileName: string) { try { - const downloadedName = this.toDownloadedName(pipelineId, fileName); + const downloadedName = LLamaCloudFileService.toDownloadedName( + pipelineId, + fileName, + ); const downloadedPath = path.join(LLAMA_CLOUD_OUTPUT_DIR, downloadedName); // Check if file already exists if (fs.existsSync(downloadedPath)) return; - const urlToDownload = await this.getFileUrlByName(pipelineId, fileName); + const urlToDownload = await LLamaCloudFileService.getFileUrlByName( + pipelineId, + fileName, + ); if (!urlToDownload) throw new Error("File not found in LlamaCloud"); const file = fs.createWriteStream(downloadedPath); @@ -93,10 +131,13 @@ export class LLamaCloudFileService { pipelineId: string, name: string, ): Promise<string | null> { - const files = await this.getAllFiles(pipelineId); + const files = await LLamaCloudFileService.getAllFiles(pipelineId); const file = files.find((file) => file.name === name); if (!file) return null; - return await this.getFileUrlById(file.project_id, file.file_id); + return await LLamaCloudFileService.getFileUrlById( + file.project_id, + file.file_id, + ); } private static async getFileUrlById( @@ -104,11 +145,10 @@ export class LLamaCloudFileService { fileId: string, ): Promise<string> { const url = `${LLAMA_CLOUD_BASE_URL}/files/${fileId}/content?project_id=${projectId}`; - const headers = { - Accept: "application/json", - Authorization: `Bearer ${process.env.LLAMA_CLOUD_API_KEY}`, - }; - const response = await fetch(url, { method: "GET", headers }); + const response = await fetch(url, { + method: "GET", + headers: LLamaCloudFileService.headers, + }); const data = (await response.json()) as { url: string }; return data.url; } @@ -117,12 +157,31 @@ export class LLamaCloudFileService { pipelineId: string, ): Promise<LlamaCloudFile[]> { const url = `${LLAMA_CLOUD_BASE_URL}/pipelines/${pipelineId}/files`; - const headers = { - Accept: "application/json", - Authorization: `Bearer ${process.env.LLAMA_CLOUD_API_KEY}`, - }; - const response = await fetch(url, { method: "GET", headers }); + const response = await fetch(url, { + method: "GET", + headers: LLamaCloudFileService.headers, + }); const data = await response.json(); return data; } + + private static async getAllProjects(): Promise<LLamaCloudProject[]> { + const url = `${LLAMA_CLOUD_BASE_URL}/projects`; + const response = await fetch(url, { + method: "GET", + headers: LLamaCloudFileService.headers, + }); + const data = (await response.json()) as LLamaCloudProject[]; + return data; + } + + private static async getAllPipelines(): Promise<LLamaCloudPipeline[]> { + const url = `${LLAMA_CLOUD_BASE_URL}/pipelines`; + const response = await fetch(url, { + method: "GET", + headers: LLamaCloudFileService.headers, + }); + const data = (await response.json()) as LLamaCloudPipeline[]; + return data; + } } diff --git a/templates/components/vectordbs/python/llamacloud/index.py b/templates/components/vectordbs/python/llamacloud/index.py index da73434f25b83b4531fedcdeb208e38a0c1d048e..e54e8ca9df634913a21df411aef29c09d88f22dd 100644 --- a/templates/components/vectordbs/python/llamacloud/index.py +++ b/templates/components/vectordbs/python/llamacloud/index.py @@ -5,10 +5,11 @@ from llama_index.indices.managed.llama_cloud import LlamaCloudIndex logger = logging.getLogger("uvicorn") - -def get_index(): - name = os.getenv("LLAMA_CLOUD_INDEX_NAME") - project_name = os.getenv("LLAMA_CLOUD_PROJECT_NAME") +def get_index(params=None): + configParams = params or {} + pipelineConfig = configParams.get("llamaCloudPipeline", {}) + name = pipelineConfig.get("pipeline", os.getenv("LLAMA_CLOUD_INDEX_NAME")) + project_name = pipelineConfig.get("project", os.getenv("LLAMA_CLOUD_PROJECT_NAME")) api_key = os.getenv("LLAMA_CLOUD_API_KEY") base_url = os.getenv("LLAMA_CLOUD_BASE_URL") organization_id = os.getenv("LLAMA_CLOUD_ORGANIZATION_ID") diff --git a/templates/components/vectordbs/python/none/index.py b/templates/components/vectordbs/python/none/index.py index f7949d66cfa534b75635e520d461b36b56c3f693..65fd5ad5ad806456132d8fbf92952281da97372f 100644 --- a/templates/components/vectordbs/python/none/index.py +++ b/templates/components/vectordbs/python/none/index.py @@ -17,7 +17,7 @@ def get_storage_context(persist_dir: str) -> StorageContext: return StorageContext.from_defaults(persist_dir=persist_dir) -def get_index(): +def get_index(params=None): storage_dir = os.getenv("STORAGE_DIR", "storage") # check if storage already exists if not os.path.exists(storage_dir): diff --git a/templates/components/vectordbs/typescript/astra/index.ts b/templates/components/vectordbs/typescript/astra/index.ts index e29ed3531110a027fc31ee3f51d0bc2a09bd19a0..38c5bbbdd469101719ecc481df87c643b6505278 100644 --- a/templates/components/vectordbs/typescript/astra/index.ts +++ b/templates/components/vectordbs/typescript/astra/index.ts @@ -3,7 +3,7 @@ import { VectorStoreIndex } from "llamaindex"; import { AstraDBVectorStore } from "llamaindex/storage/vectorStore/AstraDBVectorStore"; import { checkRequiredEnvVars } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const store = new AstraDBVectorStore(); await store.connect(process.env.ASTRA_DB_COLLECTION!); diff --git a/templates/components/vectordbs/typescript/chroma/index.ts b/templates/components/vectordbs/typescript/chroma/index.ts index 1d36e643bec3ad354fd9faf2ba0977f2dcf5f188..fbc7b4bf27992c5f57246695af2443f426c7aa46 100644 --- a/templates/components/vectordbs/typescript/chroma/index.ts +++ b/templates/components/vectordbs/typescript/chroma/index.ts @@ -3,7 +3,7 @@ import { VectorStoreIndex } from "llamaindex"; import { ChromaVectorStore } from "llamaindex/storage/vectorStore/ChromaVectorStore"; import { checkRequiredEnvVars } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const chromaUri = `http://${process.env.CHROMA_HOST}:${process.env.CHROMA_PORT}`; diff --git a/templates/components/vectordbs/typescript/llamacloud/index.ts b/templates/components/vectordbs/typescript/llamacloud/index.ts index 3f0875ccd6ab24ee5a8c884a4119a35cc1e9af47..413d97b66ae538a5e1389188302ddc8055c4d8e0 100644 --- a/templates/components/vectordbs/typescript/llamacloud/index.ts +++ b/templates/components/vectordbs/typescript/llamacloud/index.ts @@ -1,12 +1,26 @@ import { LlamaCloudIndex } from "llamaindex/cloud/LlamaCloudIndex"; -import { checkRequiredEnvVars } from "./shared"; -export async function getDataSource() { - checkRequiredEnvVars(); +type LlamaCloudDataSourceParams = { + llamaCloudPipeline?: { + project: string; + pipeline: string; + }; +}; + +export async function getDataSource(params?: LlamaCloudDataSourceParams) { + const { project, pipeline } = params?.llamaCloudPipeline ?? {}; + const projectName = project ?? process.env.LLAMA_CLOUD_PROJECT_NAME; + const pipelineName = pipeline ?? process.env.LLAMA_CLOUD_INDEX_NAME; + const apiKey = process.env.LLAMA_CLOUD_API_KEY; + if (!projectName || !pipelineName || !apiKey) { + throw new Error( + "Set project, pipeline, and api key in the params or as environment variables.", + ); + } const index = new LlamaCloudIndex({ - name: process.env.LLAMA_CLOUD_INDEX_NAME!, - projectName: process.env.LLAMA_CLOUD_PROJECT_NAME!, - apiKey: process.env.LLAMA_CLOUD_API_KEY, + name: pipelineName, + projectName, + apiKey, baseUrl: process.env.LLAMA_CLOUD_BASE_URL, }); return index; diff --git a/templates/components/vectordbs/typescript/milvus/index.ts b/templates/components/vectordbs/typescript/milvus/index.ts index c290175f06966846da2bd0fb1996ca847398036d..91275b11ef58b540627770afb86c2af8d4059bfb 100644 --- a/templates/components/vectordbs/typescript/milvus/index.ts +++ b/templates/components/vectordbs/typescript/milvus/index.ts @@ -2,7 +2,7 @@ import { VectorStoreIndex } from "llamaindex"; import { MilvusVectorStore } from "llamaindex/storage/vectorStore/MilvusVectorStore"; import { checkRequiredEnvVars, getMilvusClient } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const milvusClient = getMilvusClient(); const store = new MilvusVectorStore({ milvusClient }); diff --git a/templates/components/vectordbs/typescript/mongo/index.ts b/templates/components/vectordbs/typescript/mongo/index.ts index effb8f92182c9c105adda5eb3178b0a3b411d662..efa35fa53cc06778156c77169507f2f289e8f417 100644 --- a/templates/components/vectordbs/typescript/mongo/index.ts +++ b/templates/components/vectordbs/typescript/mongo/index.ts @@ -4,7 +4,7 @@ import { MongoDBAtlasVectorSearch } from "llamaindex/storage/vectorStore/MongoDB import { MongoClient } from "mongodb"; import { checkRequiredEnvVars } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const client = new MongoClient(process.env.MONGO_URI!); const store = new MongoDBAtlasVectorSearch({ diff --git a/templates/components/vectordbs/typescript/none/index.ts b/templates/components/vectordbs/typescript/none/index.ts index 64b289750cd978e88b4b17d8194bb468026e8910..fecc76f45637cd22514b0a791f4fdd6be07200c6 100644 --- a/templates/components/vectordbs/typescript/none/index.ts +++ b/templates/components/vectordbs/typescript/none/index.ts @@ -2,7 +2,7 @@ import { SimpleDocumentStore, VectorStoreIndex } from "llamaindex"; import { storageContextFromDefaults } from "llamaindex/storage/StorageContext"; import { STORAGE_CACHE_DIR } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { const storageContext = await storageContextFromDefaults({ persistDir: `${STORAGE_CACHE_DIR}`, }); diff --git a/templates/components/vectordbs/typescript/pg/index.ts b/templates/components/vectordbs/typescript/pg/index.ts index 787cae74d6012749577d01786deb77a7ecc7f8f9..75bcd4038be2361fa2a3347f3811a0e3d133703f 100644 --- a/templates/components/vectordbs/typescript/pg/index.ts +++ b/templates/components/vectordbs/typescript/pg/index.ts @@ -7,7 +7,7 @@ import { checkRequiredEnvVars, } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const pgvs = new PGVectorStore({ connectionString: process.env.PG_CONNECTION_STRING, diff --git a/templates/components/vectordbs/typescript/pinecone/index.ts b/templates/components/vectordbs/typescript/pinecone/index.ts index 15072cfff40959aafb32ac80a273cb5093f59e67..66a22d46e96ee893a4f96bea2cb77fee3798e73d 100644 --- a/templates/components/vectordbs/typescript/pinecone/index.ts +++ b/templates/components/vectordbs/typescript/pinecone/index.ts @@ -3,7 +3,7 @@ import { VectorStoreIndex } from "llamaindex"; import { PineconeVectorStore } from "llamaindex/storage/vectorStore/PineconeVectorStore"; import { checkRequiredEnvVars } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const store = new PineconeVectorStore(); return await VectorStoreIndex.fromVectorStore(store); diff --git a/templates/components/vectordbs/typescript/qdrant/index.ts b/templates/components/vectordbs/typescript/qdrant/index.ts index 0233d0882f54286118424d5b565721ae3053b03e..a9d87ab8cc3a057e0751ca0b524dddba18992c55 100644 --- a/templates/components/vectordbs/typescript/qdrant/index.ts +++ b/templates/components/vectordbs/typescript/qdrant/index.ts @@ -5,7 +5,7 @@ import { checkRequiredEnvVars, getQdrantClient } from "./shared"; dotenv.config(); -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const collectionName = process.env.QDRANT_COLLECTION; const store = new QdrantVectorStore({ diff --git a/templates/types/streaming/express/src/controllers/chat-config.controller.ts b/templates/types/streaming/express/src/controllers/chat-config.controller.ts index 4481e10d133e3c274aaacdedd8ce4534d87a07bf..af843c2c5ac28d6426be441b61781c72e543ac0b 100644 --- a/templates/types/streaming/express/src/controllers/chat-config.controller.ts +++ b/templates/types/streaming/express/src/controllers/chat-config.controller.ts @@ -1,4 +1,5 @@ import { Request, Response } from "express"; +import { LLamaCloudFileService } from "./llamaindex/streaming/service"; export const chatConfig = async (_req: Request, res: Response) => { let starterQuestions = undefined; @@ -12,3 +13,14 @@ export const chatConfig = async (_req: Request, res: Response) => { starterQuestions, }); }; + +export const chatLlamaCloudConfig = async (_req: Request, res: Response) => { + const config = { + projects: await LLamaCloudFileService.getAllProjectsWithPipelines(), + pipeline: { + pipeline: process.env.LLAMA_CLOUD_INDEX_NAME, + project: process.env.LLAMA_CLOUD_PROJECT_NAME, + }, + }; + return res.status(200).json(config); +}; diff --git a/templates/types/streaming/express/src/controllers/chat.controller.ts b/templates/types/streaming/express/src/controllers/chat.controller.ts index 95228e8d8a31a5aaca47a4db42bd22d5eab39f92..50b70789cc7fc591d4beb113c3ff9b4ee8e0ce09 100644 --- a/templates/types/streaming/express/src/controllers/chat.controller.ts +++ b/templates/types/streaming/express/src/controllers/chat.controller.ts @@ -17,7 +17,7 @@ export const chat = async (req: Request, res: Response) => { const vercelStreamData = new StreamData(); const streamTimeout = createStreamTimeout(vercelStreamData); try { - const { messages }: { messages: Message[] } = req.body; + const { messages, data }: { messages: Message[]; data?: any } = req.body; const userMessage = messages.pop(); if (!messages || !userMessage || userMessage.role !== "user") { return res.status(400).json({ @@ -46,7 +46,7 @@ export const chat = async (req: Request, res: Response) => { }, ); const ids = retrieveDocumentIds(allAnnotations); - const chatEngine = await createChatEngine(ids); + const chatEngine = await createChatEngine(ids, data); // Convert message content from Vercel/AI format to LlamaIndex/OpenAI format const userMessageContent = convertMessageContent( diff --git a/templates/types/streaming/express/src/routes/chat.route.ts b/templates/types/streaming/express/src/routes/chat.route.ts index 96711d502d7d30474fa440a427646c626fe53e66..5044efcb4d900e8a17cb6f9ab06f1ff6012ec000 100644 --- a/templates/types/streaming/express/src/routes/chat.route.ts +++ b/templates/types/streaming/express/src/routes/chat.route.ts @@ -1,5 +1,8 @@ import express, { Router } from "express"; -import { chatConfig } from "../controllers/chat-config.controller"; +import { + chatConfig, + chatLlamaCloudConfig, +} from "../controllers/chat-config.controller"; import { chatRequest } from "../controllers/chat-request.controller"; import { chatUpload } from "../controllers/chat-upload.controller"; import { chat } from "../controllers/chat.controller"; @@ -11,6 +14,7 @@ initSettings(); llmRouter.route("/").post(chat); llmRouter.route("/request").post(chatRequest); llmRouter.route("/config").get(chatConfig); +llmRouter.route("/config/llamacloud").get(chatLlamaCloudConfig); llmRouter.route("/upload").post(chatUpload); export default llmRouter; diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py index e2cffadf74ef0b7a630b51ff2e5e1a4a161a5d04..cb7036d922ec966111edb790ff83c84e63a4a742 100644 --- a/templates/types/streaming/fastapi/app/api/routers/chat.py +++ b/templates/types/streaming/fastapi/app/api/routers/chat.py @@ -52,8 +52,9 @@ async def chat( doc_ids = data.get_chat_document_ids() filters = generate_filters(doc_ids) + params = data.data or {} logger.info("Creating chat engine with filters", filters.dict()) - chat_engine = get_chat_engine(filters=filters) + chat_engine = get_chat_engine(filters=filters, params=params) event_handler = EventCallbackHandler() chat_engine.callback_manager.handlers.append(event_handler) # type: ignore @@ -125,3 +126,23 @@ async def chat_config() -> ChatConfig: if conversation_starters and conversation_starters.strip(): starter_questions = conversation_starters.strip().split("\n") return ChatConfig(starter_questions=starter_questions) + + +@r.get("/config/llamacloud") +async def chat_llama_cloud_config(): + projects = LLamaCloudFileService.get_all_projects_with_pipelines() + pipeline = os.getenv("LLAMA_CLOUD_INDEX_NAME") + project = os.getenv("LLAMA_CLOUD_PROJECT_NAME") + pipeline_config = ( + pipeline + and project + and { + "pipeline": pipeline, + "project": project, + } + or None + ) + return { + "projects": projects, + "pipeline": pipeline_config, + } diff --git a/templates/types/streaming/fastapi/app/api/routers/models.py b/templates/types/streaming/fastapi/app/api/routers/models.py index b510e848a6484cb0e7d66e7f0d8d184dc2b655a2..c9ea1adb772c7c61b7d1ecc9f8b7480d69ac55ec 100644 --- a/templates/types/streaming/fastapi/app/api/routers/models.py +++ b/templates/types/streaming/fastapi/app/api/routers/models.py @@ -75,6 +75,7 @@ class Message(BaseModel): class ChatData(BaseModel): messages: List[Message] + data: Any = None class Config: json_schema_extra = { @@ -237,7 +238,7 @@ class ChatConfig(BaseModel): starter_questions: Optional[List[str]] = Field( default=None, description="List of starter questions", - serialization_alias="starterQuestions" + serialization_alias="starterQuestions", ) class Config: diff --git a/templates/types/streaming/fastapi/app/api/services/llama_cloud.py b/templates/types/streaming/fastapi/app/api/services/llama_cloud.py index ea03e64bb887daaa060f38ed83d4299f8c03908c..852ae7cee693b81ad9cb94b7c34252e7a5153279 100644 --- a/templates/types/streaming/fastapi/app/api/services/llama_cloud.py +++ b/templates/types/streaming/fastapi/app/api/services/llama_cloud.py @@ -14,6 +14,32 @@ class LLamaCloudFileService: DOWNLOAD_FILE_NAME_TPL = "{pipeline_id}${filename}" + @classmethod + def get_all_projects(cls) -> List[Dict[str, Any]]: + url = f"{cls.LLAMA_CLOUD_URL}/projects" + return cls._make_request(url) + + @classmethod + def get_all_pipelines(cls) -> List[Dict[str, Any]]: + url = f"{cls.LLAMA_CLOUD_URL}/pipelines" + return cls._make_request(url) + + @classmethod + def get_all_projects_with_pipelines(cls) -> List[Dict[str, Any]]: + try: + projects = cls.get_all_projects() + pipelines = cls.get_all_pipelines() + return [ + { + **project, + "pipelines": [p for p in pipelines if p["project_id"] == project["id"]], + } + for project in projects + ] + except Exception as error: + logger.error(f"Error listing projects and pipelines: {error}") + return [] + @classmethod def _get_files(cls, pipeline_id: str) -> List[Dict[str, Any]]: url = f"{cls.LLAMA_CLOUD_URL}/pipelines/{pipeline_id}/files" diff --git a/templates/types/streaming/fastapi/app/engine/index.py b/templates/types/streaming/fastapi/app/engine/index.py index 2dbc589b15d85922d7cee978ea32a9268b2a5734..e1adcb8030ae8af50738f934556f35d0ae212163 100644 --- a/templates/types/streaming/fastapi/app/engine/index.py +++ b/templates/types/streaming/fastapi/app/engine/index.py @@ -6,7 +6,7 @@ from app.engine.vectordb import get_vector_store logger = logging.getLogger("uvicorn") -def get_index(): +def get_index(params=None): logger.info("Connecting vector store...") store = get_vector_store() # Load the index from the vector store diff --git a/templates/types/streaming/nextjs/app/api/chat/config/llamacloud/route.ts b/templates/types/streaming/nextjs/app/api/chat/config/llamacloud/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..e40409f174f5cfc5b33a2db40b7eba3b0f3a1e19 --- /dev/null +++ b/templates/types/streaming/nextjs/app/api/chat/config/llamacloud/route.ts @@ -0,0 +1,16 @@ +import { NextResponse } from "next/server"; +import { LLamaCloudFileService } from "../../llamaindex/streaming/service"; + +/** + * This API is to get config from the backend envs and expose them to the frontend + */ +export async function GET() { + const config = { + projects: await LLamaCloudFileService.getAllProjectsWithPipelines(), + pipeline: { + pipeline: process.env.LLAMA_CLOUD_INDEX_NAME, + project: process.env.LLAMA_CLOUD_PROJECT_NAME, + }, + }; + return NextResponse.json(config, { status: 200 }); +} diff --git a/templates/types/streaming/nextjs/app/api/chat/route.ts b/templates/types/streaming/nextjs/app/api/chat/route.ts index 792ecb7b8ba2483b626977d1f72970c0faf1180f..adfccf13be62b311833f4f0bf127208ea75177cc 100644 --- a/templates/types/streaming/nextjs/app/api/chat/route.ts +++ b/templates/types/streaming/nextjs/app/api/chat/route.ts @@ -27,7 +27,7 @@ export async function POST(request: NextRequest) { try { const body = await request.json(); - const { messages }: { messages: Message[] } = body; + const { messages, data }: { messages: Message[]; data?: any } = body; const userMessage = messages.pop(); if (!messages || !userMessage || userMessage.role !== "user") { return NextResponse.json( @@ -59,7 +59,7 @@ export async function POST(request: NextRequest) { }, ); const ids = retrieveDocumentIds(allAnnotations); - const chatEngine = await createChatEngine(ids); + const chatEngine = await createChatEngine(ids, data); // Convert message content from Vercel/AI format to LlamaIndex/OpenAI format const userMessageContent = convertMessageContent( diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx index 01c7c0b1bb317a9d5b99b623b7483ef958cc9a53..4c582966af4e1d4fa0392bdfa251c8169080dfc4 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx @@ -1,4 +1,5 @@ import { JSONValue } from "ai"; +import { useState } from "react"; import { Button } from "../button"; import { DocumentPreview } from "../document-preview"; import FileUploader from "../file-uploader"; @@ -6,6 +7,7 @@ import { Input } from "../input"; import UploadImagePreview from "../upload-image-preview"; import { ChatHandler } from "./chat.interface"; import { useFile } from "./hooks/use-file"; +import { LlamaCloudSelector } from "./widgets/LlamaCloudSelector"; const ALLOWED_EXTENSIONS = ["png", "jpg", "jpeg", "csv", "pdf", "txt", "docx"]; @@ -34,6 +36,7 @@ export default function ChatInput( reset, getAnnotations, } = useFile(); + const [requestData, setRequestData] = useState<any>(); // default submit function does not handle including annotations in the message // so we need to use append function to submit new message with annotations @@ -42,12 +45,15 @@ export default function ChatInput( annotations: JSONValue[] | undefined, ) => { e.preventDefault(); - props.append!({ - content: props.input, - role: "user", - createdAt: new Date(), - annotations, - }); + props.append!( + { + content: props.input, + role: "user", + createdAt: new Date(), + annotations, + }, + { data: requestData }, + ); props.setInput!(""); }; @@ -57,7 +63,7 @@ export default function ChatInput( handleSubmitWithAnnotations(e, annotations); return reset(); } - props.handleSubmit(e); + props.handleSubmit(e, { data: requestData }); }; const handleUploadFile = async (file: File) => { @@ -109,6 +115,9 @@ export default function ChatInput( disabled: props.isLoading, }} /> + {process.env.NEXT_PUBLIC_USE_LLAMACLOUD === "true" && ( + <LlamaCloudSelector setRequestData={setRequestData} /> + )} <Button type="submit" disabled={props.isLoading || !props.input.trim()}> Send message </Button> diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/widgets/LlamaCloudSelector.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/widgets/LlamaCloudSelector.tsx new file mode 100644 index 0000000000000000000000000000000000000000..aa995c91f27a577239f8b5f887d72932654e82ba --- /dev/null +++ b/templates/types/streaming/nextjs/app/components/ui/chat/widgets/LlamaCloudSelector.tsx @@ -0,0 +1,151 @@ +import { Loader2 } from "lucide-react"; +import { useEffect, useState } from "react"; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectTrigger, + SelectValue, +} from "../../select"; +import { useClientConfig } from "../hooks/use-config"; + +type LLamaCloudPipeline = { + id: string; + name: string; +}; + +type LLamaCloudProject = { + id: string; + organization_id: string; + name: string; + is_default: boolean; + pipelines: Array<LLamaCloudPipeline>; +}; + +type PipelineConfig = { + project: string; // project name + pipeline: string; // pipeline name +}; + +type LlamaCloudConfig = { + projects?: LLamaCloudProject[]; + pipeline?: PipelineConfig; +}; + +export interface LlamaCloudSelectorProps { + setRequestData: React.Dispatch<any>; +} + +export function LlamaCloudSelector({ + setRequestData, +}: LlamaCloudSelectorProps) { + const { backend } = useClientConfig(); + const [config, setConfig] = useState<LlamaCloudConfig>(); + + useEffect(() => { + if (process.env.NEXT_PUBLIC_USE_LLAMACLOUD === "true" && !config) { + fetch(`${backend}/api/chat/config/llamacloud`) + .then((response) => response.json()) + .then((data) => { + setConfig(data); + setRequestData({ + llamaCloudPipeline: data.pipeline, + }); + }) + .catch((error) => console.error("Error fetching config", error)); + } + }, [backend, config, setRequestData]); + + const setPipeline = (pipelineConfig?: PipelineConfig) => { + setConfig((prevConfig: any) => ({ + ...prevConfig, + pipeline: pipelineConfig, + })); + setRequestData((prevData: any) => { + if (!prevData) return { llamaCloudPipeline: pipelineConfig }; + return { + ...prevData, + llamaCloudPipeline: pipelineConfig, + }; + }); + }; + + const handlePipelineSelect = async (value: string) => { + setPipeline(JSON.parse(value) as PipelineConfig); + }; + + if (!config) { + return ( + <div className="flex justify-center items-center p-3"> + <Loader2 className="h-4 w-4 animate-spin" /> + </div> + ); + } + if (!isValid(config)) { + return ( + <p className="text-red-500"> + Invalid LlamaCloud configuration. Check console logs. + </p> + ); + } + const { projects, pipeline } = config; + + return ( + <Select + onValueChange={handlePipelineSelect} + defaultValue={JSON.stringify(pipeline)} + > + <SelectTrigger className="w-[200px]"> + <SelectValue placeholder="Select a pipeline" /> + </SelectTrigger> + <SelectContent> + {projects!.map((project: LLamaCloudProject) => ( + <SelectGroup key={project.id}> + <SelectLabel className="capitalize"> + Project: {project.name} + </SelectLabel> + {project.pipelines.map((pipeline) => ( + <SelectItem + key={pipeline.id} + className="last:border-b" + value={JSON.stringify({ + pipeline: pipeline.name, + project: project.name, + })} + > + <span className="pl-2">{pipeline.name}</span> + </SelectItem> + ))} + </SelectGroup> + ))} + </SelectContent> + </Select> + ); +} + +function isValid(config: LlamaCloudConfig): boolean { + const { projects, pipeline } = config; + if (!projects?.length) return false; + if (!pipeline) return false; + const matchedProject = projects.find( + (project: LLamaCloudProject) => project.name === pipeline.project, + ); + if (!matchedProject) { + console.error( + `LlamaCloud project ${pipeline.project} not found. Check LLAMA_CLOUD_PROJECT_NAME variable`, + ); + return false; + } + const pipelineExists = matchedProject.pipelines.some( + (p) => p.name === pipeline.pipeline, + ); + if (!pipelineExists) { + console.error( + `LlamaCloud pipeline ${pipeline.pipeline} not found. Check LLAMA_CLOUD_INDEX_NAME variable`, + ); + return false; + } + return true; +} diff --git a/templates/types/streaming/nextjs/app/components/ui/select.tsx b/templates/types/streaming/nextjs/app/components/ui/select.tsx new file mode 100644 index 0000000000000000000000000000000000000000..c01b068ba4a783711c1c945c076b75f7f299aef4 --- /dev/null +++ b/templates/types/streaming/nextjs/app/components/ui/select.tsx @@ -0,0 +1,159 @@ +"use client"; + +import * as SelectPrimitive from "@radix-ui/react-select"; +import { Check, ChevronDown, ChevronUp } from "lucide-react"; +import * as React from "react"; +import { cn } from "./lib/utils"; + +const Select = SelectPrimitive.Root; + +const SelectGroup = SelectPrimitive.Group; + +const SelectValue = SelectPrimitive.Value; + +const SelectTrigger = React.forwardRef< + React.ElementRef<typeof SelectPrimitive.Trigger>, + React.ComponentPropsWithoutRef<typeof SelectPrimitive.Trigger> +>(({ className, children, ...props }, ref) => ( + <SelectPrimitive.Trigger + ref={ref} + className={cn( + "flex h-10 w-full items-center justify-between rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1", + className, + )} + {...props} + > + {children} + <SelectPrimitive.Icon asChild> + <ChevronDown className="h-4 w-4 opacity-50" /> + </SelectPrimitive.Icon> + </SelectPrimitive.Trigger> +)); +SelectTrigger.displayName = SelectPrimitive.Trigger.displayName; + +const SelectScrollUpButton = React.forwardRef< + React.ElementRef<typeof SelectPrimitive.ScrollUpButton>, + React.ComponentPropsWithoutRef<typeof SelectPrimitive.ScrollUpButton> +>(({ className, ...props }, ref) => ( + <SelectPrimitive.ScrollUpButton + ref={ref} + className={cn( + "flex cursor-default items-center justify-center py-1", + className, + )} + {...props} + > + <ChevronUp className="h-4 w-4" /> + </SelectPrimitive.ScrollUpButton> +)); +SelectScrollUpButton.displayName = SelectPrimitive.ScrollUpButton.displayName; + +const SelectScrollDownButton = React.forwardRef< + React.ElementRef<typeof SelectPrimitive.ScrollDownButton>, + React.ComponentPropsWithoutRef<typeof SelectPrimitive.ScrollDownButton> +>(({ className, ...props }, ref) => ( + <SelectPrimitive.ScrollDownButton + ref={ref} + className={cn( + "flex cursor-default items-center justify-center py-1", + className, + )} + {...props} + > + <ChevronDown className="h-4 w-4" /> + </SelectPrimitive.ScrollDownButton> +)); +SelectScrollDownButton.displayName = + SelectPrimitive.ScrollDownButton.displayName; + +const SelectContent = React.forwardRef< + React.ElementRef<typeof SelectPrimitive.Content>, + React.ComponentPropsWithoutRef<typeof SelectPrimitive.Content> +>(({ className, children, position = "popper", ...props }, ref) => ( + <SelectPrimitive.Portal> + <SelectPrimitive.Content + ref={ref} + className={cn( + "relative z-50 max-h-96 min-w-[8rem] overflow-hidden rounded-md border bg-popover text-popover-foreground shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2", + position === "popper" && + "data-[side=bottom]:translate-y-1 data-[side=left]:-translate-x-1 data-[side=right]:translate-x-1 data-[side=top]:-translate-y-1", + className, + )} + position={position} + {...props} + > + <SelectScrollUpButton /> + <SelectPrimitive.Viewport + className={cn( + "p-1", + position === "popper" && + "h-[var(--radix-select-trigger-height)] w-full min-w-[var(--radix-select-trigger-width)]", + )} + > + {children} + </SelectPrimitive.Viewport> + <SelectScrollDownButton /> + </SelectPrimitive.Content> + </SelectPrimitive.Portal> +)); +SelectContent.displayName = SelectPrimitive.Content.displayName; + +const SelectLabel = React.forwardRef< + React.ElementRef<typeof SelectPrimitive.Label>, + React.ComponentPropsWithoutRef<typeof SelectPrimitive.Label> +>(({ className, ...props }, ref) => ( + <SelectPrimitive.Label + ref={ref} + className={cn("py-1.5 pl-8 pr-2 text-sm font-semibold", className)} + {...props} + /> +)); +SelectLabel.displayName = SelectPrimitive.Label.displayName; + +const SelectItem = React.forwardRef< + React.ElementRef<typeof SelectPrimitive.Item>, + React.ComponentPropsWithoutRef<typeof SelectPrimitive.Item> +>(({ className, children, ...props }, ref) => ( + <SelectPrimitive.Item + ref={ref} + className={cn( + "relative flex w-full cursor-default select-none items-center rounded-sm py-1.5 pl-8 pr-2 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50", + className, + )} + {...props} + > + <span className="absolute left-2 flex h-3.5 w-3.5 items-center justify-center"> + <SelectPrimitive.ItemIndicator> + <Check className="h-4 w-4" /> + </SelectPrimitive.ItemIndicator> + </span> + + <SelectPrimitive.ItemText>{children}</SelectPrimitive.ItemText> + </SelectPrimitive.Item> +)); +SelectItem.displayName = SelectPrimitive.Item.displayName; + +const SelectSeparator = React.forwardRef< + React.ElementRef<typeof SelectPrimitive.Separator>, + React.ComponentPropsWithoutRef<typeof SelectPrimitive.Separator> +>(({ className, ...props }, ref) => ( + <SelectPrimitive.Separator + ref={ref} + className={cn("-mx-1 my-1 h-px bg-muted", className)} + {...props} + /> +)); +SelectSeparator.displayName = SelectPrimitive.Separator.displayName; + +export { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectScrollDownButton, + SelectScrollUpButton, + SelectSeparator, + SelectTrigger, + SelectValue, +}; diff --git a/templates/types/streaming/nextjs/package.json b/templates/types/streaming/nextjs/package.json index 5d429a417f7008945de6c47a7e283d941c12ab40..b0b8bd577678fdec7fac3b82a6780a528693f612 100644 --- a/templates/types/streaming/nextjs/package.json +++ b/templates/types/streaming/nextjs/package.json @@ -15,6 +15,7 @@ "@llamaindex/pdf-viewer": "^1.1.3", "@radix-ui/react-collapsible": "^1.0.3", "@radix-ui/react-hover-card": "^1.0.7", + "@radix-ui/react-select": "^2.1.1", "@radix-ui/react-slot": "^1.0.2", "ai": "^3.0.21", "ajv": "^8.12.0",