diff --git a/.changeset/large-parents-exercise.md b/.changeset/large-parents-exercise.md new file mode 100644 index 0000000000000000000000000000000000000000..52a6072c8a8af6c051f626da45b8ef00ea4d07bf --- /dev/null +++ b/.changeset/large-parents-exercise.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Add form filling use case (Python) diff --git a/helpers/tools.ts b/helpers/tools.ts index 5199f244d0c42d140d495766ebdac7e067b14be4..27ea1e45a07611f462f57e28fa0c80b86352aceb 100644 --- a/helpers/tools.ts +++ b/helpers/tools.ts @@ -267,6 +267,22 @@ For better results, you can specify the region parameter to get results from a s }, ], }, + { + display: "Form Filling", + name: "form_filling", + supportedFrameworks: ["fastapi"], + type: ToolType.LOCAL, + dependencies: [ + { + name: "pandas", + version: "^2.2.3", + }, + { + name: "tabulate", + version: "^0.9.0", + }, + ], + }, ]; export const getTool = (toolName: string): Tool | undefined => { diff --git a/helpers/types.ts b/helpers/types.ts index 30a2ab357601d654f4848e8c736439f6bced6c71..cef8ce3b81688269a50ce8127231014aef00b3ad 100644 --- a/helpers/types.ts +++ b/helpers/types.ts @@ -48,7 +48,7 @@ export type TemplateDataSource = { }; export type TemplateDataSourceType = "file" | "web" | "db"; export type TemplateObservability = "none" | "traceloop" | "llamatrace"; -export type TemplateAgents = "financial_report" | "blog"; +export type TemplateAgents = "financial_report" | "blog" | "form_filling"; // Config for both file and folder export type FileSourceConfig = | { diff --git a/questions/questions.ts b/questions/questions.ts index 0577de5b0e772943bca15acae84245a70392fe96..45c262f49b6bbb10df690f2affa0404e31cbaff7 100644 --- a/questions/questions.ts +++ b/questions/questions.ts @@ -177,6 +177,34 @@ export const askProQuestions = async (program: QuestionArgs) => { program.observability = observability; } + // Ask agents + if (program.template === "multiagent" && !program.agents) { + const { agents } = await prompts( + { + type: "select", + name: "agents", + message: "Which agents would you like to use?", + choices: [ + { + title: "Financial report (generate a financial report)", + value: "financial_report", + }, + { + title: "Form filling (fill missing value in a CSV file)", + value: "form_filling", + }, + { + title: "Blog writer (Write a blog post)", + value: "blog_writer", + }, + ], + initial: 0, + }, + questionHandlers, + ); + program.agents = agents; + } + if (!program.modelConfig) { const modelConfig = await askModelConfig({ openAiKey: program.openAiKey, diff --git a/questions/simple.ts b/questions/simple.ts index 4888c98566b845a8e2b8a60e692ecc7231c61456..50ae904d759798984991531d3301784a6860eea7 100644 --- a/questions/simple.ts +++ b/questions/simple.ts @@ -10,6 +10,7 @@ type AppType = | "rag" | "code_artifact" | "financial_report_agent" + | "form_filling" | "extractor" | "data_scientist"; @@ -35,8 +36,12 @@ export const askSimpleQuestions = async ( title: "Financial Report Generator (using Workflows)", value: "financial_report_agent", }, + { + title: "Form Filler (using Workflows)", + value: "form_filling", + }, { title: "Code Artifact Agent", value: "code_artifact" }, - { title: "Structured extraction", value: "extractor" }, + { title: "Information Extractor", value: "extractor" }, ], }, questionHandlers, @@ -47,19 +52,22 @@ export const askSimpleQuestions = async ( let useLlamaCloud = false; if (appType !== "extractor") { - const { language: newLanguage } = await prompts( - { - type: "select", - name: "language", - message: "What language do you want to use?", - choices: [ - { title: "Python (FastAPI)", value: "fastapi" }, - { title: "Typescript (NextJS)", value: "nextjs" }, - ], - }, - questionHandlers, - ); - language = newLanguage; + // TODO: Add TS support for form filling use case + if (appType !== "form_filling") { + const { language: newLanguage } = await prompts( + { + type: "select", + name: "language", + message: "What language do you want to use?", + choices: [ + { title: "Python (FastAPI)", value: "fastapi" }, + { title: "Typescript (NextJS)", value: "nextjs" }, + ], + }, + questionHandlers, + ); + language = newLanguage; + } const { useLlamaCloud: newUseLlamaCloud } = await prompts( { @@ -152,6 +160,14 @@ const convertAnswers = async ( frontend: true, modelConfig: MODEL_GPT4o, }, + form_filling: { + template: "multiagent", + agents: "form_filling", + tools: getTools(["form_filling"]), + dataSources: EXAMPLE_10K_SEC_FILES, + frontend: true, + modelConfig: MODEL_GPT4o, + }, extractor: { template: "extractor", tools: [], diff --git a/templates/components/agents/python/blog/app/agents/publisher.py b/templates/components/agents/python/blog/app/agents/publisher.py index 5dfc6c7781d777c39181c405c98d165819baf4a1..0245e17a4000139e1292bc28815fbae5acb345e6 100644 --- a/templates/components/agents/python/blog/app/agents/publisher.py +++ b/templates/components/agents/python/blog/app/agents/publisher.py @@ -11,11 +11,11 @@ def get_publisher_tools() -> Tuple[List[FunctionTool], str, str]: tools = [] # Get configured tools from the tools.yaml file configured_tools = ToolFactory.from_env(map_result=True) - if "document_generator" in configured_tools.keys(): - tools.extend(configured_tools["document_generator"]) + if "generate_document" in configured_tools.keys(): + tools.append(configured_tools["generate_document"]) prompt_instructions = dedent(""" Normally, reply the blog post content to the user directly. - But if user requested to generate a file, use the document_generator tool to generate the file and reply the link to the file. + But if user requested to generate a file, use the generate_document tool to generate the file and reply the link to the file. """) description = "Expert in publishing the blog post, able to publish the blog post in PDF or HTML format." else: diff --git a/templates/components/agents/python/blog/app/agents/researcher.py b/templates/components/agents/python/blog/app/agents/researcher.py index 3b9ba5ed6665ff34de7a7efb109e7a12acc510a5..1ef6f07e57cf0cd01d72612f6e876c8fa6fe63d0 100644 --- a/templates/components/agents/python/blog/app/agents/researcher.py +++ b/templates/components/agents/python/blog/app/agents/researcher.py @@ -42,11 +42,15 @@ def _get_research_tools(**kwargs) -> QueryEngineTool: query_engine_tool = _create_query_engine_tool(**kwargs) if query_engine_tool is not None: tools.append(query_engine_tool) - researcher_tool_names = ["duckduckgo", "wikipedia.WikipediaToolSpec"] + researcher_tool_names = [ + "duckduckgo_search", + "duckduckgo_image_search", + "wikipedia.WikipediaToolSpec", + ] configured_tools = ToolFactory.from_env(map_result=True) for tool_name, tool in configured_tools.items(): if tool_name in researcher_tool_names: - tools.extend(tool) + tools.append(tool) return tools diff --git a/templates/components/agents/python/financial_report/app/agents/analyst.py b/templates/components/agents/python/financial_report/app/agents/analyst.py index f86b10d92af12d06eb8a62221659de31262e7766..877017f6eb6db04098f4e8882e007e3f325c4048 100644 --- a/templates/components/agents/python/financial_report/app/agents/analyst.py +++ b/templates/components/agents/python/financial_report/app/agents/analyst.py @@ -22,8 +22,8 @@ def _get_analyst_params() -> Tuple[List[type[FunctionTool]], str, str]: description = "Expert in analyzing financial data" configured_tools = ToolFactory.from_env(map_result=True) # Check if the interpreter tool is configured - if "interpreter" in configured_tools.keys(): - tools.extend(configured_tools["interpreter"]) + if "interpret" in configured_tools.keys(): + tools.append(configured_tools["interpret"]) prompt_instructions += dedent(""" You are able to visualize the financial data using code interpreter tool. It's very useful to create and include visualizations to the report (make sure you include the right code and data for the visualization). diff --git a/templates/components/agents/python/financial_report/app/agents/reporter.py b/templates/components/agents/python/financial_report/app/agents/reporter.py index b1337bb033c792a4268573ad73be68fc4fce0cbc..496884ca261f3aee90a578941f2e577ea51fef40 100644 --- a/templates/components/agents/python/financial_report/app/agents/reporter.py +++ b/templates/components/agents/python/financial_report/app/agents/reporter.py @@ -24,8 +24,8 @@ def _get_reporter_params( """ ) configured_tools = ToolFactory.from_env(map_result=True) - if "document_generator" in configured_tools: # type: ignore - tools.extend(configured_tools["document_generator"]) # type: ignore + if "generate_document" in configured_tools: # type: ignore + tools.append(configured_tools["generate_document"]) # type: ignore prompt_instructions += ( "\nYou are also able to generate a file document (PDF/HTML) of the report." ) diff --git a/templates/components/agents/python/form_filling/README-template.md b/templates/components/agents/python/form_filling/README-template.md new file mode 100644 index 0000000000000000000000000000000000000000..a3340c531f455b26e9b15d99c7a0082f7b477e75 --- /dev/null +++ b/templates/components/agents/python/form_filling/README-template.md @@ -0,0 +1,59 @@ +This is a [LlamaIndex](https://www.llamaindex.ai/) multi-agents project using [Workflows](https://docs.llamaindex.ai/en/stable/understanding/workflows/). + +## Getting Started + +First, setup the environment with poetry: + +> **_Note:_** This step is not needed if you are using the dev-container. + +```shell +poetry install +``` + +Then check the parameters that have been pre-configured in the `.env` file in this directory. +Make sure you have the `OPENAI_API_KEY` set. + +Second, run the development server: + +```shell +poetry run python main.py +``` + +## Use Case: Filling Financial CSV Template + +To reproduce the use case, start the [frontend](../frontend/README.md) and follow these steps in the frontend: + +1. Upload the Apple and Tesla financial reports from the [data](./data) directory. Just send an empty message. +2. Upload the CSV file [sec_10k_template.csv](./sec_10k_template.csv) and send the message "Fill the missing cells in the CSV file". + +The agent will fill the missing cells by retrieving the information from the uploaded financial reports and return a new CSV file with the filled cells. + +### API endpoints + +The example provides one streaming API endpoint `/api/chat`. +You can test the endpoint with the following curl request: + +``` +curl --location 'localhost:8000/api/chat' \ +--header 'Content-Type: application/json' \ +--data '{ "messages": [{ "role": "user", "content": "What can you do?" }] }' +``` + +You can start editing the API by modifying `app/api/routers/chat.py` or `app/agents/form_filling.py`. The API auto-updates as you save the files. + +Open [http://localhost:8000/docs](http://localhost:8000/docs) with your browser to see the Swagger UI of the API. + +The API allows CORS for all origins to simplify development. You can change this behavior by setting the `ENVIRONMENT` environment variable to `prod`: + +``` +ENVIRONMENT=prod poetry run python main.py +``` + +## Learn More + +To learn more about LlamaIndex, take a look at the following resources: + +- [LlamaIndex Documentation](https://docs.llamaindex.ai) - learn about LlamaIndex. +- [Workflows Introduction](https://docs.llamaindex.ai/en/stable/understanding/workflows/) - learn about LlamaIndex workflows. + +You can check out [the LlamaIndex GitHub repository](https://github.com/run-llama/llama_index) - your feedback and contributions are welcome! diff --git a/templates/components/agents/python/form_filling/app/agents/form_filling.py b/templates/components/agents/python/form_filling/app/agents/form_filling.py new file mode 100644 index 0000000000000000000000000000000000000000..2cc7a0f79f45e37e49d7bb9ded07f15216cea6e3 --- /dev/null +++ b/templates/components/agents/python/form_filling/app/agents/form_filling.py @@ -0,0 +1,397 @@ +import os +import uuid +from enum import Enum +from typing import AsyncGenerator, List, Optional + +from app.engine.index import get_index +from app.engine.tools import ToolFactory +from app.engine.tools.form_filling import CellValue, MissingCell +from llama_index.core import Settings +from llama_index.core.base.llms.types import ChatMessage, MessageRole +from llama_index.core.indices.vector_store import VectorStoreIndex +from llama_index.core.llms.function_calling import FunctionCallingLLM +from llama_index.core.memory import ChatMemoryBuffer +from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection +from llama_index.core.tools.types import ToolOutput +from llama_index.core.workflow import ( + Context, + Event, + StartEvent, + StopEvent, + Workflow, + step, +) +from pydantic import Field + + +def create_workflow( + chat_history: Optional[List[ChatMessage]] = None, **kwargs +) -> Workflow: + index: VectorStoreIndex = get_index() + if index is None: + query_engine_tool = None + else: + top_k = int(os.getenv("TOP_K", 10)) + query_engine = index.as_query_engine(similarity_top_k=top_k) + query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine) + + configured_tools = ToolFactory.from_env(map_result=True) + extractor_tool = configured_tools.get("extract_questions") + filling_tool = configured_tools.get("fill_form") + + if extractor_tool is None or filling_tool is None: + raise ValueError("Extractor or filling tool is not found!") + + workflow = FormFillingWorkflow( + query_engine_tool=query_engine_tool, + extractor_tool=extractor_tool, + filling_tool=filling_tool, + chat_history=chat_history, + ) + + return workflow + + +class InputEvent(Event): + input: List[ChatMessage] + response: bool = False + + +class ExtractMissingCellsEvent(Event): + tool_call: ToolSelection + + +class FindAnswersEvent(Event): + missing_cells: list[MissingCell] + + +class FillEvent(Event): + tool_call: ToolSelection + + +class AgentRunEventType(Enum): + TEXT = "text" + PROGRESS = "progress" + + +class AgentRunEvent(Event): + name: str + msg: str + event_type: AgentRunEventType = Field(default=AgentRunEventType.TEXT) + data: Optional[dict] = None + + def to_response(self) -> dict: + return { + "type": "agent", + "data": { + "agent": self.name, + "type": self.event_type.value, + "text": self.msg, + "data": self.data, + }, + } + + +class FormFillingWorkflow(Workflow): + """ + A predefined workflow for filling missing cells in a CSV file. + Required tools: + - query_engine: A query engine to query for the answers to the questions. + - extract_question: Extract missing cells in a CSV file and generate questions to fill them. + - answer_question: Query for the answers to the questions. + + Flow: + 1. Extract missing cells in a CSV file and generate questions to fill them. + 2. Query for the answers to the questions. + 3. Fill the missing cells with the answers. + """ + + _default_system_prompt = """ + You are a helpful assistant who helps fill missing cells in a CSV file. + Only use provided data, never make up any information yourself. Fill N/A if the answer is not found. + """ + + def __init__( + self, + query_engine_tool: QueryEngineTool, + extractor_tool: FunctionTool, + filling_tool: FunctionTool, + llm: Optional[FunctionCallingLLM] = None, + timeout: int = 360, + chat_history: Optional[List[ChatMessage]] = None, + system_prompt: Optional[str] = None, + ): + super().__init__(timeout=timeout) + self.system_prompt = system_prompt or self._default_system_prompt + self.chat_history = chat_history or [] + self.query_engine_tool = query_engine_tool + self.extractor_tool = extractor_tool + self.filling_tool = filling_tool + self.llm: FunctionCallingLLM = llm or Settings.llm + if not isinstance(self.llm, FunctionCallingLLM): + raise ValueError("FormFillingWorkflow only supports FunctionCallingLLM.") + self.memory = ChatMemoryBuffer.from_defaults( + llm=self.llm, chat_history=self.chat_history + ) + + @step() + async def start(self, ctx: Context, ev: StartEvent) -> InputEvent: + ctx.data["streaming"] = getattr(ev, "streaming", False) + ctx.data["input"] = ev.input + + if self.system_prompt: + system_msg = ChatMessage( + role=MessageRole.SYSTEM, content=self.system_prompt + ) + self.memory.put(system_msg) + + user_input = ev.input + user_msg = ChatMessage(role=MessageRole.USER, content=user_input) + self.memory.put(user_msg) + + chat_history = self.memory.get() + return InputEvent(input=chat_history) + + @step(pass_context=True) + async def handle_llm_input( # type: ignore + self, + ctx: Context, + ev: InputEvent, + ) -> ExtractMissingCellsEvent | FillEvent | StopEvent: + """ + Handle an LLM input and decide the next step. + """ + chat_history: list[ChatMessage] = ev.input + + generator = self._tool_call_generator(chat_history) + + # Check for immediate tool call + is_tool_call = await generator.__anext__() + if is_tool_call: + full_response = await generator.__anext__() + tool_calls = self.llm.get_tool_calls_from_response(full_response) # type: ignore + for tool_call in tool_calls: + if tool_call.tool_name == self.extractor_tool.metadata.get_name(): + ctx.send_event(ExtractMissingCellsEvent(tool_call=tool_call)) + elif tool_call.tool_name == self.filling_tool.metadata.get_name(): + ctx.send_event(FillEvent(tool_call=tool_call)) + else: + # If no tool call, return the generator + return StopEvent(result=generator) + + @step() + async def extract_missing_cells( + self, ctx: Context, ev: ExtractMissingCellsEvent + ) -> InputEvent | FindAnswersEvent: + """ + Extract missing cells in a CSV file and generate questions to fill them. + """ + ctx.write_event_to_stream( + AgentRunEvent( + name="Extractor", + msg="Extracting missing cells", + ) + ) + # Call the extract questions tool + response = self._call_tool( + ctx, + agent_name="Extractor", + tool=self.extractor_tool, + tool_selection=ev.tool_call, + ) + if response.is_error: + return InputEvent(input=self.memory.get()) + + missing_cells = response.raw_output.get("missing_cells", []) + message = ChatMessage( + role=MessageRole.TOOL, + content=str(missing_cells), + additional_kwargs={ + "tool_call_id": ev.tool_call.tool_id, + "name": ev.tool_call.tool_name, + }, + ) + self.memory.put(message) + + if self.query_engine_tool is None: + # Fallback to input that query engine tool is not found so that cannot answer questions + self.memory.put( + ChatMessage( + role=MessageRole.ASSISTANT, + content="Extracted missing cells but query engine tool is not found so cannot answer questions. Ask user to upload file or connect to a knowledge base.", + ) + ) + return InputEvent(input=self.memory.get()) + + # Forward missing cells information to find answers step + return FindAnswersEvent(missing_cells=missing_cells) + + @step() + async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent: + """ + Call answer questions tool to query for the answers to the questions. + """ + ctx.write_event_to_stream( + AgentRunEvent( + name="Researcher", + msg="Finding answers for missing cells", + ) + ) + missing_cells = ev.missing_cells + # If missing cells information is not found, fallback to other tools + # It means that the extractor tool has not been called yet + # Fallback to input + if missing_cells is None: + ctx.write_event_to_stream( + AgentRunEvent( + name="Researcher", + msg="Error: Missing cells information not found. Fallback to other tools.", + ) + ) + message = ChatMessage( + role=MessageRole.TOOL, + content="Error: Missing cells information not found.", + additional_kwargs={ + "tool_call_id": ev.tool_call.tool_id, + "name": ev.tool_call.tool_name, + }, + ) + self.memory.put(message) + return InputEvent(input=self.memory.get()) + + cell_values: list[CellValue] = [] + # Iterate over missing cells and query for the answers + # and stream the progress + progress_id = str(uuid.uuid4()) + total_steps = len(missing_cells) + for i, cell in enumerate(missing_cells): + if cell.question_to_answer is None: + continue + ctx.write_event_to_stream( + AgentRunEvent( + name="Researcher", + msg=f"Querying for: {cell.question_to_answer}", + event_type=AgentRunEventType.PROGRESS, + data={ + "id": progress_id, + "total": total_steps, + "current": i, + }, + ) + ) + # Call query engine tool directly + answer = await self.query_engine_tool.acall(query=cell.question_to_answer) + cell_values.append( + CellValue( + row_index=cell.row_index, + column_index=cell.column_index, + value=str(answer), + ) + ) + self.memory.put( + ChatMessage( + role=MessageRole.ASSISTANT, + content=str(cell_values), + ) + ) + return InputEvent(input=self.memory.get()) + + @step() + async def fill_cells(self, ctx: Context, ev: FillEvent) -> InputEvent: + """ + Call fill cells tool to fill the missing cells with the answers. + """ + ctx.write_event_to_stream( + AgentRunEvent( + name="Processor", + msg="Filling missing cells", + ) + ) + # Call the fill cells tool + result = self._call_tool( + ctx, + agent_name="Processor", + tool=self.filling_tool, + tool_selection=ev.tool_call, + ) + if result.is_error: + return InputEvent(input=self.memory.get()) + + message = ChatMessage( + role=MessageRole.TOOL, + content=str(result.raw_output), + additional_kwargs={ + "tool_call_id": ev.tool_call.tool_id, + "name": ev.tool_call.tool_name, + }, + ) + self.memory.put(message) + return InputEvent(input=self.memory.get(), response=True) + + async def _tool_call_generator( + self, chat_history: list[ChatMessage] + ) -> AsyncGenerator[ChatMessage | bool, None]: + response_stream = await self.llm.astream_chat_with_tools( + [self.extractor_tool, self.filling_tool], + chat_history=chat_history, + ) + + full_response = None + yielded_indicator = False + async for chunk in response_stream: + if "tool_calls" not in chunk.message.additional_kwargs: + # Yield a boolean to indicate whether the response is a tool call + if not yielded_indicator: + yield False + yielded_indicator = True + + # if not a tool call, yield the chunks! + yield chunk + elif not yielded_indicator: + # Yield the indicator for a tool call + yield True + yielded_indicator = True + + full_response = chunk + + # Write the full response to memory and yield it + if full_response: + self.memory.put(full_response.message) + yield full_response + + def _call_tool( + self, + ctx: Context, + agent_name: str, + tool: FunctionTool, + tool_selection: ToolSelection, + ) -> ToolOutput: + """ + Safely call a tool and handle errors. + """ + try: + response: ToolOutput = tool.call(**tool_selection.tool_kwargs) + return response + except Exception as e: + ctx.write_event_to_stream( + AgentRunEvent( + name=agent_name, + msg=f"Error: {str(e)}", + ) + ) + message = ChatMessage( + role=MessageRole.TOOL, + content=f"Error: {str(e)}", + additional_kwargs={ + "tool_call_id": tool_selection.tool_id, + "name": tool.metadata.get_name(), + }, + ) + self.memory.put(message) + return ToolOutput( + content=f"Error: {str(e)}", + tool_name=tool.metadata.get_name(), + raw_input=tool_selection.tool_kwargs, + raw_output=None, + is_error=True, + ) diff --git a/templates/components/agents/python/form_filling/app/engine/engine.py b/templates/components/agents/python/form_filling/app/engine/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..68dbb6ce68b45207c346e46da2a86c0b6debc9fc --- /dev/null +++ b/templates/components/agents/python/form_filling/app/engine/engine.py @@ -0,0 +1,11 @@ +from typing import List, Optional + +from app.agents.form_filling import create_workflow +from llama_index.core.chat_engine.types import ChatMessage +from llama_index.core.workflow import Workflow + + +def get_chat_engine( + chat_history: Optional[List[ChatMessage]] = None, **kwargs +) -> Workflow: + return create_workflow(chat_history=chat_history, **kwargs) diff --git a/templates/components/agents/python/form_filling/sec_10k_template.csv b/templates/components/agents/python/form_filling/sec_10k_template.csv new file mode 100644 index 0000000000000000000000000000000000000000..ae920ab8ac17b81ca0836e4a31fa28752c82c449 --- /dev/null +++ b/templates/components/agents/python/form_filling/sec_10k_template.csv @@ -0,0 +1,17 @@ +Parameter,2023 Apple (AAPL),2023 Tesla (TSLA) +Revenue,, +Net Income,, +Earnings Per Share (EPS),, +Debt-to-Equity Ratio,, +Current Ratio,, +Gross Margin,, +Operating Margin,, +Net Profit Margin,, +Inventory Turnover,, +Accounts Receivable Turnover,, +Capital Expenditure,, +Research and Development Expense,, +Market Cap,, +Price to Earnings Ratio,, +Dividend Yield,, +Year-over-Year Growth Rate,, diff --git a/templates/components/engines/python/agent/tools/__init__.py b/templates/components/engines/python/agent/tools/__init__.py index f9ede661d3268d304334b08078e7c3c072ec1d51..dfc02edeacb39de52683bfdb5e377aa76df83e74 100644 --- a/templates/components/engines/python/agent/tools/__init__.py +++ b/templates/components/engines/python/agent/tools/__init__.py @@ -56,7 +56,7 @@ class ToolFactory: A dictionary of tool names to lists of FunctionTools if map_result is True, otherwise a list of FunctionTools. """ - tools: Union[Dict[str, List[FunctionTool]], List[FunctionTool]] = ( + tools: Union[Dict[str, FunctionTool], List[FunctionTool]] = ( {} if map_result else [] ) @@ -69,7 +69,9 @@ class ToolFactory: tool_type, tool_name, config ) if map_result: - tools[tool_name] = loaded_tools # type: ignore + tools.update( # type: ignore + {tool.metadata.name: tool for tool in loaded_tools} + ) else: tools.extend(loaded_tools) # type: ignore diff --git a/templates/components/engines/python/agent/tools/form_filling.py b/templates/components/engines/python/agent/tools/form_filling.py new file mode 100644 index 0000000000000000000000000000000000000000..488faaa81f764a065517e473de70504e7ca60a72 --- /dev/null +++ b/templates/components/engines/python/agent/tools/form_filling.py @@ -0,0 +1,224 @@ +import logging +import os +import uuid +from textwrap import dedent +from typing import Optional + +import pandas as pd +from app.services.file import FileService +from llama_index.core import Settings +from llama_index.core.prompts import PromptTemplate +from llama_index.core.tools import FunctionTool +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + + +class MissingCell(BaseModel): + """ + A missing cell in a table. + """ + + row_index: int = Field(description="The index of the row of the missing cell") + column_index: int = Field(description="The index of the column of the missing cell") + question_to_answer: str = Field( + description="The question to answer to fill the missing cell" + ) + + +class MissingCells(BaseModel): + """ + A list of missing cells. + """ + + missing_cells: list[MissingCell] = Field(description="The missing cells") + + +class CellValue(BaseModel): + row_index: int = Field(description="The row index of the cell") + column_index: int = Field(description="The column index of the cell") + value: str = Field( + description="The value of the cell. Should be a concise value (numerical value or specific value)" + ) + + +class FormFillingTool: + """ + Fill out missing cells in a CSV file using information from the knowledge base. + """ + + save_dir: str = os.path.join("output", "tools") + + # Default prompt for extracting questions + # Replace the default prompt with a custom prompt by setting the EXTRACT_QUESTIONS_PROMPT environment variable. + _default_extract_questions_prompt = dedent( + """ + You are a data analyst. You are given a table with missing cells. + Your task is to identify the missing cells and the questions needed to fill them. + IMPORTANT: Column indices should be 0-based, where the first data column is index 1 + (index 0 is typically the row names/index column). + + # Instructions: + - Understand the entire content of the table and the topics of the table. + - Identify the missing cells and the meaning of the data in the cells. + - For each missing cell, provide the row index and the correct column index (remember: first data column is 1). + - For each missing cell, provide the question needed to fill the cell (it's important to provide the question that is relevant to the topic of the table). + - Since the cell's value should be concise, the question should request a numerical answer or a specific value. + + # Example: + # | | Name | Age | City | + # |----|------|-----|------| + # | 0 | John | | Paris| + # | 1 | Mary | | | + # | 2 | | 30 | | + # + # Your thoughts: + # - The table is about people's names, ages, and cities. + # - Row: 1, Column: 1 (Age column), Question: "How old is Mary? Please provide only the numerical answer." + # - Row: 1, Column: 2 (City column), Question: "In which city does Mary live? Please provide only the city name." + + + Please provide your answer in the requested format. + # Here is your task: + + - Table content: + {table_content} + + - Your answer: + """ + ) + + def extract_questions( + self, + file_path: Optional[str] = None, + file_content: Optional[str] = None, + ) -> dict: + """ + Use this tool to extract missing cells in a CSV file and generate questions to fill them. + Pass either the path to the CSV file or the content of the CSV file. + + Args: + file_path (Optional[str]): The local file path to the CSV file to extract missing cells from (Don't pass a sandbox path). + file_content (Optional[str]): The content of the CSV file to extract missing cells from. + + Returns: + dict: A dictionary containing the missing cells and their corresponding questions. + """ + extract_questions_prompt = os.getenv( + "EXTRACT_QUESTIONS_PROMPT", self._default_extract_questions_prompt + ) + if file_path is None and file_content is None: + raise ValueError("Either `file_path` or `file_content` must be provided") + + table_content = None + + if file_path: + file_name, file_extension = self._get_file_name_and_extension( + file_path, file_content + ) + + try: + df = pd.read_csv(file_path) + except FileNotFoundError as e: + return { + "error": str(e), + "message": "Please check and update the file path and ensure it's a local path - not a sandbox path.", + } + + table_content = df.to_markdown() + if table_content is None: + raise ValueError("Could not convert the table to markdown") + if file_content: + table_content = file_content + + if table_content is None: + raise ValueError("Table content not found") + + response: MissingCells = Settings.llm.structured_predict( + output_cls=MissingCells, + prompt=PromptTemplate(extract_questions_prompt), + table_content=table_content, + ) + return response.model_dump() + + def fill_form( + self, + cell_values: list[CellValue], + file_path: Optional[str] = None, + file_content: Optional[str] = None, + ) -> dict: + """ + Use this tool to fill cell values into a CSV file. + Requires cell values to be used for filling out, as well as either the path to the CSV file or the content of the CSV file. + + Args: + cell_values (list[CellValue]): The cell values used to fill out the CSV file (call `extract_questions` and query engine to construct the cell values). + file_path (Optional[str]): The local file path to the CSV file that should be filled out (not as sandbox path). + file_content (Optional[str]): The content of the CSV file that should be filled out. + + Returns: + dict: A dictionary containing the content and metadata of the filled-out file. + """ + file_name, file_extension = self._get_file_name_and_extension( + file_path, file_content + ) + df = pd.read_csv(file_path) + + # Fill the dataframe with the cell values + filled_df = df.copy() + for cell_value in cell_values: + if not isinstance(cell_value, CellValue): + cell_value = CellValue(**cell_value) + filled_df.iloc[cell_value.row_index, cell_value.column_index] = ( + cell_value.value + ) + + # Save the filled table to a new CSV file + csv_content: str = filled_df.to_csv(index=False) + file_metadata = FileService.save_file( + content=csv_content, + file_name=f"{file_name}_filled.csv", + save_dir=self.save_dir, + ) + + new_content: str = filled_df.to_markdown() + result = { + "filled_content": new_content, + "filled_file": file_metadata, + } + return result + + def _get_file_name_and_extension( + self, file_path: Optional[str], file_content: Optional[str] + ) -> tuple[str, str]: + if file_path is None and file_content is None: + raise ValueError("Either `file_path` or `file_content` must be provided") + + if file_path is None: + file_name = str(uuid.uuid4()) + file_extension = ".csv" + else: + file_name, file_extension = os.path.splitext(file_path) + if file_extension != ".csv": + raise ValueError("Form filling is only supported for CSV files") + + return file_name, file_extension + + def _save_output(self, file_name: str, output: str) -> dict: + """ + Save the output to a file. + """ + file_metadata = FileService.save_file( + content=output, + file_name=file_name, + save_dir=self.save_dir, + ) + return file_metadata.model_dump() + + +def get_tools(**kwargs): + tool = FormFillingTool() + return [ + FunctionTool.from_defaults(tool.extract_questions), + FunctionTool.from_defaults(tool.fill_form), + ] diff --git a/templates/components/multiagent/python/app/api/routers/vercel_response.py b/templates/components/multiagent/python/app/api/routers/vercel_response.py index 2c23b6a758bb0ced0b6f308bd83cb9b0c497c310..4298553243168d806ae52bb6a3ef7662eca3925f 100644 --- a/templates/components/multiagent/python/app/api/routers/vercel_response.py +++ b/templates/components/multiagent/python/app/api/routers/vercel_response.py @@ -1,12 +1,11 @@ import asyncio import json import logging -from typing import AsyncGenerator, List +from typing import AsyncGenerator, Awaitable, List from aiostream import stream from app.api.routers.models import ChatData, Message from app.api.services.suggestion import NextQuestionSuggestion -from app.workflows.single import AgentRunEvent, AgentRunResult from fastapi import Request from fastapi.responses import StreamingResponse @@ -55,8 +54,8 @@ class VercelStreamResponse(StreamingResponse): self, request: Request, chat_data: ChatData, - event_handler: AgentRunResult | AsyncGenerator, - events: AsyncGenerator[AgentRunEvent, None], + event_handler: Awaitable, + events: AsyncGenerator, verbose: bool = True, ): # Yield the text response @@ -64,15 +63,17 @@ class VercelStreamResponse(StreamingResponse): result = await event_handler final_response = "" - if isinstance(result, AgentRunResult): - for token in result.response.message.content: - final_response += token - yield self.convert_text(token) - if isinstance(result, AsyncGenerator): async for token in result: - final_response += token.delta + final_response += str(token.delta) yield self.convert_text(token.delta) + else: + if hasattr(result, "response"): + content = result.response.message.content + if content: + for token in content: + final_response += str(token) + yield self.convert_text(token) # Generate next questions if next question prompt is configured question_data = await self._generate_next_questions( @@ -86,7 +87,7 @@ class VercelStreamResponse(StreamingResponse): # Yield the events from the event handler async def _event_generator(): async for event in events: - event_response = self._event_to_response(event) + event_response = event.to_response() if verbose: logger.debug(event_response) if event_response is not None: @@ -95,13 +96,6 @@ class VercelStreamResponse(StreamingResponse): combine = stream.merge(_chat_response_generator(), _event_generator()) return combine - @staticmethod - def _event_to_response(event: AgentRunEvent) -> dict: - return { - "type": "agent", - "data": {"agent": event.name, "text": event.msg}, - } - @classmethod def convert_text(cls, token: str): # Escape newlines and double quotes to avoid breaking the stream diff --git a/templates/components/multiagent/python/app/workflows/single.py b/templates/components/multiagent/python/app/workflows/single.py index a598bdf65306788ffa55b9af72fa41150769f687..401fb80061ac2f72bd4204786e5524e56ee87280 100644 --- a/templates/components/multiagent/python/app/workflows/single.py +++ b/templates/components/multiagent/python/app/workflows/single.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from enum import Enum from typing import Any, AsyncGenerator, List, Optional from llama_index.core.llms import ChatMessage, ChatResponse @@ -15,7 +16,7 @@ from llama_index.core.workflow import ( Workflow, step, ) -from pydantic import BaseModel +from pydantic import BaseModel, Field class InputEvent(Event): @@ -26,17 +27,27 @@ class ToolCallEvent(Event): tool_calls: list[ToolSelection] -class AgentRunEvent(Event): - name: str - _msg: str +class AgentRunEventType(Enum): + TEXT = "text" + PROGRESS = "progress" - @property - def msg(self): - return self._msg - @msg.setter - def msg(self, value): - self._msg = value +class AgentRunEvent(Event): + name: str + msg: str + event_type: AgentRunEventType = Field(default=AgentRunEventType.TEXT) + data: Optional[dict] = None + + def to_response(self) -> dict: + return { + "type": "agent", + "data": { + "name": self.name, + "type": self.event_type.value, + "msg": self.msg, + "data": self.data, + }, + } class AgentRunResult(BaseModel): diff --git a/templates/types/streaming/fastapi/app/api/routers/models.py b/templates/types/streaming/fastapi/app/api/routers/models.py index a8bb991eb746577a14014b2f3d1e70df0f8b6df9..31f2fa46fef212080c059ac25ed65b064167c178 100644 --- a/templates/types/streaming/fastapi/app/api/routers/models.py +++ b/templates/types/streaming/fastapi/app/api/routers/models.py @@ -60,9 +60,11 @@ class AnnotationFileData(BaseModel): # Include document IDs if it's available if file.refs is not None: default_content += f"Document IDs: {file.refs}\n" - # Include sandbox file path + # file path sandbox_file_path = f"/tmp/{file.name}" + local_file_path = f"output/uploaded/{file.name}" default_content += f"Sandbox file path (instruction: only use sandbox path for artifact or code interpreter tool): {sandbox_file_path}\n" + default_content += f"Local file path (instruction: Use for local tools: form filling, extractor): {local_file_path}\n" return default_content def to_llm_content(self) -> Optional[str]: @@ -128,24 +130,29 @@ class ChatData(BaseModel): def get_last_message_content(self) -> str: """ - Get the content of the last message along with the data content if available. - Fallback to use data content from previous messages + Get the content of the last message along with the data content from all user messages """ if len(self.messages) == 0: raise ValueError("There is not any message in the chat") + last_message = self.messages[-1] message_content = last_message.content - for message in reversed(self.messages): + + # Collect annotation contents from all user messages + all_annotation_contents: List[str] = [] + for message in self.messages: if message.role == MessageRole.USER and message.annotations is not None: annotation_contents = filter( None, [annotation.to_content() for annotation in message.annotations], ) - if not annotation_contents: - continue - annotation_text = "\n".join(annotation_contents) - message_content = f"{message_content}\n{annotation_text}" - break + all_annotation_contents.extend(annotation_contents) + + # Add all annotation contents if any exist + if len(all_annotation_contents) > 0: + annotation_text = "\n".join(all_annotation_contents) + message_content = f"{message_content}\n{annotation_text}" + return message_content def _get_agent_messages(self, max_messages: int = 10) -> List[str]: diff --git a/templates/types/streaming/fastapi/app/services/file.py b/templates/types/streaming/fastapi/app/services/file.py index a551ea5f1a79736c5c346ee8c2269d59dcb756dd..b9de026e2c3ea6e75c88af70c911539eec722981 100644 --- a/templates/types/streaming/fastapi/app/services/file.py +++ b/templates/types/streaming/fastapi/app/services/file.py @@ -6,7 +6,7 @@ import re import uuid from io import BytesIO from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple from llama_index.core import VectorStoreIndex from llama_index.core.ingestion import IngestionPipeline @@ -14,7 +14,6 @@ from llama_index.core.readers.file.base import ( _try_loading_included_file_formats as get_file_loaders_map, ) from llama_index.core.schema import Document -from llama_index.core.tools.function_tool import FunctionTool from llama_index.indices.managed.llama_cloud.base import LlamaCloudIndex from llama_index.readers.file import FlatReader from pydantic import BaseModel, Field @@ -78,10 +77,8 @@ class FileService: save_dir=PRIVATE_STORE_PATH, ) - tools = _get_available_tools() - code_executor_tools = ["interpreter", "artifact"] - # If the file is CSV and there is a code executor tool, we don't need to index. - if extension == "csv" and any(tool in tools for tool in code_executor_tools): + # Don't index csv files (they are handled by tools) + if extension == "csv": return document_file else: # Insert the file into the index and update document ids to the file metadata @@ -283,18 +280,3 @@ def _default_file_loaders_map(): default_loaders[".txt"] = FlatReader default_loaders[".csv"] = FlatReader return default_loaders - - -def _get_available_tools() -> Dict[str, List[FunctionTool]]: - try: - from app.engine.tools import ToolFactory # type: ignore - except ImportError: - logger.warning("ToolFactory not found, no tools will be available") - return {} - - try: - tools = ToolFactory.from_env(map_result=True) - return tools # type: ignore - except Exception as e: - logger.error(f"Error loading tools from environment: {str(e)}") - raise ValueError(f"Failed to get available tools: {str(e)}") from e diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx index 33b1b92d90ae83d27df195b8dc821edff14f14e2..0e5c318b31f29a69fac2ce62ca373c4b4c6cfc31 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx @@ -122,13 +122,19 @@ export default function ChatInput( config={{ allowedExtensions: ALLOWED_EXTENSIONS, disabled: props.isLoading, + multiple: true, }} /> {process.env.NEXT_PUBLIC_USE_LLAMACLOUD === "true" && props.setRequestData && ( <LlamaCloudSelector setRequestData={props.setRequestData} /> )} - <Button type="submit" disabled={props.isLoading || !props.input.trim()}> + <Button + type="submit" + disabled={ + props.isLoading || (!props.input.trim() && files.length === 0) + } + > Send message </Button> </div> diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-agent-events.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-agent-events.tsx index 8fea31dfab2aafc37ad5b16952f8bd4227953e64..a385754bc2bb2d5a38dd8dae5b18862210eeca16 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-agent-events.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message/chat-agent-events.tsx @@ -9,7 +9,8 @@ import { DrawerTitle, DrawerTrigger, } from "../../drawer"; -import { AgentEventData } from "../index"; +import { Progress } from "../../progress"; +import { AgentEventData, ProgressData } from "../index"; import Markdown from "./markdown"; const AgentIcons: Record<string, LucideIcon> = { @@ -20,10 +21,19 @@ const AgentIcons: Record<string, LucideIcon> = { publisher: icons.BookCheck, }; +type StepText = { + text: string; +}; + +type StepProgress = { + text: string; + progress: ProgressData; +}; + type MergedEvent = { agent: string; - texts: string[]; icon: LucideIcon; + steps: Array<StepText | StepProgress>; }; export function ChatAgentEvents({ @@ -52,6 +62,53 @@ export function ChatAgentEvents({ const MAX_TEXT_LENGTH = 150; +function TextContent({ agent, step }: { agent: string; step: StepText }) { + const { displayText, showMore } = useMemo( + () => ({ + displayText: step.text.slice(0, MAX_TEXT_LENGTH), + showMore: step.text.length > MAX_TEXT_LENGTH, + }), + [step.text], + ); + + return ( + <> + <div className="whitespace-break-spaces"> + {!showMore && <span>{step.text}</span>} + {showMore && ( + <div> + <span>{displayText}...</span> + <AgentEventDialog content={step.text} title={`Agent "${agent}"`}> + <span className="font-semibold underline cursor-pointer ml-2"> + Show more + </span> + </AgentEventDialog> + </div> + )} + </div> + </> + ); +} + +function ProgressContent({ step }: { step: StepProgress }) { + const progressValue = + step.progress.total !== 0 + ? Math.round(((step.progress.current + 1) / step.progress.total) * 100) + : 0; + + return ( + <div className="space-y-2 mt-2"> + {step.text && ( + <p className="text-sm text-muted-foreground">{step.text}</p> + )} + <Progress value={progressValue} className="w-full h-2" /> + <p className="text-sm text-muted-foreground"> + Processing {step.progress.current + 1} of {step.progress.total} steps... + </p> + </div> + ); +} + function AgentEventContent({ event, isLast, @@ -61,8 +118,19 @@ function AgentEventContent({ isLast: boolean; isFinished: boolean; }) { - const { agent, texts } = event; + const { agent, steps } = event; const AgentIcon = event.icon; + const textSteps = steps.filter((step) => !("progress" in step)); + const progressSteps = steps.filter( + (step) => "progress" in step, + ) as StepProgress[]; + // We only show progress at the last step + // TODO: once we support steps that work in parallel, we need to update this + const lastProgressStep = + progressSteps.length > 0 + ? progressSteps[progressSteps.length - 1] + : undefined; + return ( <div className="flex gap-4 border-b pb-4 items-center fadein-agent"> <div className="w-[100px] flex flex-col items-center gap-2"> @@ -79,26 +147,20 @@ function AgentEventContent({ </div> <span className="font-bold">{agent}</span> </div> - <ul className="flex-1 list-decimal space-y-2"> - {texts.map((text, index) => ( - <li className="whitespace-break-spaces" key={index}> - {text.length <= MAX_TEXT_LENGTH && <span>{text}</span>} - {text.length > MAX_TEXT_LENGTH && ( - <div> - <span>{text.slice(0, MAX_TEXT_LENGTH)}...</span> - <AgentEventDialog - content={text} - title={`Agent "${agent}" - Step: ${index + 1}`} - > - <span className="font-semibold underline cursor-pointer ml-2"> - Show more - </span> - </AgentEventDialog> - </div> - )} - </li> - ))} - </ul> + {textSteps.length > 0 && ( + <div className="flex-1"> + <ul className="list-decimal space-y-2"> + {textSteps.map((step, index) => ( + <li key={index}> + <TextContent agent={agent} step={step} /> + </li> + ))} + </ul> + {lastProgressStep && !isFinished && ( + <ProgressContent step={lastProgressStep} /> + )} + </div> + )} </div> ); } @@ -136,15 +198,22 @@ function mergeAdjacentEvents(events: AgentEventData[]): MergedEvent[] { for (const event of events) { const lastMergedEvent = mergedEvents[mergedEvents.length - 1]; + const eventStep: StepText | StepProgress = event.data + ? ({ + text: event.text, + progress: event.data, + } as StepProgress) + : ({ + text: event.text, + } as StepText); + if (lastMergedEvent && lastMergedEvent.agent === event.agent) { - // If the last event in mergedEvents has the same non-null agent, add the title to it - lastMergedEvent.texts.push(event.text); + lastMergedEvent.steps.push(eventStep); } else { - // Otherwise, create a new merged event mergedEvents.push({ agent: event.agent, - texts: [event.text], - icon: AgentIcons[event.agent] ?? icons.Bot, + steps: [eventStep], + icon: AgentIcons[event.agent.toLowerCase()] ?? icons.Bot, }); } } diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/index.ts b/templates/types/streaming/nextjs/app/components/ui/chat/index.ts index b88dfd423b52f293ecbb6c119717f7f781000173..e78a8c27745e2d3077afe7f3f676e59f9e7285c6 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/index.ts +++ b/templates/types/streaming/nextjs/app/components/ui/chat/index.ts @@ -56,9 +56,17 @@ export type EventData = { title: string; }; +export type ProgressData = { + id: string; + total: number; + current: number; +}; + export type AgentEventData = { agent: string; text: string; + type: "text" | "progress"; + data?: ProgressData; }; export type ToolData = { diff --git a/templates/types/streaming/nextjs/app/components/ui/file-uploader.tsx b/templates/types/streaming/nextjs/app/components/ui/file-uploader.tsx index e42a267d18cbe76391d1decad0acca8fdf4dc295..15f9e4035df338c1b799be003e3977dc4a593508 100644 --- a/templates/types/streaming/nextjs/app/components/ui/file-uploader.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/file-uploader.tsx @@ -12,6 +12,7 @@ export interface FileUploaderProps { allowedExtensions?: string[]; checkExtension?: (extension: string) => string | null; disabled: boolean; + multiple?: boolean; }; onFileUpload: (file: File) => Promise<void>; onFileError?: (errMsg: string) => void; @@ -26,6 +27,7 @@ export default function FileUploader({ onFileError, }: FileUploaderProps) { const [uploading, setUploading] = useState(false); + const [remainingFiles, setRemainingFiles] = useState<number>(0); const inputId = config?.inputId || DEFAULT_INPUT_ID; const fileSizeLimit = config?.fileSizeLimit || DEFAULT_FILE_SIZE_LIMIT; @@ -50,30 +52,51 @@ export default function FileUploader({ }; const onFileChange = async (e: ChangeEvent<HTMLInputElement>) => { - const file = e.target.files?.[0]; - if (!file) return; + const files = Array.from(e.target.files || []); + if (!files.length) return; setUploading(true); - await handleUpload(file); + + await handleUpload(files); + resetInput(); setUploading(false); }; - const handleUpload = async (file: File) => { + const handleUpload = async (files: File[]) => { const onFileUploadError = onFileError || window.alert; - const fileExtension = file.name.split(".").pop() || ""; - const extensionFileError = checkExtension(fileExtension); - if (extensionFileError) { - return onFileUploadError(extensionFileError); + // Validate files + // If multiple files with image or multiple images + if ( + files.length > 1 && + files.some((file) => file.type.startsWith("image/")) + ) { + onFileUploadError("Multiple files with image are not supported"); + return; } - if (isFileSizeExceeded(file)) { - return onFileUploadError( - `File size exceeded. Limit is ${fileSizeLimit / 1024 / 1024} MB`, - ); + for (const file of files) { + const fileExtension = file.name.split(".").pop() || ""; + const extensionFileError = checkExtension(fileExtension); + if (extensionFileError) { + onFileUploadError(extensionFileError); + return; + } + + if (isFileSizeExceeded(file)) { + onFileUploadError( + `File size exceeded. Limit is ${fileSizeLimit / 1024 / 1024} MB`, + ); + return; + } } - await onFileUpload(file); + setRemainingFiles(files.length); + for (const file of files) { + await onFileUpload(file); + setRemainingFiles((prev) => prev - 1); + } + setRemainingFiles(0); }; return ( @@ -85,17 +108,25 @@ export default function FileUploader({ onChange={onFileChange} accept={allowedExtensions?.join(",")} disabled={config?.disabled || uploading} + multiple={config?.multiple} /> <label htmlFor={inputId} className={cn( buttonVariants({ variant: "secondary", size: "icon" }), - "cursor-pointer", + "cursor-pointer relative", uploading && "opacity-50", )} > {uploading ? ( - <Loader2 className="h-4 w-4 animate-spin" /> + <div className="relative flex items-center justify-center h-full w-full"> + <Loader2 className="h-6 w-6 animate-spin absolute" /> + {remainingFiles > 0 && ( + <span className="text-xs absolute inset-0 flex items-center justify-center"> + {remainingFiles} + </span> + )} + </div> ) : ( <Paperclip className="-rotate-45 w-4 h-4" /> )} diff --git a/templates/types/streaming/nextjs/app/components/ui/progress.tsx b/templates/types/streaming/nextjs/app/components/ui/progress.tsx new file mode 100644 index 0000000000000000000000000000000000000000..d286f0f39870c4bf112ed4b1415298ec9f8272b5 --- /dev/null +++ b/templates/types/streaming/nextjs/app/components/ui/progress.tsx @@ -0,0 +1,27 @@ +"use client"; + +import * as ProgressPrimitive from "@radix-ui/react-progress"; +import * as React from "react"; +import { cn } from "./lib/utils"; + +const Progress = React.forwardRef< + React.ElementRef<typeof ProgressPrimitive.Root>, + React.ComponentPropsWithoutRef<typeof ProgressPrimitive.Root> +>(({ className, value, ...props }, ref) => ( + <ProgressPrimitive.Root + ref={ref} + className={cn( + "relative h-4 w-full overflow-hidden rounded-full bg-secondary", + className, + )} + {...props} + > + <ProgressPrimitive.Indicator + className="h-full w-full flex-1 bg-primary transition-all" + style={{ transform: `translateX(-${100 - (value || 0)}%)` }} + /> + </ProgressPrimitive.Root> +)); +Progress.displayName = ProgressPrimitive.Root.displayName; + +export { Progress }; diff --git a/templates/types/streaming/nextjs/package.json b/templates/types/streaming/nextjs/package.json index 090cfc31cd015d13eb1f07ddceac088bd4ad6b82..d54224829d3b496e27ad771780cc4be3cf2ae060 100644 --- a/templates/types/streaming/nextjs/package.json +++ b/templates/types/streaming/nextjs/package.json @@ -15,6 +15,7 @@ "@llamaindex/pdf-viewer": "^1.1.3", "@radix-ui/react-collapsible": "^1.0.3", "@radix-ui/react-hover-card": "^1.0.7", + "@radix-ui/react-progress": "^1.1.0", "@radix-ui/react-select": "^2.1.1", "@radix-ui/react-slot": "^1.0.2", "@radix-ui/react-tabs": "^1.1.0",