diff --git a/.changeset/poor-knives-smoke.md b/.changeset/poor-knives-smoke.md new file mode 100644 index 0000000000000000000000000000000000000000..a09c627c24b344a7088f8e7929ca244ad21f8e2c --- /dev/null +++ b/.changeset/poor-knives-smoke.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Fix event streaming is blocked diff --git a/.changeset/wet-tips-judge.md b/.changeset/wet-tips-judge.md new file mode 100644 index 0000000000000000000000000000000000000000..478106884de3189902d006a7b4a4ee38540f85fc --- /dev/null +++ b/.changeset/wet-tips-judge.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Add upload file to sandbox (artifact and code interpreter) diff --git a/helpers/tools.ts b/helpers/tools.ts index 0684a780daba9f35f11d1e570920aac1314b966b..262e71b1d225a8cfb36a54362f214b163f35b5a9 100644 --- a/helpers/tools.ts +++ b/helpers/tools.ts @@ -139,7 +139,7 @@ For better results, you can specify the region parameter to get results from a s dependencies: [ { name: "e2b_code_interpreter", - version: "0.0.10", + version: "0.0.11b38", }, ], supportedFrameworks: ["fastapi", "express", "nextjs"], diff --git a/questions/simple.ts b/questions/simple.ts index 6a3477eeeb1610f5cbdfb6c27bd46ca5e2ace674..29930c1aefb244bd773c843eed7a73891558213c 100644 --- a/questions/simple.ts +++ b/questions/simple.ts @@ -5,14 +5,19 @@ import { getTools } from "../helpers/tools"; import { ModelConfig, TemplateFramework } from "../helpers/types"; import { PureQuestionArgs, QuestionResults } from "./types"; import { askPostInstallAction, questionHandlers } from "./utils"; -type AppType = "rag" | "code_artifact" | "multiagent" | "extractor"; + +type AppType = + | "rag" + | "code_artifact" + | "multiagent" + | "extractor" + | "data_scientist"; type SimpleAnswers = { appType: AppType; language: TemplateFramework; useLlamaCloud: boolean; llamaCloudKey?: string; - modelConfig: ModelConfig; }; export const askSimpleQuestions = async ( @@ -25,6 +30,7 @@ export const askSimpleQuestions = async ( message: "What app do you want to build?", choices: [ { title: "Agentic RAG", value: "rag" }, + { title: "Data Scientist", value: "data_scientist" }, { title: "Code Artifact Agent", value: "code_artifact" }, { title: "Multi-Agent Report Gen", value: "multiagent" }, { title: "Structured extraction", value: "extractor" }, @@ -80,28 +86,36 @@ export const askSimpleQuestions = async ( } } - const modelConfig = await askModelConfig({ - openAiKey: args.openAiKey, - askModels: args.askModels ?? false, - framework: language, - }); - - const results = convertAnswers({ + const results = await convertAnswers(args, { appType, language, useLlamaCloud, llamaCloudKey, - modelConfig, }); results.postInstallAction = await askPostInstallAction(results); return results; }; -const convertAnswers = (answers: SimpleAnswers): QuestionResults => { +const convertAnswers = async ( + args: PureQuestionArgs, + answers: SimpleAnswers, +): Promise<QuestionResults> => { + const MODEL_GPT4o: ModelConfig = { + provider: "openai", + apiKey: args.openAiKey, + model: "gpt-4o", + embeddingModel: "text-embedding-3-large", + dimensions: 1536, + isConfigured(): boolean { + return !!args.openAiKey; + }, + }; const lookup: Record< AppType, - Pick<QuestionResults, "template" | "tools" | "frontend" | "dataSources"> + Pick<QuestionResults, "template" | "tools" | "frontend" | "dataSources"> & { + modelConfig?: ModelConfig; + } > = { rag: { template: "streaming", @@ -109,11 +123,19 @@ const convertAnswers = (answers: SimpleAnswers): QuestionResults => { frontend: true, dataSources: [EXAMPLE_FILE], }, + data_scientist: { + template: "streaming", + tools: getTools(["interpreter", "document_generator"]), + frontend: true, + dataSources: [], + modelConfig: MODEL_GPT4o, + }, code_artifact: { template: "streaming", tools: getTools(["artifact"]), frontend: true, dataSources: [], + modelConfig: MODEL_GPT4o, }, multiagent: { template: "multiagent", @@ -140,11 +162,16 @@ const convertAnswers = (answers: SimpleAnswers): QuestionResults => { llamaCloudKey: answers.llamaCloudKey, useLlamaParse: answers.useLlamaCloud, llamapack: "", - postInstallAction: "none", vectorDb: answers.useLlamaCloud ? "llamacloud" : "none", - modelConfig: answers.modelConfig, observability: "none", ...results, + modelConfig: + results.modelConfig ?? + (await askModelConfig({ + openAiKey: args.openAiKey, + askModels: args.askModels ?? false, + framework: answers.language, + })), frontend: answers.language === "nextjs" ? false : results.frontend, }; }; diff --git a/templates/components/engines/python/agent/tools/artifact.py b/templates/components/engines/python/agent/tools/artifact.py index 4c877b2fd65fdd9a706b5e6f88edec32fdd995b0..5506113f912812e08ac579c50cf55a061a086296 100644 --- a/templates/components/engines/python/agent/tools/artifact.py +++ b/templates/components/engines/python/agent/tools/artifact.py @@ -66,21 +66,29 @@ class CodeGeneratorTool: def __init__(self): pass - def artifact(self, query: str, old_code: Optional[str] = None) -> Dict: - """Generate a code artifact based on the input. + def artifact( + self, + query: str, + sandbox_files: Optional[List[str]] = None, + old_code: Optional[str] = None, + ) -> Dict: + """Generate a code artifact based on the provided input. Args: - query (str): The description of the application you want to build. + query (str): A description of the application you want to build. + sandbox_files (Optional[List[str]], optional): A list of sandbox file paths. Defaults to None. Include these files if the code requires them. old_code (Optional[str], optional): The existing code to be modified. Defaults to None. Returns: - Dict: A dictionary containing the generated artifact information. + Dict: A dictionary containing information about the generated artifact. """ if old_code: user_message = f"{query}\n\nThe existing code is: \n```\n{old_code}\n```" else: user_message = query + if sandbox_files: + user_message += f"\n\nThe provided files are: \n{str(sandbox_files)}" messages: List[ChatMessage] = [ ChatMessage(role="system", content=CODE_GENERATION_PROMPT), @@ -90,7 +98,10 @@ class CodeGeneratorTool: sllm = Settings.llm.as_structured_llm(output_cls=CodeArtifact) # type: ignore response = sllm.chat(messages) data: CodeArtifact = response.raw - return data.model_dump() + data_dict = data.model_dump() + if sandbox_files: + data_dict["files"] = sandbox_files + return data_dict except Exception as e: logger.error(f"Failed to generate artifact: {str(e)}") raise e diff --git a/templates/components/engines/python/agent/tools/interpreter.py b/templates/components/engines/python/agent/tools/interpreter.py index 0f4c10b95e51ec4d7722ed79f5edb9010f4349af..9d19ea883136231984fdbb7cff84db4ca827881b 100644 --- a/templates/components/engines/python/agent/tools/interpreter.py +++ b/templates/components/engines/python/agent/tools/interpreter.py @@ -2,14 +2,15 @@ import base64 import logging import os import uuid -from typing import Dict, List, Optional +from typing import List, Optional +from app.engine.utils.file_helper import FileMetadata, save_file from e2b_code_interpreter import CodeInterpreter from e2b_code_interpreter.models import Logs from llama_index.core.tools import FunctionTool from pydantic import BaseModel -logger = logging.getLogger(__name__) +logger = logging.getLogger("uvicorn") class InterpreterExtraResult(BaseModel): @@ -22,11 +23,14 @@ class InterpreterExtraResult(BaseModel): class E2BToolOutput(BaseModel): is_error: bool logs: Logs + error_message: Optional[str] = None results: List[InterpreterExtraResult] = [] + retry_count: int = 0 class E2BCodeInterpreter: output_dir = "output/tools" + uploaded_files_dir = "output/uploaded" def __init__(self, api_key: str = None): if api_key is None: @@ -42,40 +46,43 @@ class E2BCodeInterpreter: ) self.filesever_url_prefix = filesever_url_prefix - self.interpreter = CodeInterpreter(api_key=api_key) + self.interpreter = None + self.api_key = api_key def __del__(self): - self.interpreter.close() - - def get_output_path(self, filename: str) -> str: - # if output directory doesn't exist, create it - if not os.path.exists(self.output_dir): - os.makedirs(self.output_dir, exist_ok=True) - return os.path.join(self.output_dir, filename) + """ + Kill the interpreter when the tool is no longer in use + """ + if self.interpreter is not None: + self.interpreter.kill() - def save_to_disk(self, base64_data: str, ext: str) -> Dict: - filename = f"{uuid.uuid4()}.{ext}" # generate a unique filename + def _init_interpreter(self, sandbox_files: List[str] = []): + """ + Lazily initialize the interpreter. + """ + logger.info(f"Initializing interpreter with {len(sandbox_files)} files") + self.interpreter = CodeInterpreter(api_key=self.api_key) + if len(sandbox_files) > 0: + for file_path in sandbox_files: + file_name = os.path.basename(file_path) + local_file_path = os.path.join(self.uploaded_files_dir, file_name) + with open(local_file_path, "rb") as f: + content = f.read() + if self.interpreter and self.interpreter.files: + self.interpreter.files.write(file_path, content) + logger.info(f"Uploaded {len(sandbox_files)} files to sandbox") + + def _save_to_disk(self, base64_data: str, ext: str) -> FileMetadata: buffer = base64.b64decode(base64_data) - output_path = self.get_output_path(filename) - - try: - with open(output_path, "wb") as file: - file.write(buffer) - except IOError as e: - logger.error(f"Failed to write to file {output_path}: {str(e)}") - raise e - logger.info(f"Saved file to {output_path}") + filename = f"{uuid.uuid4()}.{ext}" # generate a unique filename + output_path = os.path.join(self.output_dir, filename) - return { - "outputPath": output_path, - "filename": filename, - } + file_metadata = save_file(buffer, file_path=output_path) - def get_file_url(self, filename: str) -> str: - return f"{self.filesever_url_prefix}/{self.output_dir}/{filename}" + return file_metadata - def parse_result(self, result) -> List[InterpreterExtraResult]: + def _parse_result(self, result) -> List[InterpreterExtraResult]: """ The result could include multiple formats (e.g. png, svg, etc.) but encoded in base64 We save each result to disk and return saved file metadata (extension, filename, url) @@ -92,16 +99,20 @@ class E2BCodeInterpreter: for ext, data in zip(formats, results): match ext: case "png" | "svg" | "jpeg" | "pdf": - result = self.save_to_disk(data, ext) - filename = result["filename"] + file_metadata = self._save_to_disk(data, ext) output.append( InterpreterExtraResult( type=ext, - filename=filename, - url=self.get_file_url(filename), + filename=file_metadata.name, + url=file_metadata.url, ) ) case _: + # Try serialize data to string + try: + data = str(data) + except Exception as e: + data = f"Error when serializing data: {e}" output.append( InterpreterExtraResult( type=ext, @@ -114,28 +125,75 @@ class E2BCodeInterpreter: return output - def interpret(self, code: str) -> E2BToolOutput: + def interpret( + self, + code: str, + sandbox_files: List[str] = [], + retry_count: int = 0, + ) -> E2BToolOutput: """ - Execute python code in a Jupyter notebook cell, the toll will return result, stdout, stderr, display_data, and error. + Execute Python code in a Jupyter notebook cell. The tool will return the result, stdout, stderr, display_data, and error. + If the code needs to use a file, ALWAYS pass the file path in the sandbox_files argument. + You have a maximum of 3 retries to get the code to run successfully. Parameters: - code (str): The python code to be executed in a single cell. + code (str): The Python code to be executed in a single cell. + sandbox_files (List[str]): List of local file paths to be used by the code. The tool will throw an error if a file is not found. + retry_count (int): Number of times the tool has been retried. """ - logger.info( - f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}" - ) - exec = self.interpreter.notebook.exec_cell(code) - - if exec.error: - logger.error("Error when executing code", exec.error) - output = E2BToolOutput(is_error=True, logs=exec.logs, results=[]) - else: - if len(exec.results) == 0: - output = E2BToolOutput(is_error=False, logs=exec.logs, results=[]) + if retry_count > 2: + return E2BToolOutput( + is_error=True, + logs=Logs( + stdout="", + stderr="", + display_data="", + error="", + ), + error_message="Failed to execute the code after 3 retries. Explain the error to the user and suggest a fix.", + retry_count=retry_count, + ) + + if self.interpreter is None: + self._init_interpreter(sandbox_files) + + if self.interpreter and self.interpreter.notebook: + logger.info( + f"\n{'='*50}\n> Running following AI-generated code:\n{code}\n{'='*50}" + ) + exec = self.interpreter.notebook.exec_cell(code) + + if exec.error: + error_message = f"The code failed to execute successfully. Error: {exec.error}. Try to fix the code and run again." + logger.error(error_message) + # Calling the generated code caused an error. Kill the interpreter and return the error to the LLM so it can try to fix the error + try: + self.interpreter.kill() # type: ignore + except Exception: + pass + finally: + self.interpreter = None + output = E2BToolOutput( + is_error=True, + logs=exec.logs, + results=[], + error_message=error_message, + retry_count=retry_count + 1, + ) else: - results = self.parse_result(exec.results[0]) - output = E2BToolOutput(is_error=False, logs=exec.logs, results=results) - return output + if len(exec.results) == 0: + output = E2BToolOutput(is_error=False, logs=exec.logs, results=[]) + else: + results = self._parse_result(exec.results[0]) + output = E2BToolOutput( + is_error=False, + logs=exec.logs, + results=results, + retry_count=retry_count + 1, + ) + return output + else: + raise ValueError("Interpreter is not initialized.") def get_tools(**kwargs): diff --git a/templates/components/engines/typescript/agent/tools/code-generator.ts b/templates/components/engines/typescript/agent/tools/code-generator.ts index eedcfa515134f89b53ce4477c7178256468968c8..ec6f10e832564d30a072c7b315d44b82146205ba 100644 --- a/templates/components/engines/typescript/agent/tools/code-generator.ts +++ b/templates/components/engines/typescript/agent/tools/code-generator.ts @@ -48,11 +48,13 @@ export type CodeArtifact = { port: number | null; file_path: string; code: string; + files?: string[]; }; export type CodeGeneratorParameter = { requirement: string; oldCode?: string; + sandboxFiles?: string[]; }; export type CodeGeneratorToolParams = { @@ -75,6 +77,15 @@ const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<CodeGeneratorParameter>> = description: "The existing code to be modified", nullable: true, }, + sandboxFiles: { + type: "array", + description: + "A list of sandbox file paths. Include these files if the code requires them.", + items: { + type: "string", + }, + nullable: true, + }, }, required: ["requirement"], }, @@ -93,6 +104,9 @@ export class CodeGeneratorTool implements BaseTool<CodeGeneratorParameter> { input.requirement, input.oldCode, ); + if (input.sandboxFiles) { + artifact.files = input.sandboxFiles; + } return artifact as JSONValue; } catch (error) { return { isError: true }; diff --git a/templates/components/engines/typescript/agent/tools/interpreter.ts b/templates/components/engines/typescript/agent/tools/interpreter.ts index 24573c2051d28fe2b64cfe9591af1d1bbfcf0045..ae386a13f7ed1d9fa93623525aadb5f049e56f3b 100644 --- a/templates/components/engines/typescript/agent/tools/interpreter.ts +++ b/templates/components/engines/typescript/agent/tools/interpreter.ts @@ -7,6 +7,8 @@ import path from "node:path"; export type InterpreterParameter = { code: string; + sandboxFiles?: string[]; + retryCount?: number; }; export type InterpreterToolParams = { @@ -18,7 +20,9 @@ export type InterpreterToolParams = { export type InterpreterToolOutput = { isError: boolean; logs: Logs; + text?: string; extraResult: InterpreterExtraResult[]; + retryCount?: number; }; type InterpreterExtraType = @@ -41,8 +45,10 @@ export type InterpreterExtraResult = { const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<InterpreterParameter>> = { name: "interpreter", - description: - "Execute python code in a Jupyter notebook cell and return any result, stdout, stderr, display_data, and error.", + description: `Execute python code in a Jupyter notebook cell and return any result, stdout, stderr, display_data, and error. +If the code needs to use a file, ALWAYS pass the file path in the sandbox_files argument. +You have a maximum of 3 retries to get the code to run successfully. +`, parameters: { type: "object", properties: { @@ -50,6 +56,21 @@ const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<InterpreterParameter>> = { type: "string", description: "The python code to execute in a single cell.", }, + sandboxFiles: { + type: "array", + description: + "List of local file paths to be used by the code. The tool will throw an error if a file is not found.", + items: { + type: "string", + }, + nullable: true, + }, + retryCount: { + type: "number", + description: "The number of times the tool has been retried", + default: 0, + nullable: true, + }, }, required: ["code"], }, @@ -57,6 +78,7 @@ const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<InterpreterParameter>> = { export class InterpreterTool implements BaseTool<InterpreterParameter> { private readonly outputDir = "output/tools"; + private readonly uploadedFilesDir = "output/uploaded"; private apiKey?: string; private fileServerURLPrefix?: string; metadata: ToolMetadata<JSONSchemaType<InterpreterParameter>>; @@ -80,33 +102,64 @@ export class InterpreterTool implements BaseTool<InterpreterParameter> { } } - public async initInterpreter() { + public async initInterpreter(input: InterpreterParameter) { if (!this.codeInterpreter) { this.codeInterpreter = await CodeInterpreter.create({ apiKey: this.apiKey, }); } + // upload files to sandbox + if (input.sandboxFiles) { + console.log(`Uploading ${input.sandboxFiles.length} files to sandbox`); + for (const filePath of input.sandboxFiles) { + const fileName = path.basename(filePath); + const localFilePath = path.join(this.uploadedFilesDir, fileName); + const content = fs.readFileSync(localFilePath); + await this.codeInterpreter?.files.write(filePath, content); + } + console.log(`Uploaded ${input.sandboxFiles.length} files to sandbox`); + } return this.codeInterpreter; } - public async codeInterpret(code: string): Promise<InterpreterToolOutput> { + public async codeInterpret( + input: InterpreterParameter, + ): Promise<InterpreterToolOutput> { + console.log( + `Sandbox files: ${input.sandboxFiles}. Retry count: ${input.retryCount}`, + ); + + if (input.retryCount && input.retryCount >= 3) { + return { + isError: true, + logs: { + stdout: [], + stderr: [], + }, + text: "Max retries reached", + extraResult: [], + }; + } + console.log( - `\n${"=".repeat(50)}\n> Running following AI-generated code:\n${code}\n${"=".repeat(50)}`, + `\n${"=".repeat(50)}\n> Running following AI-generated code:\n${input.code}\n${"=".repeat(50)}`, ); - const interpreter = await this.initInterpreter(); - const exec = await interpreter.notebook.execCell(code); + const interpreter = await this.initInterpreter(input); + const exec = await interpreter.notebook.execCell(input.code); if (exec.error) console.error("[Code Interpreter error]", exec.error); const extraResult = await this.getExtraResult(exec.results[0]); const result: InterpreterToolOutput = { isError: !!exec.error, logs: exec.logs, + text: exec.text, extraResult, + retryCount: input.retryCount ? input.retryCount + 1 : 1, }; return result; } async call(input: InterpreterParameter): Promise<InterpreterToolOutput> { - const result = await this.codeInterpret(input.code); + const result = await this.codeInterpret(input); return result; } diff --git a/templates/components/llamaindex/typescript/documents/helper.ts b/templates/components/llamaindex/typescript/documents/helper.ts index bfe7452286851741701d1592312081ed2dded933..52cc5d94326bba80254683c917b7ddf416d03a2b 100644 --- a/templates/components/llamaindex/typescript/documents/helper.ts +++ b/templates/components/llamaindex/typescript/documents/helper.ts @@ -1,3 +1,5 @@ +import { Document } from "llamaindex"; +import crypto from "node:crypto"; import fs from "node:fs"; import path from "node:path"; import { getExtractors } from "../../engine/loader"; @@ -5,23 +7,58 @@ import { getExtractors } from "../../engine/loader"; const MIME_TYPE_TO_EXT: Record<string, string> = { "application/pdf": "pdf", "text/plain": "txt", + "text/csv": "csv", "application/vnd.openxmlformats-officedocument.wordprocessingml.document": "docx", }; const UPLOADED_FOLDER = "output/uploaded"; +export type FileMetadata = { + id: string; + name: string; + url: string; + refs: string[]; +}; + export async function storeAndParseFile( filename: string, fileBuffer: Buffer, mimeType: string, +): Promise<FileMetadata> { + const fileMetadata = await storeFile(filename, fileBuffer, mimeType); + const documents: Document[] = await parseFile(fileBuffer, filename, mimeType); + // Update document IDs in the file metadata + fileMetadata.refs = documents.map((document) => document.id_ as string); + return fileMetadata; +} + +export async function storeFile( + filename: string, + fileBuffer: Buffer, + mimeType: string, ) { const fileExt = MIME_TYPE_TO_EXT[mimeType]; if (!fileExt) throw new Error(`Unsupported document type: ${mimeType}`); + const fileId = crypto.randomUUID(); + const newFilename = `${fileId}_${sanitizeFileName(filename)}`; + const filepath = path.join(UPLOADED_FOLDER, newFilename); + const fileUrl = await saveDocument(filepath, fileBuffer); + return { + id: fileId, + name: newFilename, + url: fileUrl, + refs: [] as string[], + } as FileMetadata; +} + +export async function parseFile( + fileBuffer: Buffer, + filename: string, + mimeType: string, +) { const documents = await loadDocuments(fileBuffer, mimeType); - const filepath = path.join(UPLOADED_FOLDER, filename); - await saveDocument(filepath, fileBuffer); for (const document of documents) { document.metadata = { ...document.metadata, @@ -48,12 +85,6 @@ export async function saveDocument(filepath: string, content: string | Buffer) { if (path.isAbsolute(filepath)) { throw new Error("Absolute file paths are not allowed."); } - const fileName = path.basename(filepath); - if (!/^[a-zA-Z0-9_.-]+$/.test(fileName)) { - throw new Error( - "File name is not allowed to contain any special characters.", - ); - } if (!process.env.FILESERVER_URL_PREFIX) { throw new Error("FILESERVER_URL_PREFIX environment variable is not set."); } @@ -71,3 +102,7 @@ export async function saveDocument(filepath: string, content: string | Buffer) { console.log(`Saved document to ${filepath}. Reachable at URL: ${fileurl}`); return fileurl; } + +function sanitizeFileName(fileName: string) { + return fileName.replace(/[^a-zA-Z0-9_.-]/g, "_"); +} diff --git a/templates/components/llamaindex/typescript/documents/pipeline.ts b/templates/components/llamaindex/typescript/documents/pipeline.ts index 6f9589cd2d4b962b86135ddfaa363749a615793d..01b52fd5d732d98a6569bcf418d624bd3cbe40f1 100644 --- a/templates/components/llamaindex/typescript/documents/pipeline.ts +++ b/templates/components/llamaindex/typescript/documents/pipeline.ts @@ -7,7 +7,7 @@ import { } from "llamaindex"; export async function runPipeline( - currentIndex: VectorStoreIndex, + currentIndex: VectorStoreIndex | null, documents: Document[], ) { // Use ingestion pipeline to process the documents into nodes and add them to the vector store @@ -21,8 +21,18 @@ export async function runPipeline( ], }); const nodes = await pipeline.run({ documents }); - await currentIndex.insertNodes(nodes); - currentIndex.storageContext.docStore.persist(); - console.log("Added nodes to the vector store."); - return documents.map((document) => document.id_); + if (currentIndex) { + await currentIndex.insertNodes(nodes); + currentIndex.storageContext.docStore.persist(); + console.log("Added nodes to the vector store."); + return documents.map((document) => document.id_); + } else { + // Initialize a new index with the documents + const newIndex = await VectorStoreIndex.fromDocuments(documents); + newIndex.storageContext.docStore.persist(); + console.log( + "Got empty index, created new index with the uploaded documents", + ); + return documents.map((document) => document.id_); + } } diff --git a/templates/components/llamaindex/typescript/documents/upload.ts b/templates/components/llamaindex/typescript/documents/upload.ts index a5a817e772a44ae3c90e46fd4ecae1216b2be831..158b05a1ac8decfa512c166ae93dedac20d3a724 100644 --- a/templates/components/llamaindex/typescript/documents/upload.ts +++ b/templates/components/llamaindex/typescript/documents/upload.ts @@ -1,30 +1,41 @@ -import { LLamaCloudFileService, VectorStoreIndex } from "llamaindex"; +import { Document, LLamaCloudFileService, VectorStoreIndex } from "llamaindex"; import { LlamaCloudIndex } from "llamaindex/cloud/LlamaCloudIndex"; -import { storeAndParseFile } from "./helper"; +import fs from "node:fs/promises"; +import path from "node:path"; +import { FileMetadata, parseFile, storeFile } from "./helper"; import { runPipeline } from "./pipeline"; export async function uploadDocument( - index: VectorStoreIndex | LlamaCloudIndex, + index: VectorStoreIndex | LlamaCloudIndex | null, filename: string, raw: string, -): Promise<string[]> { +): Promise<FileMetadata> { const [header, content] = raw.split(","); const mimeType = header.replace("data:", "").replace(";base64", ""); const fileBuffer = Buffer.from(content, "base64"); + // Store file + const fileMetadata = await storeFile(filename, fileBuffer, mimeType); + + // If the file is csv and has codeExecutorTool, we don't need to index the file. + if (mimeType === "text/csv" && (await hasCodeExecutorTool())) { + return fileMetadata; + } + if (index instanceof LlamaCloudIndex) { // trigger LlamaCloudIndex API to upload the file and run the pipeline const projectId = await index.getProjectId(); const pipelineId = await index.getPipelineId(); try { - return [ - await LLamaCloudFileService.addFileToPipeline( - projectId, - pipelineId, - new File([fileBuffer], filename, { type: mimeType }), - { private: "true" }, - ), - ]; + const documentId = await LLamaCloudFileService.addFileToPipeline( + projectId, + pipelineId, + new File([fileBuffer], filename, { type: mimeType }), + { private: "true" }, + ); + // Update file metadata with document IDs + fileMetadata.refs = [documentId]; + return fileMetadata; } catch (error) { if ( error instanceof ReferenceError && @@ -39,6 +50,21 @@ export async function uploadDocument( } // run the pipeline for other vector store indexes - const documents = await storeAndParseFile(filename, fileBuffer, mimeType); - return runPipeline(index, documents); + const documents: Document[] = await parseFile(fileBuffer, filename, mimeType); + // Update file metadata with document IDs + fileMetadata.refs = documents.map((document) => document.id_ as string); + // Run the pipeline + await runPipeline(index, documents); + return fileMetadata; } + +const hasCodeExecutorTool = async () => { + const codeExecutorTools = ["interpreter", "artifact"]; + + const configFile = path.join("config", "tools.json"); + const toolConfig = JSON.parse(await fs.readFile(configFile, "utf8")); + + const localTools = toolConfig.local || {}; + // Check if local tools contains codeExecutorTools + return codeExecutorTools.some((tool) => localTools[tool] !== undefined); +}; diff --git a/templates/components/llamaindex/typescript/streaming/annotations.ts b/templates/components/llamaindex/typescript/streaming/annotations.ts index 10e6f52c4755fa7e82d53b2fcee5193bb5d6dfca..f8de88f8b611669aa7b1470ef66539a1de78f9b4 100644 --- a/templates/components/llamaindex/typescript/streaming/annotations.ts +++ b/templates/components/llamaindex/typescript/streaming/annotations.ts @@ -3,17 +3,17 @@ import { MessageContent, MessageContentDetail } from "llamaindex"; export type DocumentFileType = "csv" | "pdf" | "txt" | "docx"; -export type DocumentFileContent = { - type: "ref" | "text"; - value: string[] | string; +export type UploadedFileMeta = { + id: string; + name: string; + url?: string; + refs?: string[]; }; export type DocumentFile = { - id: string; - filename: string; - filesize: number; - filetype: DocumentFileType; - content: DocumentFileContent; + type: DocumentFileType; + url: string; + metadata: UploadedFileMeta; }; type Annotation = { @@ -29,28 +29,25 @@ export function isValidMessages(messages: Message[]): boolean { export function retrieveDocumentIds(messages: Message[]): string[] { // retrieve document Ids from the annotations of all messages (if any) + const documentFiles = retrieveDocumentFiles(messages); + return documentFiles.map((file) => file.metadata?.refs || []).flat(); +} + +export function retrieveDocumentFiles(messages: Message[]): DocumentFile[] { const annotations = getAllAnnotations(messages); if (annotations.length === 0) return []; - const ids: string[] = []; - + const files: DocumentFile[] = []; for (const { type, data } of annotations) { if ( type === "document_file" && "files" in data && Array.isArray(data.files) ) { - const files = data.files as DocumentFile[]; - for (const file of files) { - if (Array.isArray(file.content.value)) { - // it's an array, so it's an array of doc IDs - ids.push(...file.content.value); - } - } + files.push(...data.files); } } - - return ids; + return files; } export function retrieveMessageContent(messages: Message[]): MessageContent { @@ -65,6 +62,36 @@ export function retrieveMessageContent(messages: Message[]): MessageContent { ]; } +function getFileContent(file: DocumentFile): string { + const fileMetadata = file.metadata; + let defaultContent = `=====File: ${fileMetadata.name}=====\n`; + // Include file URL if it's available + const urlPrefix = process.env.FILESERVER_URL_PREFIX; + let urlContent = ""; + if (urlPrefix) { + if (fileMetadata.url) { + urlContent = `File URL: ${fileMetadata.url}\n`; + } else { + urlContent = `File URL (instruction: do not update this file URL yourself): ${urlPrefix}/output/uploaded/${fileMetadata.name}\n`; + } + } else { + console.warn( + "Warning: FILESERVER_URL_PREFIX not set in environment variables. Can't use file server", + ); + } + defaultContent += urlContent; + + // Include document IDs if it's available + if (fileMetadata.refs) { + defaultContent += `Document IDs: ${fileMetadata.refs}\n`; + } + // Include sandbox file paths + const sandboxFilePath = `/tmp/${fileMetadata.name}`; + defaultContent += `Sandbox file path (instruction: only use sandbox path for artifact or code interpreter tool): ${sandboxFilePath}\n`; + + return defaultContent; +} + function getAllAnnotations(messages: Message[]): Annotation[] { return messages.flatMap((message) => (message.annotations ?? []).map((annotation) => @@ -131,25 +158,11 @@ function convertAnnotations(messages: Message[]): MessageContentDetail[] { "files" in data && Array.isArray(data.files) ) { - // get all CSV files and convert their whole content to one text message - // currently CSV files are the only files where we send the whole content - we don't use an index - const csvFiles: DocumentFile[] = data.files.filter( - (file: DocumentFile) => file.filetype === "csv", - ); - if (csvFiles && csvFiles.length > 0) { - const csvContents = csvFiles.map((file: DocumentFile) => { - const fileContent = Array.isArray(file.content.value) - ? file.content.value.join("\n") - : file.content.value; - return "```csv\n" + fileContent + "\n```"; - }); - const text = - "Use the following CSV content:\n" + csvContents.join("\n\n"); - content.push({ - type: "text", - text, - }); - } + const fileContent = data.files.map(getFileContent).join("\n"); + content.push({ + type: "text", + text: fileContent, + }); } }); diff --git a/templates/components/routers/python/sandbox.py b/templates/components/routers/python/sandbox.py index c5a2a367034ebd305f667088ce794f066b2b1072..9efe146fd9c3158a99c5774ed4adb57812871340 100644 --- a/templates/components/routers/python/sandbox.py +++ b/templates/components/routers/python/sandbox.py @@ -16,7 +16,8 @@ import base64 import logging import os import uuid -from typing import Dict, List, Optional, Union +from dataclasses import asdict +from typing import Any, Dict, List, Optional, Union from app.engine.tools.artifact import CodeArtifact from app.engine.utils.file_helper import save_file @@ -36,7 +37,7 @@ class ExecutionResult(BaseModel): template: str stdout: List[str] stderr: List[str] - runtime_error: Optional[Dict[str, Union[str, List[str]]]] = None + runtime_error: Optional[Dict[str, Any]] = None output_urls: List[Dict[str, str]] url: Optional[str] @@ -54,15 +55,27 @@ class ExecutionResult(BaseModel): } +class FileUpload(BaseModel): + id: str + name: str + + @sandbox_router.post("") async def create_sandbox(request: Request): request_data = await request.json() + artifact_data = request_data.get("artifact", None) + sandbox_files = artifact_data.get("files", []) + + if not artifact_data: + raise HTTPException( + status_code=400, detail="Could not create artifact from the request data" + ) try: - artifact = CodeArtifact(**request_data["artifact"]) + artifact = CodeArtifact(**artifact_data) except Exception: logger.error(f"Could not create artifact from request data: {request_data}") - return HTTPException( + raise HTTPException( status_code=400, detail="Could not create artifact from the request data" ) @@ -94,6 +107,10 @@ async def create_sandbox(request: Request): f"Installed dependencies: {', '.join(artifact.additional_dependencies)} in sandbox {sbx}" ) + # Copy files + if len(sandbox_files) > 0: + _upload_files(sbx, sandbox_files) + # Copy code to disk if isinstance(artifact.code, list): for file in artifact.code: @@ -107,11 +124,12 @@ async def create_sandbox(request: Request): if artifact.template == "code-interpreter-multilang": result = sbx.notebook.exec_cell(artifact.code or "") output_urls = _download_cell_results(result.results) + runtime_error = asdict(result.error) if result.error else None return ExecutionResult( template=artifact.template, stdout=result.logs.stdout, stderr=result.logs.stderr, - runtime_error=result.error, + runtime_error=runtime_error, output_urls=output_urls, url=None, ).to_response() @@ -126,6 +144,19 @@ async def create_sandbox(request: Request): ).to_response() +def _upload_files( + sandbox: Union[CodeInterpreter, Sandbox], + sandbox_files: List[str] = [], +) -> None: + for file_path in sandbox_files: + file_name = os.path.basename(file_path) + local_file_path = os.path.join("output", "uploaded", file_name) + with open(local_file_path, "rb") as f: + content = f.read() + sandbox.files.write(file_path, content) + return None + + def _download_cell_results(cell_results: Optional[List]) -> List[Dict[str, str]]: """ To pull results from code interpreter cell and save them to disk for serving @@ -141,14 +172,14 @@ def _download_cell_results(cell_results: Optional[List]) -> List[Dict[str, str]] data = result[ext] if ext in ["png", "svg", "jpeg", "pdf"]: - file_path = f"output/tools/{uuid.uuid4()}.{ext}" + file_path = os.path.join("output", "tools", f"{uuid.uuid4()}.{ext}") base64_data = data buffer = base64.b64decode(base64_data) file_meta = save_file(content=buffer, file_path=file_path) output.append( { "type": ext, - "filename": file_meta.filename, + "filename": file_meta.name, "url": file_meta.url, } ) diff --git a/templates/components/services/python/file.py b/templates/components/services/python/file.py index d526718189f81aa1f5a20a4cb759f9a7daaab481..3e9ad3e64ae90dc5188bde3f5d702aa062b3a828 100644 --- a/templates/components/services/python/file.py +++ b/templates/components/services/python/file.py @@ -1,17 +1,21 @@ import base64 import mimetypes import os +import re +import uuid from io import BytesIO from pathlib import Path -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple from app.engine.index import IndexConfig, get_index +from app.engine.utils.file_helper import FileMetadata, save_file from llama_index.core import VectorStoreIndex from llama_index.core.ingestion import IngestionPipeline from llama_index.core.readers.file.base import ( _try_loading_included_file_formats as get_file_loaders_map, ) from llama_index.core.schema import Document +from llama_index.core.tools.function_tool import FunctionTool from llama_index.indices.managed.llama_cloud.base import LlamaCloudIndex from llama_index.readers.file import FlatReader @@ -31,14 +35,19 @@ def get_llamaparse_parser(): def default_file_loaders_map(): default_loaders = get_file_loaders_map() default_loaders[".txt"] = FlatReader + default_loaders[".csv"] = FlatReader return default_loaders class PrivateFileService: + """ + To store the files uploaded by the user and add them to the index. + """ + PRIVATE_STORE_PATH = "output/uploaded" @staticmethod - def preprocess_base64_file(base64_content: str) -> Tuple[bytes, str | None]: + def _preprocess_base64_file(base64_content: str) -> Tuple[bytes, str | None]: header, data = base64_content.split(",", 1) mime_type = header.split(";")[0].split(":", 1)[1] extension = mimetypes.guess_extension(mime_type) @@ -46,79 +55,144 @@ class PrivateFileService: return base64.b64decode(data), extension @staticmethod - def store_and_parse_file(file_name, file_data, extension) -> List[Document]: + def _store_file(file_name, file_data) -> FileMetadata: + """ + Store the file to the private directory and return the file metadata + """ # Store file to the private directory os.makedirs(PrivateFileService.PRIVATE_STORE_PATH, exist_ok=True) file_path = Path(os.path.join(PrivateFileService.PRIVATE_STORE_PATH, file_name)) - # write file - with open(file_path, "wb") as f: - f.write(file_data) + return save_file(file_data, file_path=str(file_path)) + + @staticmethod + def _load_file_to_documents(file_metadata: FileMetadata) -> List[Document]: + """ + Load the file from the private directory and return the documents + """ + _, extension = os.path.splitext(file_metadata.name) + extension = extension.lstrip(".") # Load file to documents # If LlamaParse is enabled, use it to parse the file # Otherwise, use the default file loaders reader = get_llamaparse_parser() if reader is None: - reader_cls = default_file_loaders_map().get(extension) + reader_cls = default_file_loaders_map().get(f".{extension}") if reader_cls is None: raise ValueError(f"File extension {extension} is not supported") reader = reader_cls() - documents = reader.load_data(file_path) + documents = reader.load_data(Path(file_metadata.path)) # Add custom metadata for doc in documents: - doc.metadata["file_name"] = file_name + doc.metadata["file_name"] = file_metadata.name doc.metadata["private"] = "true" return documents @staticmethod + def _add_documents_to_vector_store_index( + documents: List[Document], index: VectorStoreIndex + ) -> None: + """ + Add the documents to the vector store index + """ + pipeline = IngestionPipeline() + nodes = pipeline.run(documents=documents) + + # Add the nodes to the index and persist it + if index is None: + index = VectorStoreIndex(nodes=nodes) + else: + index.insert_nodes(nodes=nodes) + index.storage_context.persist( + persist_dir=os.environ.get("STORAGE_DIR", "storage") + ) + + @staticmethod + def _add_file_to_llama_cloud_index( + index: LlamaCloudIndex, + file_name: str, + file_data: bytes, + ) -> str: + """ + Add the file to the LlamaCloud index. + LlamaCloudIndex is a managed index so we can directly use the files. + """ + try: + from app.engine.service import LLamaCloudFileService + except ImportError: + raise ValueError("LlamaCloudFileService is not found") + + project_id = index._get_project_id() + pipeline_id = index._get_pipeline_id() + # LlamaCloudIndex is a managed index so we can directly use the files + upload_file = (file_name, BytesIO(file_data)) + doc_id = LLamaCloudFileService.add_file_to_pipeline( + project_id, + pipeline_id, + upload_file, + custom_metadata={}, + ) + return doc_id + + @staticmethod + def _sanitize_file_name(file_name: str) -> str: + file_name, extension = os.path.splitext(file_name) + return re.sub(r"[^a-zA-Z0-9]", "_", file_name) + extension + + @classmethod def process_file( - file_name: str, base64_content: str, params: Optional[dict] = None - ) -> List[str]: + cls, + file_name: str, + base64_content: str, + params: Optional[dict] = None, + ) -> FileMetadata: if params is None: params = {} - file_data, extension = PrivateFileService.preprocess_base64_file(base64_content) - # Add the nodes to the index and persist it index_config = IndexConfig(**params) - current_index = get_index(index_config) + index = get_index(index_config) - # Insert the documents into the index - if isinstance(current_index, LlamaCloudIndex): - from app.engine.service import LLamaCloudFileService + # Generate a new file name if the same file is uploaded multiple times + file_id = str(uuid.uuid4()) + new_file_name = f"{file_id}_{cls._sanitize_file_name(file_name)}" - project_id = current_index._get_project_id() - pipeline_id = current_index._get_pipeline_id() - # LlamaCloudIndex is a managed index so we can directly use the files - upload_file = (file_name, BytesIO(file_data)) - return [ - LLamaCloudFileService.add_file_to_pipeline( - project_id, - pipeline_id, - upload_file, - custom_metadata={ - # Set private=true to mark the document as private user docs (required for filtering) - "private": "true", - }, - ) - ] + # Preprocess and store the file + file_data, extension = cls._preprocess_base64_file(base64_content) + file_metadata = cls._store_file(new_file_name, file_data) + + tools = cls._get_available_tools() + code_executor_tools = ["interpreter", "artifact"] + # If the file is CSV and there is a code executor tool, we don't need to index. + if extension == ".csv" and any(tool in tools for tool in code_executor_tools): + return file_metadata else: - # First process documents into nodes - documents = PrivateFileService.store_and_parse_file( - file_name, file_data, extension - ) - pipeline = IngestionPipeline() - nodes = pipeline.run(documents=documents) - - # Add the nodes to the index and persist it - if current_index is None: - current_index = VectorStoreIndex(nodes=nodes) + # Insert the file into the index and update document ids to the file metadata + if isinstance(index, LlamaCloudIndex): + doc_id = cls._add_file_to_llama_cloud_index( + index, new_file_name, file_data + ) + # Add document ids to the file metadata + file_metadata.refs = [doc_id] else: - current_index.insert_nodes(nodes=nodes) - current_index.storage_context.persist( - persist_dir=os.environ.get("STORAGE_DIR", "storage") - ) + documents = cls._load_file_to_documents(file_metadata) + cls._add_documents_to_vector_store_index(documents, index) + # Add document ids to the file metadata + file_metadata.refs = [doc.doc_id for doc in documents] - # Return the document ids - return [doc.doc_id for doc in documents] + # Return the file metadata + return file_metadata + + @staticmethod + def _get_available_tools() -> Dict[str, List[FunctionTool]]: + try: + from app.engine.tools import ToolFactory + + tools = ToolFactory.from_env(map_result=True) + return tools + except ImportError: + # There is no tool code + return {} + except Exception as e: + raise ValueError(f"Failed to get available tools: {e}") from e diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py index 7e96c9274ea4ec1cfd4e3bfdb91c8f12368bfc20..c024dad02ae73f2aa38f0187a0949c99f4d55bf6 100644 --- a/templates/types/streaming/fastapi/app/api/routers/chat.py +++ b/templates/types/streaming/fastapi/app/api/routers/chat.py @@ -1,8 +1,6 @@ import logging -from typing import List from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status -from llama_index.core.chat_engine.types import NodeWithScore from llama_index.core.llms import MessageRole from app.api.routers.events import EventCallbackHandler @@ -42,10 +40,11 @@ async def chat( chat_engine = get_chat_engine( filters=filters, params=params, event_handlers=[event_handler] ) - response = await chat_engine.astream_chat(last_message_content, messages) - process_response_nodes(response.source_nodes, background_tasks) + response = chat_engine.astream_chat(last_message_content, messages) - return VercelStreamResponse(request, event_handler, response, data) + return VercelStreamResponse( + request, event_handler, response, data, background_tasks + ) except Exception as e: logger.exception("Error in chat engine", exc_info=True) raise HTTPException( @@ -76,17 +75,3 @@ async def chat_request( result=Message(role=MessageRole.ASSISTANT, content=response.response), nodes=SourceNodes.from_source_nodes(response.source_nodes), ) - - -def process_response_nodes( - nodes: List[NodeWithScore], - background_tasks: BackgroundTasks, -): - try: - # Start background tasks to download documents from LlamaCloud if needed - from app.engine.service import LLamaCloudFileService - - LLamaCloudFileService.download_files_from_nodes(nodes, background_tasks) - except ImportError: - logger.debug("LlamaCloud is not configured. Skipping post processing of nodes") - pass diff --git a/templates/types/streaming/fastapi/app/api/routers/models.py b/templates/types/streaming/fastapi/app/api/routers/models.py index 17c63e59c368da0a225bcd637b365059cc65684f..3bbe7b6e7a1af4d5cb5bb62030d13ce322814c7a 100644 --- a/templates/types/streaming/fastapi/app/api/routers/models.py +++ b/templates/types/streaming/fastapi/app/api/routers/models.py @@ -1,6 +1,6 @@ import logging import os -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Optional from llama_index.core.llms import ChatMessage, MessageRole from llama_index.core.schema import NodeWithScore @@ -12,19 +12,52 @@ from app.config import DATA_DIR logger = logging.getLogger("uvicorn") -class FileContent(BaseModel): - type: Literal["text", "ref"] - # If the file is pure text then the value is be a string - # otherwise, it's a list of document IDs - value: str | List[str] +class FileMetadata(BaseModel): + id: str + name: str + url: Optional[str] = None + refs: Optional[List[str]] = None + + def _get_url_llm_content(self) -> Optional[str]: + url_prefix = os.getenv("FILESERVER_URL_PREFIX") + if url_prefix: + if self.url is not None: + return f"File URL: {self.url}\n" + else: + # Construct url from file name + return f"File URL (instruction: do not update this file URL yourself): {url_prefix}/output/uploaded/{self.name}\n" + else: + logger.warning( + "Warning: FILESERVER_URL_PREFIX not set in environment variables. Can't use file server" + ) + return None + + def to_llm_content(self) -> str: + """ + Construct content for LLM from the file metadata + """ + default_content = f"=====File: {self.name}=====\n" + # Include file URL if it's available + url_content = self._get_url_llm_content() + if url_content: + default_content += url_content + # Include document IDs if it's available + if self.refs is not None: + default_content += f"Document IDs: {self.refs}\n" + # Include sandbox file path + sandbox_file_path = f"/tmp/{self.name}" + default_content += f"Sandbox file path (instruction: only use sandbox path for artifact or code interpreter tool): {sandbox_file_path}\n" + return default_content class File(BaseModel): - id: str - content: FileContent - filename: str - filesize: int filetype: str + metadata: FileMetadata + + def _load_file_content(self) -> str: + file_path = f"output/uploaded/{self.metadata.name}" + with open(file_path, "r") as file: + return file.read() class AnnotationFileData(BaseModel): @@ -62,24 +95,18 @@ class ArtifactAnnotation(BaseModel): class Annotation(BaseModel): type: str - data: Union[AnnotationFileData, List[str], AgentAnnotation, ArtifactAnnotation] + data: AnnotationFileData | List[str] | AgentAnnotation | ArtifactAnnotation def to_content(self) -> Optional[str]: - if self.type == "document_file": - if isinstance(self.data, AnnotationFileData): - # We only support generating context content for CSV files for now - csv_files = [file for file in self.data.files if file.filetype == "csv"] - if len(csv_files) > 0: - return "Use data from following CSV raw content\n" + "\n".join( - [ - f"```csv\n{csv_file.content.value}\n```" - for csv_file in csv_files - ] - ) - else: - logger.warning( - f"Unexpected data type for document_file annotation: {type(self.data)}" + if self.type == "document_file" and isinstance(self.data, AnnotationFileData): + # iterate through all files and construct content for LLM + file_contents = [file.metadata.to_llm_content() for file in self.data.files] + if len(file_contents) > 0: + return "Use data from following files content\n" + "\n".join( + file_contents ) + elif self.type == "image": + raise NotImplementedError("Use image file is not supported yet!") else: logger.warning( f"The annotation {self.type} is not supported for generating context content" @@ -175,7 +202,11 @@ class ChatData(BaseModel): ): tool_output = annotation.data.toolOutput if tool_output and not tool_output.get("isError", False): - return tool_output.get("output", {}).get("code", None) + output = tool_output.get("output", {}) + if isinstance(output, dict) and output.get("code"): + return output.get("code") + else: + return None return None def get_history_messages( @@ -216,18 +247,26 @@ class ChatData(BaseModel): Get the document IDs from the chat messages """ document_ids: List[str] = [] + uploaded_files = self.get_uploaded_files() + for _file in uploaded_files: + refs = _file.metadata.refs + if refs is not None: + document_ids.extend(refs) + return list(set(document_ids)) + + def get_uploaded_files(self) -> List[File]: + """ + Get the uploaded files from the chat data + """ + uploaded_files = [] for message in self.messages: if message.role == MessageRole.USER and message.annotations is not None: for annotation in message.annotations: - if ( - annotation.type == "document_file" - and isinstance(annotation.data, AnnotationFileData) - and annotation.data.files is not None + if annotation.type == "document_file" and isinstance( + annotation.data, AnnotationFileData ): - for fi in annotation.data.files: - if fi.content.type == "ref": - document_ids += fi.content.value - return list(set(document_ids)) + uploaded_files.extend(annotation.data.files) + return uploaded_files class SourceNodes(BaseModel): diff --git a/templates/types/streaming/fastapi/app/api/routers/upload.py b/templates/types/streaming/fastapi/app/api/routers/upload.py index ccc03004b4cb6955b77b97b67e458ad174600cfb..78aff33cd1df817adb37cc622d07ac635e06bd2b 100644 --- a/templates/types/streaming/fastapi/app/api/routers/upload.py +++ b/templates/types/streaming/fastapi/app/api/routers/upload.py @@ -1,5 +1,5 @@ import logging -from typing import List, Any +from typing import Any, Dict from fastapi import APIRouter, HTTPException from pydantic import BaseModel @@ -18,12 +18,18 @@ class FileUploadRequest(BaseModel): @r.post("") -def upload_file(request: FileUploadRequest) -> List[str]: +def upload_file(request: FileUploadRequest) -> Dict[str, Any]: + """ + To upload a private file from the chat UI. + Returns: + The metadata of the uploaded file. + """ try: - logger.info("Processing file") - return PrivateFileService.process_file( + logger.info(f"Processing file: {request.filename}") + file_meta = PrivateFileService.process_file( request.filename, request.base64, request.params ) + return file_meta.to_upload_response() except Exception as e: logger.error(f"Error processing file: {e}", exc_info=True) raise HTTPException(status_code=500, detail="Error processing file") diff --git a/templates/types/streaming/fastapi/app/api/routers/vercel_response.py b/templates/types/streaming/fastapi/app/api/routers/vercel_response.py index 924c60ce5f6f25ebc21204f93108e6c73dcf563a..fc5f03e03a39f6490821f04fc47200dda8af265f 100644 --- a/templates/types/streaming/fastapi/app/api/routers/vercel_response.py +++ b/templates/types/streaming/fastapi/app/api/routers/vercel_response.py @@ -1,15 +1,19 @@ import json -from typing import List +import logging +from typing import Awaitable, List from aiostream import stream -from fastapi import Request +from fastapi import BackgroundTasks, Request from fastapi.responses import StreamingResponse from llama_index.core.chat_engine.types import StreamingAgentChatResponse +from llama_index.core.schema import NodeWithScore from app.api.routers.events import EventCallbackHandler from app.api.routers.models import ChatData, Message, SourceNodes from app.api.services.suggestion import NextQuestionSuggestion +logger = logging.getLogger("uvicorn") + class VercelStreamResponse(StreamingResponse): """ @@ -19,26 +23,16 @@ class VercelStreamResponse(StreamingResponse): TEXT_PREFIX = "0:" DATA_PREFIX = "8:" - @classmethod - def convert_text(cls, token: str): - # Escape newlines and double quotes to avoid breaking the stream - token = json.dumps(token) - return f"{cls.TEXT_PREFIX}{token}\n" - - @classmethod - def convert_data(cls, data: dict): - data_str = json.dumps(data) - return f"{cls.DATA_PREFIX}[{data_str}]\n" - def __init__( self, request: Request, event_handler: EventCallbackHandler, - response: StreamingAgentChatResponse, + response: Awaitable[StreamingAgentChatResponse], chat_data: ChatData, + background_tasks: BackgroundTasks, ): content = VercelStreamResponse.content_generator( - request, event_handler, response, chat_data + request, event_handler, response, chat_data, background_tasks ) super().__init__(content=content) @@ -47,53 +41,23 @@ class VercelStreamResponse(StreamingResponse): cls, request: Request, event_handler: EventCallbackHandler, - response: StreamingAgentChatResponse, + response: Awaitable[StreamingAgentChatResponse], chat_data: ChatData, + background_tasks: BackgroundTasks, ): - # Yield the text response - async def _chat_response_generator(): - final_response = "" - async for token in response.async_response_gen(): - final_response += token - yield cls.convert_text(token) - - # Generate next questions if next question prompt is configured - question_data = await cls._generate_next_questions( - chat_data.messages, final_response - ) - if question_data: - yield cls.convert_data(question_data) - - # the text_generator is the leading stream, once it's finished, also finish the event stream - event_handler.is_done = True - - # Yield the source nodes - yield cls.convert_data( - { - "type": "sources", - "data": { - "nodes": [ - SourceNodes.from_source_node(node).model_dump() - for node in response.source_nodes - ] - }, - } - ) - - # Yield the events from the event handler - async def _event_generator(): - async for event in event_handler.async_event_gen(): - event_response = event.to_response() - if event_response is not None: - yield cls.convert_data(event_response) + chat_response_generator = cls._chat_response_generator( + response, background_tasks, event_handler, chat_data + ) + event_generator = cls._event_generator(event_handler) - combine = stream.merge(_chat_response_generator(), _event_generator()) + # Merge the chat response generator and the event generator + combine = stream.merge(chat_response_generator, event_generator) is_stream_started = False async with combine.stream() as streamer: async for output in streamer: if not is_stream_started: is_stream_started = True - # Stream a blank message to start the stream + # Stream a blank message to start displaying the response in the UI yield cls.convert_text("") yield output @@ -101,6 +65,90 @@ class VercelStreamResponse(StreamingResponse): if await request.is_disconnected(): break + @classmethod + async def _event_generator(cls, event_handler: EventCallbackHandler): + """ + Yield the events from the event handler + """ + async for event in event_handler.async_event_gen(): + event_response = event.to_response() + if event_response is not None: + yield cls.convert_data(event_response) + + @classmethod + async def _chat_response_generator( + cls, + response: Awaitable[StreamingAgentChatResponse], + background_tasks: BackgroundTasks, + event_handler: EventCallbackHandler, + chat_data: ChatData, + ): + """ + Yield the text response and source nodes from the chat engine + """ + # Wait for the response from the chat engine + result = await response + + # Once we got a source node, start a background task to download the files (if needed) + cls._process_response_nodes(result.source_nodes, background_tasks) + + # Yield the source nodes + yield cls.convert_data( + { + "type": "sources", + "data": { + "nodes": [ + SourceNodes.from_source_node(node).model_dump() + for node in result.source_nodes + ] + }, + } + ) + + final_response = "" + async for token in result.async_response_gen(): + final_response += token + yield cls.convert_text(token) + + # Generate next questions if next question prompt is configured + question_data = await cls._generate_next_questions( + chat_data.messages, final_response + ) + if question_data: + yield cls.convert_data(question_data) + + # the text_generator is the leading stream, once it's finished, also finish the event stream + event_handler.is_done = True + + @classmethod + def convert_text(cls, token: str): + # Escape newlines and double quotes to avoid breaking the stream + token = json.dumps(token) + return f"{cls.TEXT_PREFIX}{token}\n" + + @classmethod + def convert_data(cls, data: dict): + data_str = json.dumps(data) + return f"{cls.DATA_PREFIX}[{data_str}]\n" + + @staticmethod + def _process_response_nodes( + source_nodes: List[NodeWithScore], + background_tasks: BackgroundTasks, + ): + try: + # Start background tasks to download documents from LlamaCloud if needed + from app.engine.service import LLamaCloudFileService + + LLamaCloudFileService.download_files_from_nodes( + source_nodes, background_tasks + ) + except ImportError: + logger.debug( + "LlamaCloud is not configured. Skipping post processing of nodes" + ) + pass + @staticmethod async def _generate_next_questions(chat_history: List[Message], response: str): questions = await NextQuestionSuggestion.suggest_next_questions( diff --git a/templates/types/streaming/fastapi/app/engine/utils/file_helper.py b/templates/types/streaming/fastapi/app/engine/utils/file_helper.py index c794a3b472fff0902ef4fdb55bacee3addee4639..5270139eabab1900d991ee14c3498c1d115ed063 100644 --- a/templates/types/streaming/fastapi/app/engine/utils/file_helper.py +++ b/templates/types/streaming/fastapi/app/engine/utils/file_helper.py @@ -1,17 +1,36 @@ import logging import os import uuid -from typing import Optional +from typing import Any, Dict, List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field, computed_field logger = logging.getLogger(__name__) class FileMetadata(BaseModel): - outputPath: str - filename: str - url: str + path: str = Field(..., description="The stored path of the file") + name: str = Field(..., description="The name of the file") + url: str = Field(..., description="The URL of the file") + refs: Optional[List[str]] = Field( + None, description="The indexed document IDs that the file is referenced to" + ) + + @computed_field + def file_id(self) -> Optional[str]: + file_els = self.name.split("_", maxsplit=1) + if len(file_els) == 2: + return file_els[0] + return None + + def to_upload_response(self) -> Dict[str, Any]: + response = { + "id": self.file_id, + "name": self.name, + "url": self.url, + "refs": self.refs, + } + return response def save_file( @@ -58,7 +77,7 @@ def save_file( logger.info(f"Saved file to {file_path}") return FileMetadata( - outputPath=file_path, - filename=file_name, + path=file_path if isinstance(file_path, str) else str(file_path), + name=file_name, url=f"{os.getenv('FILESERVER_URL_PREFIX')}/{file_path}", ) diff --git a/templates/types/streaming/nextjs/app/api/chat/upload/route.ts b/templates/types/streaming/nextjs/app/api/chat/upload/route.ts index 001d3a0efe9065bd80acb9c3e62b67f6c5803f9a..382a94c937fd779e347d49b9ba8a59515c12487f 100644 --- a/templates/types/streaming/nextjs/app/api/chat/upload/route.ts +++ b/templates/types/streaming/nextjs/app/api/chat/upload/route.ts @@ -23,11 +23,6 @@ export async function POST(request: NextRequest) { ); } const index = await getDataSource(params); - if (!index) { - throw new Error( - `StorageContext is empty - call 'npm run generate' to generate the storage first`, - ); - } return NextResponse.json(await uploadDocument(index, filename, base64)); } catch (error) { console.error("[Upload API]", error); diff --git a/templates/types/streaming/nextjs/app/api/sandbox/route.ts b/templates/types/streaming/nextjs/app/api/sandbox/route.ts index cfc20087490bd072f55034eb31512ccb5bc35372..6bbd15177dbe33f8a0eb0f50cdbf767f6545b207 100644 --- a/templates/types/streaming/nextjs/app/api/sandbox/route.ts +++ b/templates/types/streaming/nextjs/app/api/sandbox/route.ts @@ -19,6 +19,8 @@ import { Result, Sandbox, } from "@e2b/code-interpreter"; +import fs from "node:fs/promises"; +import path from "node:path"; import { saveDocument } from "../chat/llamaindex/documents/helper"; type CodeArtifact = { @@ -32,6 +34,7 @@ type CodeArtifact = { port: number | null; file_path: string; code: string; + files?: string[]; }; const sandboxTimeout = 10 * 60 * 1000; // 10 minute in ms @@ -82,6 +85,18 @@ export async function POST(req: Request) { } } + // Copy files + if (artifact.files) { + artifact.files.forEach(async (sandboxFilePath) => { + const fileName = path.basename(sandboxFilePath); + const localFilePath = path.join("output", "uploaded", fileName); + const fileContent = await fs.readFile(localFilePath); + + await sbx.files.write(sandboxFilePath, fileContent); + console.log(`Copied file to ${sandboxFilePath} in ${sbx.sandboxID}`); + }); + } + // Copy code to fs if (artifact.code && Array.isArray(artifact.code)) { artifact.code.forEach(async (file) => { diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx index 326cc969505839384e6862cb56240af3617b785b..ce2b02d063f14fc6064b6fbe94c321409dad3f84 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx @@ -95,9 +95,9 @@ export default function ChatInput( )} {files.length > 0 && ( <div className="flex gap-4 w-full overflow-auto py-2"> - {files.map((file) => ( + {files.map((file, index) => ( <DocumentPreview - key={file.id} + key={file.metadata?.id ?? `${file.filename}-${index}`} file={file} onRemove={() => removeDoc(file)} /> diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-files.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-files.tsx index 5139c5411faf381fbe95bf35427735e20f776e71..085963d2b940db8369293b97dfb737e11bce32c6 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-files.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-files.tsx @@ -5,8 +5,11 @@ export function ChatFiles({ data }: { data: DocumentFileData }) { if (!data.files.length) return null; return ( <div className="flex gap-2 items-center"> - {data.files.map((file) => ( - <DocumentPreview key={file.id} file={file} /> + {data.files.map((file, index) => ( + <DocumentPreview + key={file.metadata?.id ?? `${file.filename}-${index}`} + file={file} + /> ))} </div> ); diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-sources.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-sources.tsx index 929e199cafac16c627a762aa93bb034c9bc0ada5..03a1a62783ace6dd33e19293f52c31f7a85e5ba2 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-sources.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-sources.tsx @@ -1,8 +1,7 @@ -import { Check, Copy, FileText } from "lucide-react"; -import Image from "next/image"; +import { Check, Copy } from "lucide-react"; import { useMemo } from "react"; import { Button } from "../../button"; -import { FileIcon } from "../../document-preview"; +import { PreviewCard } from "../../document-preview"; import { HoverCard, HoverCardContent, @@ -49,13 +48,7 @@ export function ChatSources({ data }: { data: SourceData }) { ); } -export function SourceInfo({ - node, - index, -}: { - node?: SourceNode; - index: number; -}) { +function SourceInfo({ node, index }: { node?: SourceNode; index: number }) { if (!node) return <SourceNumberButton index={index} />; return ( <HoverCard> @@ -97,49 +90,33 @@ export function SourceNumberButton({ ); } -function DocumentInfo({ document }: { document: Document }) { - if (!document.sources.length) return null; +export function DocumentInfo({ + document, + className, +}: { + document: Document; + className?: string; +}) { const { url, sources } = document; - const fileName = sources[0].metadata.file_name as string | undefined; - const fileExt = fileName?.split(".").pop(); - const fileImage = fileExt ? FileIcon[fileExt as DocumentFileType] : null; + // Extract filename from URL + const urlParts = url.split("/"); + const fileName = urlParts.length > 0 ? urlParts[urlParts.length - 1] : url; + const fileExt = fileName?.split(".").pop() as DocumentFileType | undefined; + + const previewFile = { + filename: fileName, + filetype: fileExt, + }; const DocumentDetail = ( - <div - key={url} - className="h-28 w-48 flex flex-col justify-between p-4 border rounded-md shadow-md cursor-pointer" - > - <p - title={fileName} - className={cn( - fileName ? "truncate" : "text-blue-900 break-words", - "text-left", - )} - > - {fileName ?? url} - </p> - <div className="flex justify-between items-center"> - <div className="space-x-2 flex"> - {sources.map((node: SourceNode, index: number) => { - return ( - <div key={node.id}> - <SourceInfo node={node} index={index} /> - </div> - ); - })} - </div> - {fileImage ? ( - <div className="relative h-8 w-8 shrink-0 overflow-hidden rounded-md"> - <Image - className="h-full w-auto" - priority - src={fileImage} - alt="Icon" - /> + <div className={`relative ${className}`}> + <PreviewCard className={"cursor-pointer"} file={previewFile} /> + <div className="absolute bottom-2 right-2 space-x-2 flex"> + {sources.map((node: SourceNode, index: number) => ( + <div key={node.id}> + <SourceInfo node={node} index={index} /> </div> - ) : ( - <FileText className="text-gray-500" /> - )} + ))} </div> </div> ); diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/markdown.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/markdown.tsx index aa32e40d05e273ec8a3d3330535b1c280d36d269..8682a802ec75e576a1ea938c3a1483cac6ce1b4a 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/markdown.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/markdown.tsx @@ -5,8 +5,9 @@ import rehypeKatex from "rehype-katex"; import remarkGfm from "remark-gfm"; import remarkMath from "remark-math"; -import { SourceData } from ".."; -import { SourceNumberButton } from "./chat-sources"; +import { DOCUMENT_FILE_TYPES, DocumentFileType, SourceData } from ".."; +import { useClientConfig } from "../hooks/use-config"; +import { DocumentInfo, SourceNumberButton } from "./chat-sources"; import { CodeBlock } from "./codeblock"; const MemoizedReactMarkdown: FC<Options> = memo( @@ -78,6 +79,7 @@ export default function Markdown({ sources?: SourceData; }) { const processedContent = preprocessContent(content, sources); + const { backend } = useClientConfig(); return ( <MemoizedReactMarkdown @@ -86,7 +88,7 @@ export default function Markdown({ rehypePlugins={[rehypeKatex as any]} components={{ p({ children }) { - return <p className="mb-2 last:mb-0">{children}</p>; + return <div className="mb-2 last:mb-0">{children}</div>; }, code({ node, inline, className, children, ...props }) { if (children.length) { @@ -120,6 +122,26 @@ export default function Markdown({ ); }, a({ href, children }) { + // If href starts with `{backend}/api/files`, then it's a local document and we use DocumenInfo for rendering + if (href?.startsWith(backend + "/api/files")) { + // Check if the file is document file type + const fileExtension = href.split(".").pop()?.toLowerCase(); + + if ( + fileExtension && + DOCUMENT_FILE_TYPES.includes(fileExtension as DocumentFileType) + ) { + return ( + <DocumentInfo + document={{ + url: new URL(decodeURIComponent(href)).href, + sources: [], + }} + className="mb-2 mt-2" + /> + ); + } + } // If a text link starts with 'citation:', then render it as a citation reference if ( Array.isArray(children) && diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/hooks/use-file.ts b/templates/types/streaming/nextjs/app/components/ui/chat/hooks/use-file.ts index cc49169ac62747f1a6c135a6dc9774314dabb313..695202f4ff146c886813e8b9e38c3ceb0e4bfde5 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/hooks/use-file.ts +++ b/templates/types/streaming/nextjs/app/components/ui/chat/hooks/use-file.ts @@ -2,12 +2,12 @@ import { JSONValue } from "llamaindex"; import { useState } from "react"; -import { v4 as uuidv4 } from "uuid"; import { DocumentFile, DocumentFileType, MessageAnnotation, MessageAnnotationType, + UploadedFileMeta, } from ".."; import { useClientConfig } from "./use-config"; @@ -25,7 +25,7 @@ export function useFile() { const [files, setFiles] = useState<DocumentFile[]>([]); const docEqual = (a: DocumentFile, b: DocumentFile) => { - if (a.id === b.id) return true; + if (a.metadata?.id === b.metadata?.id) return true; if (a.filename === b.filename && a.filesize === b.filesize) return true; return false; }; @@ -40,7 +40,9 @@ export function useFile() { }; const removeDoc = (file: DocumentFile) => { - setFiles((prev) => prev.filter((f) => f.id !== file.id)); + setFiles((prev) => + prev.filter((f) => f.metadata?.id !== file.metadata?.id), + ); }; const reset = () => { @@ -51,7 +53,7 @@ export function useFile() { const uploadContent = async ( file: File, requestParams: any = {}, - ): Promise<string[]> => { + ): Promise<UploadedFileMeta> => { const base64 = await readContent({ file, asUrl: true }); const uploadAPI = `${backend}/api/chat/upload`; const response = await fetch(uploadAPI, { @@ -66,7 +68,7 @@ export function useFile() { }), }); if (!response.ok) throw new Error("Failed to upload document."); - return await response.json(); + return (await response.json()) as UploadedFileMeta; }; const getAnnotations = () => { @@ -112,34 +114,14 @@ export function useFile() { const filetype = docMineTypeMap[file.type]; if (!filetype) throw new Error("Unsupported document type."); - const newDoc: Omit<DocumentFile, "content"> = { - id: uuidv4(), - filetype, + const uploadedFileMeta = await uploadContent(file, requestParams); + const newDoc: DocumentFile = { filename: file.name, filesize: file.size, + filetype, + metadata: uploadedFileMeta, }; - switch (file.type) { - case "text/csv": { - const content = await readContent({ file }); - return addDoc({ - ...newDoc, - content: { - type: "text", - value: content, - }, - }); - } - default: { - const ids = await uploadContent(file, requestParams); - return addDoc({ - ...newDoc, - content: { - type: "ref", - value: ids, - }, - }); - } - } + return addDoc(newDoc); }; return { diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/index.ts b/templates/types/streaming/nextjs/app/components/ui/chat/index.ts index a9aba73e504cd354999735e3aa97ee86581feb9e..eedd4ceb25763e8d2b2da43f31454e52700f47fe 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/index.ts +++ b/templates/types/streaming/nextjs/app/components/ui/chat/index.ts @@ -20,18 +20,25 @@ export type ImageData = { }; export type DocumentFileType = "csv" | "pdf" | "txt" | "docx"; - -export type DocumentFileContent = { - type: "ref" | "text"; - value: string[] | string; +export const DOCUMENT_FILE_TYPES: DocumentFileType[] = [ + "csv", + "pdf", + "txt", + "docx", +]; + +export type UploadedFileMeta = { + id: string; + name: string; // The uploaded file name in the backend (including uuid and sanitized) + url?: string; + refs?: string[]; }; export type DocumentFile = { - id: string; - filename: string; + filename: string; // The original file name filesize: number; filetype: DocumentFileType; - content: DocumentFileContent; + metadata?: UploadedFileMeta; // undefined when the file is not uploaded yet }; export type DocumentFileData = { diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/widgets/Artifact.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/widgets/Artifact.tsx index 6f808efd95a2e431945adc0f606c446016107508..fa2a60592584a1a6c69619dee2d49433ba3c9cec 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/widgets/Artifact.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/widgets/Artifact.tsx @@ -26,6 +26,7 @@ export type CodeArtifact = { port: number | null; file_path: string; code: string; + files?: string[]; }; type ArtifactResult = { @@ -201,10 +202,10 @@ function ArtifactOutput({ function RunTimeError({ runtimeError, }: { - runtimeError: { name: string; value: string; tracebackRaw: string[] }; + runtimeError: { name: string; value: string; tracebackRaw?: string[] }; }) { const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 1000 }); - const contentToCopy = `Fix this error:\n${runtimeError.name}\n${runtimeError.value}\n${runtimeError.tracebackRaw.join("\n")}`; + const contentToCopy = `Fix this error:\n${runtimeError.name}\n${runtimeError.value}\n${runtimeError.tracebackRaw?.join("\n")}`; return ( <Collapsible className="bg-red-100 text-red-800 rounded-md py-2 px-4 space-y-4"> <CollapsibleTrigger className="font-bold w-full text-start flex items-center justify-between"> @@ -215,7 +216,7 @@ function RunTimeError({ <div className="flex flex-col gap-2"> <p className="font-semibold">{runtimeError.name}</p> <p>{runtimeError.value}</p> - {runtimeError.tracebackRaw.map((trace, index) => ( + {runtimeError.tracebackRaw?.map((trace, index) => ( <pre key={index} className="whitespace-pre-wrap text-sm mb-2"> {trace} </pre> diff --git a/templates/types/streaming/nextjs/app/components/ui/document-preview.tsx b/templates/types/streaming/nextjs/app/components/ui/document-preview.tsx index c2d5232217eeb65dc2aa57b279e28c8ef884f883..b0f9bd900dde1ec0319476250aae2afdc92c9272 100644 --- a/templates/types/streaming/nextjs/app/components/ui/document-preview.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/document-preview.tsx @@ -23,11 +23,11 @@ export interface DocumentPreviewProps { } export function DocumentPreview(props: DocumentPreviewProps) { - const { filename, filesize, content, filetype } = props.file; + const { filename, filesize, filetype, metadata } = props.file; - if (content.type === "ref") { + if (metadata?.refs?.length) { return ( - <div title={`Document IDs: ${(content.value as string[]).join(", ")}`}> + <div title={`Document IDs: ${metadata.refs.join(", ")}`}> <PreviewCard {...props} /> </div> ); @@ -37,7 +37,7 @@ export function DocumentPreview(props: DocumentPreviewProps) { <Drawer direction="left"> <DrawerTrigger asChild> <div> - <PreviewCard {...props} /> + <PreviewCard className="cursor-pointer" {...props} /> </div> </DrawerTrigger> <DrawerContent className="w-3/5 mt-24 h-full max-h-[96%] "> @@ -53,9 +53,9 @@ export function DocumentPreview(props: DocumentPreviewProps) { </DrawerClose> </DrawerHeader> <div className="m-4 max-h-[80%] overflow-auto"> - {content.type === "text" && ( + {metadata?.refs?.length && ( <pre className="bg-secondary rounded-md p-4 block text-sm"> - {content.value as string} + {metadata.refs.join(", ")} </pre> )} </div> @@ -71,31 +71,41 @@ export const FileIcon: Record<DocumentFileType, string> = { txt: TxtIcon, }; -function PreviewCard(props: DocumentPreviewProps) { - const { onRemove, file } = props; +export function PreviewCard(props: { + file: { + filename: string; + filesize?: number; + filetype?: DocumentFileType; + }; + onRemove?: () => void; + className?: string; +}) { + const { onRemove, file, className } = props; return ( <div className={cn( "p-2 w-60 max-w-60 bg-secondary rounded-lg text-sm relative", - file.content.type === "ref" ? "" : "cursor-pointer", + className, )} > <div className="flex flex-row items-center gap-2"> - <div className="relative h-8 w-8 shrink-0 overflow-hidden rounded-md"> + <div className="relative h-8 w-8 shrink-0 overflow-hidden rounded-md flex items-center justify-center"> <Image - className="h-full w-auto" + className="h-full w-auto object-contain" priority - src={FileIcon[file.filetype]} + src={FileIcon[file.filetype || "txt"]} alt="Icon" /> </div> <div className="overflow-hidden"> <div className="truncate font-semibold"> - {file.filename} ({inKB(file.filesize)} KB) - </div> - <div className="truncate text-token-text-tertiary flex items-center gap-2"> - <span>{file.filetype.toUpperCase()} File</span> + {file.filename} {file.filesize ? `(${inKB(file.filesize)} KB)` : ""} </div> + {file.filetype && ( + <div className="truncate text-token-text-tertiary flex items-center gap-2"> + <span>{file.filetype.toUpperCase()} File</span> + </div> + )} </div> </div> {onRemove && (