diff --git a/.changeset/light-parrots-work.md b/.changeset/light-parrots-work.md new file mode 100644 index 0000000000000000000000000000000000000000..4fae766ebd074bed228f8d7fb877f8250cd7442b --- /dev/null +++ b/.changeset/light-parrots-work.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Optimize generated workflow code for Python diff --git a/e2e/shared/multiagent_template.spec.ts b/e2e/shared/multiagent_template.spec.ts index f470b34925f674686aec2979b783e06d1e2d6a79..955e73453eac909fb2c296a240bd9c95bdb50d3c 100644 --- a/e2e/shared/multiagent_template.spec.ts +++ b/e2e/shared/multiagent_template.spec.ts @@ -18,7 +18,7 @@ const templateUI: TemplateUI = "shadcn"; const templatePostInstallAction: TemplatePostInstallAction = "runApp"; const appType: AppType = templateFramework === "nextjs" ? "" : "--frontend"; const userMessage = "Write a blog post about physical standards for letters"; -const templateAgents = ["financial_report", "blog"]; +const templateAgents = ["financial_report", "blog", "form_filling"]; for (const agents of templateAgents) { test.describe(`Test multiagent template ${agents} ${templateFramework} ${dataSource} ${templateUI} ${appType} ${templatePostInstallAction}`, async () => { @@ -26,6 +26,10 @@ for (const agents of templateAgents) { process.platform !== "linux" || process.env.DATASOURCE === "--no-files", "The multiagent template currently only works with files. We also only run on Linux to speed up tests.", ); + test.skip( + agents === "form_filling" && templateFramework !== "fastapi", + "Form filling is currently only supported with FastAPI.", + ); let port: number; let externalPort: number; let cwd: string; @@ -68,6 +72,10 @@ for (const agents of templateAgents) { test("Frontend should be able to submit a message and receive the start of a streamed response", async ({ page, }) => { + test.skip( + agents === "financial_report" || agents === "form_filling", + "Skip chat tests for financial report and form filling.", + ); await page.goto(`http://localhost:${port}`); await page.fill("form textarea", userMessage); diff --git a/templates/components/agents/python/blog/README-template.md b/templates/components/agents/python/blog/README-template.md index 162de0c8a3a2e75dfe72eb67b679965cdf45f874..5d17a5ac883a01cd3d8d417e67eb3331596cbf9e 100644 --- a/templates/components/agents/python/blog/README-template.md +++ b/templates/components/agents/python/blog/README-template.md @@ -8,9 +8,9 @@ This example is using three agents to generate a blog post: There are three different methods how the agents can interact to reach their goal: -1. [Choreography](./app/examples/choreography.py) - the agents decide themselves to delegate a task to another agent -1. [Orchestrator](./app/examples/orchestrator.py) - a central orchestrator decides which agent should execute a task -1. [Explicit Workflow](./app/examples/workflow.py) - a pre-defined workflow specific for the task is used to execute the tasks +1. [Choreography](./app/agents/choreography.py) - the agents decide themselves to delegate a task to another agent +1. [Orchestrator](./app/agents/orchestrator.py) - a central orchestrator decides which agent should execute a task +1. [Explicit Workflow](./app/agents/workflow.py) - a pre-defined workflow specific for the task is used to execute the tasks ## Getting Started diff --git a/templates/components/agents/python/blog/app/workflows/__init__.py b/templates/components/agents/python/blog/app/workflows/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61ab6589e33242a19f5734aac4519c6f5eeac934 --- /dev/null +++ b/templates/components/agents/python/blog/app/workflows/__init__.py @@ -0,0 +1,3 @@ +from .blog import create_workflow + +__all__ = ["create_workflow"] diff --git a/templates/components/agents/python/blog/app/engine/engine.py b/templates/components/agents/python/blog/app/workflows/blog.py similarity index 76% rename from templates/components/agents/python/blog/app/engine/engine.py rename to templates/components/agents/python/blog/app/workflows/blog.py index 78a79c4593494c8cb4382d2760d3962234b825f4..6c6eba526fbcd07626116c4429a7e396bc7a2c25 100644 --- a/templates/components/agents/python/blog/app/engine/engine.py +++ b/templates/components/agents/python/blog/app/workflows/blog.py @@ -4,17 +4,18 @@ from typing import List, Optional from app.agents.choreography import create_choreography from app.agents.orchestrator import create_orchestrator -from app.agents.workflow import create_workflow +from app.agents.workflow import create_workflow as create_blog_workflow from llama_index.core.chat_engine.types import ChatMessage from llama_index.core.workflow import Workflow logger = logging.getLogger("uvicorn") -def get_chat_engine( +def create_workflow( chat_history: Optional[List[ChatMessage]] = None, **kwargs ) -> Workflow: - # TODO: the EXAMPLE_TYPE could be passed as a chat config parameter? + # Chat filters are not supported yet + kwargs.pop("filters", None) agent_type = os.getenv("EXAMPLE_TYPE", "").lower() match agent_type: case "choreography": @@ -22,7 +23,7 @@ def get_chat_engine( case "orchestrator": agent = create_orchestrator(chat_history, **kwargs) case _: - agent = create_workflow(chat_history, **kwargs) + agent = create_blog_workflow(chat_history, **kwargs) logger.info(f"Using agent pattern: {agent_type}") diff --git a/templates/components/multiagent/python/app/workflows/multi.py b/templates/components/agents/python/blog/app/workflows/multi.py similarity index 100% rename from templates/components/multiagent/python/app/workflows/multi.py rename to templates/components/agents/python/blog/app/workflows/multi.py diff --git a/templates/components/multiagent/python/app/workflows/planner.py b/templates/components/agents/python/blog/app/workflows/planner.py similarity index 100% rename from templates/components/multiagent/python/app/workflows/planner.py rename to templates/components/agents/python/blog/app/workflows/planner.py diff --git a/templates/components/multiagent/python/app/workflows/single.py b/templates/components/agents/python/blog/app/workflows/single.py similarity index 99% rename from templates/components/multiagent/python/app/workflows/single.py rename to templates/components/agents/python/blog/app/workflows/single.py index 401fb80061ac2f72bd4204786e5524e56ee87280..4ab905a51672ec32a6bbe6d84eb2f88e371406f3 100644 --- a/templates/components/multiagent/python/app/workflows/single.py +++ b/templates/components/agents/python/blog/app/workflows/single.py @@ -42,9 +42,9 @@ class AgentRunEvent(Event): return { "type": "agent", "data": { - "name": self.name, + "agent": self.name, "type": self.event_type.value, - "msg": self.msg, + "text": self.msg, "data": self.data, }, } diff --git a/templates/components/agents/python/financial_report/README-template.md b/templates/components/agents/python/financial_report/README-template.md index ba6d24fba8d3ae42637a680222dff0bba91060c4..0f3beb238e9c2c4fa8a57565fd951aa1f2270874 100644 --- a/templates/components/agents/python/financial_report/README-template.md +++ b/templates/components/agents/python/financial_report/README-template.md @@ -33,7 +33,7 @@ curl --location 'localhost:8000/api/chat' \ --data '{ "messages": [{ "role": "user", "content": "Create a report comparing the finances of Apple and Tesla" }] }' ``` -You can start editing the API by modifying `app/api/routers/chat.py` or `app/financial_report/workflow.py`. The API auto-updates as you save the files. +You can start editing the API by modifying `app/api/routers/chat.py` or `app/workflows/financial_report.py`. The API auto-updates as you save the files. Open [http://localhost:8000/docs](http://localhost:8000/docs) with your browser to see the Swagger UI of the API. diff --git a/templates/components/agents/python/financial_report/app/agents/analyst.py b/templates/components/agents/python/financial_report/app/agents/analyst.py deleted file mode 100644 index 877017f6eb6db04098f4e8882e007e3f325c4048..0000000000000000000000000000000000000000 --- a/templates/components/agents/python/financial_report/app/agents/analyst.py +++ /dev/null @@ -1,47 +0,0 @@ -from textwrap import dedent -from typing import List, Tuple - -from app.engine.tools import ToolFactory -from app.workflows.single import FunctionCallingAgent -from llama_index.core.chat_engine.types import ChatMessage -from llama_index.core.tools import FunctionTool - - -def _get_analyst_params() -> Tuple[List[type[FunctionTool]], str, str]: - tools = [] - prompt_instructions = dedent( - """ - You are an expert in analyzing financial data. - You are given a task and a set of financial data to analyze. Your task is to analyze the financial data and return a report. - Your response should include a detailed analysis of the financial data, including any trends, patterns, or insights that you find. - Construct the analysis in a textual format like tables would be great! - Don't need to synthesize the data, just analyze and provide your findings. - Always use the provided information, don't make up any information yourself. - """ - ) - description = "Expert in analyzing financial data" - configured_tools = ToolFactory.from_env(map_result=True) - # Check if the interpreter tool is configured - if "interpret" in configured_tools.keys(): - tools.append(configured_tools["interpret"]) - prompt_instructions += dedent(""" - You are able to visualize the financial data using code interpreter tool. - It's very useful to create and include visualizations to the report (make sure you include the right code and data for the visualization). - Never include any code into the report, just the visualization. - """) - description += ( - ", able to visualize the financial data using code interpreter tool." - ) - return tools, prompt_instructions, description - - -def create_analyst(chat_history: List[ChatMessage]): - tools, prompt_instructions, description = _get_analyst_params() - - return FunctionCallingAgent( - name="analyst", - tools=tools, - description=description, - system_prompt=dedent(prompt_instructions), - chat_history=chat_history, - ) diff --git a/templates/components/agents/python/financial_report/app/agents/reporter.py b/templates/components/agents/python/financial_report/app/agents/reporter.py deleted file mode 100644 index 496884ca261f3aee90a578941f2e577ea51fef40..0000000000000000000000000000000000000000 --- a/templates/components/agents/python/financial_report/app/agents/reporter.py +++ /dev/null @@ -1,44 +0,0 @@ -from textwrap import dedent -from typing import List, Tuple - -from app.engine.tools import ToolFactory -from app.workflows.single import FunctionCallingAgent -from llama_index.core.chat_engine.types import ChatMessage -from llama_index.core.tools import BaseTool - - -def _get_reporter_params( - chat_history: List[ChatMessage], -) -> Tuple[List[type[BaseTool]], str, str]: - tools: List[type[BaseTool]] = [] - description = "Expert in representing a financial report" - prompt_instructions = dedent( - """ - You are a report generation assistant tasked with producing a well-formatted report given parsed context. - Given a comprehensive analysis of the user request, your task is to synthesize the information and return a well-formatted report. - - ## Instructions - You are responsible for representing the analysis in a well-formatted report. If tables or visualizations provided, add them to the right sections that are most relevant. - Use only the provided information to create the report. Do not make up any information yourself. - Finally, the report should be presented in markdown format. - """ - ) - configured_tools = ToolFactory.from_env(map_result=True) - if "generate_document" in configured_tools: # type: ignore - tools.append(configured_tools["generate_document"]) # type: ignore - prompt_instructions += ( - "\nYou are also able to generate a file document (PDF/HTML) of the report." - ) - description += " and generate a file document (PDF/HTML) of the report." - return tools, description, prompt_instructions - - -def create_reporter(chat_history: List[ChatMessage]): - tools, description, prompt_instructions = _get_reporter_params(chat_history) - return FunctionCallingAgent( - name="reporter", - tools=tools, - description=description, - system_prompt=prompt_instructions, - chat_history=chat_history, - ) diff --git a/templates/components/agents/python/financial_report/app/agents/researcher.py b/templates/components/agents/python/financial_report/app/agents/researcher.py deleted file mode 100644 index 4d1459a5f790b7329c43e43a3c4dda8a35179f08..0000000000000000000000000000000000000000 --- a/templates/components/agents/python/financial_report/app/agents/researcher.py +++ /dev/null @@ -1,105 +0,0 @@ -import os -from textwrap import dedent -from typing import List, Optional - -from app.engine.index import IndexConfig, get_index -from app.workflows.single import FunctionCallingAgent -from llama_index.core.chat_engine.types import ChatMessage -from llama_index.core.tools import BaseTool, QueryEngineTool, ToolMetadata -from llama_index.indices.managed.llama_cloud import LlamaCloudIndex - - -def _create_query_engine_tools(params=None) -> Optional[list[type[BaseTool]]]: - """ - Provide an agent worker that can be used to query the index. - """ - # Add query tool if index exists - index_config = IndexConfig(**(params or {})) - index = get_index(index_config) - if index is None: - return None - - top_k = int(os.getenv("TOP_K", 5)) - - # Construct query engine tools - tools = [] - # If index is LlamaCloudIndex, we need to add chunk and doc retriever tools - if isinstance(index, LlamaCloudIndex): - # Document retriever - doc_retriever = index.as_query_engine( - retriever_mode="files_via_content", - similarity_top_k=top_k, - ) - chunk_retriever = index.as_query_engine( - retriever_mode="chunks", - similarity_top_k=top_k, - ) - tools.append( - QueryEngineTool( - query_engine=doc_retriever, - metadata=ToolMetadata( - name="document_retriever", - description=dedent( - """ - Document retriever that retrieves entire documents from the corpus. - ONLY use for research questions that may require searching over entire research reports. - Will be slower and more expensive than chunk-level retrieval but may be necessary. - """ - ), - ), - ) - ) - tools.append( - QueryEngineTool( - query_engine=chunk_retriever, - metadata=ToolMetadata( - name="chunk_retriever", - description=dedent( - """ - Retrieves a small set of relevant document chunks from the corpus. - Use for research questions that want to look up specific facts from the knowledge corpus, - and need entire documents. - """ - ), - ), - ) - ) - else: - query_engine = index.as_query_engine( - **({"similarity_top_k": top_k} if top_k != 0 else {}) - ) - tools.append( - QueryEngineTool( - query_engine=query_engine, - metadata=ToolMetadata( - name="retrieve_information", - description="Use this tool to retrieve information about the text corpus from the index.", - ), - ) - ) - return tools - - -def create_researcher(chat_history: List[ChatMessage], **kwargs): - """ - Researcher is an agent that take responsibility for using tools to complete a given task. - """ - tools = _create_query_engine_tools(**kwargs) - - if tools is None: - raise ValueError("No tools found for researcher agent") - - return FunctionCallingAgent( - name="researcher", - tools=tools, - description="expert in retrieving any unknown content from the corpus", - system_prompt=dedent( - """ - You are a researcher agent. You are responsible for retrieving information from the corpus. - ## Instructions - + Don't synthesize the information, just return the whole retrieved information. - + Don't need to retrieve the information that is already provided in the chat history and response with: "There is no new information, please reuse the information from the conversation." - """ - ), - chat_history=chat_history, - ) diff --git a/templates/components/agents/python/financial_report/app/agents/workflow.py b/templates/components/agents/python/financial_report/app/agents/workflow.py deleted file mode 100644 index 1c77581904c6f74ebbaa093540fd041dff5eab91..0000000000000000000000000000000000000000 --- a/templates/components/agents/python/financial_report/app/agents/workflow.py +++ /dev/null @@ -1,177 +0,0 @@ -from textwrap import dedent -from typing import AsyncGenerator, List, Optional - -from app.agents.analyst import create_analyst -from app.agents.reporter import create_reporter -from app.agents.researcher import create_researcher -from app.workflows.single import AgentRunEvent, AgentRunResult, FunctionCallingAgent -from llama_index.core.chat_engine.types import ChatMessage -from llama_index.core.prompts import PromptTemplate -from llama_index.core.settings import Settings -from llama_index.core.workflow import ( - Context, - Event, - StartEvent, - StopEvent, - Workflow, - step, -) - - -def create_workflow(chat_history: Optional[List[ChatMessage]] = None, **kwargs): - researcher = create_researcher( - chat_history=chat_history, - **kwargs, - ) - - analyst = create_analyst(chat_history=chat_history) - - reporter = create_reporter(chat_history=chat_history) - - workflow = FinancialReportWorkflow(timeout=360, chat_history=chat_history) - - workflow.add_workflows( - researcher=researcher, - analyst=analyst, - reporter=reporter, - ) - return workflow - - -class ResearchEvent(Event): - input: str - - -class AnalyzeEvent(Event): - input: str - - -class ReportEvent(Event): - input: str - - -class FinancialReportWorkflow(Workflow): - def __init__( - self, timeout: int = 360, chat_history: Optional[List[ChatMessage]] = None - ): - super().__init__(timeout=timeout) - self.chat_history = chat_history or [] - - @step() - async def start(self, ctx: Context, ev: StartEvent) -> ResearchEvent | ReportEvent: - # set streaming - ctx.data["streaming"] = getattr(ev, "streaming", False) - # start the workflow with researching about a topic - ctx.data["task"] = ev.input - ctx.data["user_input"] = ev.input - - # Decision-making process - decision = await self._decide_workflow(ev.input, self.chat_history) - - if decision != "publish": - return ResearchEvent(input=f"Research for this task: {ev.input}") - else: - chat_history_str = "\n".join( - [f"{msg.role}: {msg.content}" for msg in self.chat_history] - ) - return ReportEvent( - input=f"Create a report based on the chat history\n{chat_history_str}\n\n and task: {ev.input}" - ) - - async def _decide_workflow( - self, input: str, chat_history: List[ChatMessage] - ) -> str: - # TODO: Refactor this by using prompt generation - prompt_template = PromptTemplate( - dedent( - """ - You are an expert in decision-making, helping people create financial reports for the provided data. - If the user doesn't need to add or update anything, respond with 'publish'. - Otherwise, respond with 'research'. - - Here is the chat history: - {chat_history} - - The current user request is: - {input} - - Given the chat history and the new user request, decide whether to create a report based on existing information. - Decision (respond with either 'not_publish' or 'publish'): - """ - ) - ) - - chat_history_str = "\n".join( - [f"{msg.role}: {msg.content}" for msg in chat_history] - ) - prompt = prompt_template.format(chat_history=chat_history_str, input=input) - - output = await Settings.llm.acomplete(prompt) - decision = output.text.strip().lower() - - return "publish" if decision == "publish" else "research" - - @step() - async def research( - self, ctx: Context, ev: ResearchEvent, researcher: FunctionCallingAgent - ) -> AnalyzeEvent: - result: AgentRunResult = await self.run_agent(ctx, researcher, ev.input) - content = result.response.message.content - return AnalyzeEvent( - input=dedent( - f""" - Given the following research content: - {content} - Provide a comprehensive analysis of the data for the user's request: {ctx.data["task"]} - """ - ) - ) - - @step() - async def analyze( - self, ctx: Context, ev: AnalyzeEvent, analyst: FunctionCallingAgent - ) -> ReportEvent | StopEvent: - result: AgentRunResult = await self.run_agent(ctx, analyst, ev.input) - content = result.response.message.content - return ReportEvent( - input=dedent( - f""" - Given the following analysis: - {content} - Create a report for the user's request: {ctx.data["task"]} - """ - ) - ) - - @step() - async def report( - self, ctx: Context, ev: ReportEvent, reporter: FunctionCallingAgent - ) -> StopEvent: - try: - result: AgentRunResult = await self.run_agent( - ctx, reporter, ev.input, streaming=ctx.data["streaming"] - ) - return StopEvent(result=result) - except Exception as e: - ctx.write_event_to_stream( - AgentRunEvent( - name=reporter.name, - msg=f"Error creating a report: {e}", - ) - ) - return StopEvent(result=None) - - async def run_agent( - self, - ctx: Context, - agent: FunctionCallingAgent, - input: str, - streaming: bool = False, - ) -> AgentRunResult | AsyncGenerator: - handler = agent.run(input=input, streaming=streaming) - # bubble all events while running the executor to the planner - async for event in handler.stream_events(): - # Don't write the StopEvent from sub task to the stream - if type(event) is not StopEvent: - ctx.write_event_to_stream(event) - return await handler diff --git a/templates/components/agents/python/financial_report/app/engine/engine.py b/templates/components/agents/python/financial_report/app/engine/engine.py deleted file mode 100644 index 0ea21ab29ae6a8581ede7b80d392e17a9bab2d87..0000000000000000000000000000000000000000 --- a/templates/components/agents/python/financial_report/app/engine/engine.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import List, Optional - -from app.agents.workflow import create_workflow -from llama_index.core.chat_engine.types import ChatMessage -from llama_index.core.workflow import Workflow - - -def get_chat_engine( - chat_history: Optional[List[ChatMessage]] = None, **kwargs -) -> Workflow: - agent_workflow = create_workflow(chat_history, **kwargs) - return agent_workflow diff --git a/templates/components/agents/python/financial_report/app/workflows/__init__.py b/templates/components/agents/python/financial_report/app/workflows/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a509e28ae7cea22a3c9eb459d53d3bb3cdc758c --- /dev/null +++ b/templates/components/agents/python/financial_report/app/workflows/__init__.py @@ -0,0 +1,3 @@ +from .financial_report import create_workflow + +__all__ = ["create_workflow"] diff --git a/templates/components/agents/python/financial_report/app/workflows/financial_report.py b/templates/components/agents/python/financial_report/app/workflows/financial_report.py new file mode 100644 index 0000000000000000000000000000000000000000..03ebfaa3df5cad54609c943d2a9c0380dd8c263f --- /dev/null +++ b/templates/components/agents/python/financial_report/app/workflows/financial_report.py @@ -0,0 +1,298 @@ +import os +from typing import Any, Dict, List, Optional + +from app.engine.index import IndexConfig, get_index +from app.engine.tools import ToolFactory +from app.workflows.events import AgentRunEvent +from app.workflows.tools import ( + call_tools, + chat_with_tools, +) +from llama_index.core import Settings +from llama_index.core.base.llms.types import ChatMessage, MessageRole +from llama_index.core.indices.vector_store import VectorStoreIndex +from llama_index.core.llms.function_calling import FunctionCallingLLM +from llama_index.core.memory import ChatMemoryBuffer +from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection +from llama_index.core.workflow import ( + Context, + Event, + StartEvent, + StopEvent, + Workflow, + step, +) + + +def create_workflow( + chat_history: Optional[List[ChatMessage]] = None, + params: Optional[Dict[str, Any]] = None, + filters: Optional[List[Any]] = None, +) -> Workflow: + index_config = IndexConfig(**params) + index: VectorStoreIndex = get_index(config=index_config) + if index is None: + query_engine_tool = None + else: + top_k = int(os.getenv("TOP_K", 10)) + query_engine = index.as_query_engine(similarity_top_k=top_k, filters=filters) + query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine) + + configured_tools: Dict[str, FunctionTool] = ToolFactory.from_env(map_result=True) # type: ignore + code_interpreter_tool = configured_tools.get("interpret") + document_generator_tool = configured_tools.get("generate_document") + + return FinancialReportWorkflow( + query_engine_tool=query_engine_tool, + code_interpreter_tool=code_interpreter_tool, + document_generator_tool=document_generator_tool, + chat_history=chat_history, + ) + + +class InputEvent(Event): + input: List[ChatMessage] + response: bool = False + + +class ResearchEvent(Event): + input: list[ToolSelection] + + +class AnalyzeEvent(Event): + input: list[ToolSelection] | ChatMessage + + +class ReportEvent(Event): + input: list[ToolSelection] + + +class FinancialReportWorkflow(Workflow): + """ + A workflow to generate a financial report using indexed documents. + + Requirements: + - Indexed documents containing financial data and a query engine tool to search them + - A code interpreter tool to analyze data and generate reports + - A document generator tool to create report files + + Steps: + 1. LLM Input: The LLM determines the next step based on function calling. + For example, if the model requests the query engine tool, it returns a ResearchEvent; + if it requests document generation, it returns a ReportEvent. + 2. Research: Uses the query engine to find relevant chunks from indexed documents. + After gathering information, it requests analysis (step 3). + 3. Analyze: Uses a custom prompt to analyze research results and can call the code + interpreter tool for visualization or calculation. Returns results to the LLM. + 4. Report: Uses the document generator tool to create a report. Returns results to the LLM. + """ + + _default_system_prompt = """ + You are a financial analyst who are given a set of tools to help you. + It's good to using appropriate tools for the user request and always use the information from the tools, don't make up anything yourself. + For the query engine tool, you should break down the user request into a list of queries and call the tool with the queries. + """ + + def __init__( + self, + query_engine_tool: QueryEngineTool, + code_interpreter_tool: FunctionTool, + document_generator_tool: FunctionTool, + llm: Optional[FunctionCallingLLM] = None, + timeout: int = 360, + chat_history: Optional[List[ChatMessage]] = None, + system_prompt: Optional[str] = None, + ): + super().__init__(timeout=timeout) + self.system_prompt = system_prompt or self._default_system_prompt + self.chat_history = chat_history or [] + self.query_engine_tool = query_engine_tool + self.code_interpreter_tool = code_interpreter_tool + self.document_generator_tool = document_generator_tool + assert ( + query_engine_tool is not None + ), "Query engine tool is not found. Try run generation script or upload a document file first." + assert code_interpreter_tool is not None, "Code interpreter tool is required" + assert ( + document_generator_tool is not None + ), "Document generator tool is required" + self.tools = [ + self.query_engine_tool, + self.code_interpreter_tool, + self.document_generator_tool, + ] + self.llm: FunctionCallingLLM = llm or Settings.llm + assert isinstance(self.llm, FunctionCallingLLM) + self.memory = ChatMemoryBuffer.from_defaults( + llm=self.llm, chat_history=self.chat_history + ) + + @step() + async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent: + ctx.data["input"] = ev.input + + if self.system_prompt: + system_msg = ChatMessage( + role=MessageRole.SYSTEM, content=self.system_prompt + ) + self.memory.put(system_msg) + + # Add user input to memory + self.memory.put(ChatMessage(role=MessageRole.USER, content=ev.input)) + + return InputEvent(input=self.memory.get()) + + @step() + async def handle_llm_input( # type: ignore + self, + ctx: Context, + ev: InputEvent, + ) -> ResearchEvent | AnalyzeEvent | ReportEvent | StopEvent: + """ + Handle an LLM input and decide the next step. + """ + # Always use the latest chat history from the input + chat_history: list[ChatMessage] = ev.input + + # Get tool calls + response = await chat_with_tools( + self.llm, + self.tools, # type: ignore + chat_history, + ) + if not response.has_tool_calls(): + # If no tool call, return the response generator + return StopEvent(result=response.generator) + # calling different tools at the same time is not supported at the moment + # add an error message to tell the AI to process step by step + if response.is_calling_different_tools(): + self.memory.put( + ChatMessage( + role=MessageRole.ASSISTANT, + content="Cannot call different tools at the same time. Try calling one tool at a time.", + ) + ) + return InputEvent(input=self.memory.get()) + self.memory.put(response.tool_call_message) + match response.tool_name(): + case self.code_interpreter_tool.metadata.name: + return AnalyzeEvent(input=response.tool_calls) + case self.document_generator_tool.metadata.name: + return ReportEvent(input=response.tool_calls) + case self.query_engine_tool.metadata.name: + return ResearchEvent(input=response.tool_calls) + case _: + raise ValueError(f"Unknown tool: {response.tool_name()}") + + @step() + async def research(self, ctx: Context, ev: ResearchEvent) -> AnalyzeEvent: + """ + Do a research to gather information for the user's request. + A researcher should have these tools: query engine, search engine, etc. + """ + ctx.write_event_to_stream( + AgentRunEvent( + name="Researcher", + msg="Starting research", + ) + ) + tool_calls = ev.input + + tool_messages = await call_tools( + ctx=ctx, + agent_name="Researcher", + tools=[self.query_engine_tool], + tool_calls=tool_calls, + ) + self.memory.put_messages(tool_messages) + return AnalyzeEvent( + input=ChatMessage( + role=MessageRole.ASSISTANT, + content="I've finished the research. Please analyze the result.", + ), + ) + + @step() + async def analyze(self, ctx: Context, ev: AnalyzeEvent) -> InputEvent: + """ + Analyze the research result. + """ + ctx.write_event_to_stream( + AgentRunEvent( + name="Analyst", + msg="Starting analysis", + ) + ) + event_requested_by_workflow_llm = isinstance(ev.input, list) + # Requested by the workflow LLM Input step, it's a tool call + if event_requested_by_workflow_llm: + # Set the tool calls + tool_calls = ev.input + else: + # Otherwise, it's triggered by the research step + # Use a custom prompt and independent memory for the analyst agent + analysis_prompt = """ + You are a financial analyst, you are given a research result and a set of tools to help you. + Always use the given information, don't make up anything yourself. If there is not enough information, you can asking for more information. + If you have enough numerical information, it's good to include some charts/visualizations to the report so you can use the code interpreter tool to generate a report. + """ + # This is handled by analyst agent + # Clone the shared memory to avoid conflicting with the workflow. + chat_history = self.memory.get() + chat_history.append( + ChatMessage(role=MessageRole.SYSTEM, content=analysis_prompt) + ) + chat_history.append(ev.input) # type: ignore + # Check if the analyst agent needs to call tools + response = await chat_with_tools( + self.llm, + [self.code_interpreter_tool], + chat_history, + ) + if not response.has_tool_calls(): + # If no tool call, fallback analyst message to the workflow + analyst_msg = ChatMessage( + role=MessageRole.ASSISTANT, + content=await response.full_response(), + ) + self.memory.put(analyst_msg) + return InputEvent(input=self.memory.get()) + else: + # Set the tool calls and the tool call message to the memory + tool_calls = response.tool_calls + self.memory.put(response.tool_call_message) + + # Call tools + tool_messages = await call_tools( + ctx=ctx, + agent_name="Analyst", + tools=[self.code_interpreter_tool], + tool_calls=tool_calls, # type: ignore + ) + self.memory.put_messages(tool_messages) + + # Fallback to the input with the latest chat history + return InputEvent(input=self.memory.get()) + + @step() + async def report(self, ctx: Context, ev: ReportEvent) -> InputEvent: + """ + Generate a report based on the analysis result. + """ + ctx.write_event_to_stream( + AgentRunEvent( + name="Reporter", + msg="Starting report generation", + ) + ) + tool_calls = ev.input + tool_messages = await call_tools( + ctx=ctx, + agent_name="Reporter", + tools=[self.document_generator_tool], + tool_calls=tool_calls, + ) + self.memory.put_messages(tool_messages) + + # After the tool calls, fallback to the input with the latest chat history + return InputEvent(input=self.memory.get()) diff --git a/templates/components/agents/python/form_filling/README-template.md b/templates/components/agents/python/form_filling/README-template.md index a3340c531f455b26e9b15d99c7a0082f7b477e75..be6ec38e39f2cd3632f78ca2fa4627ce4ab39e1b 100644 --- a/templates/components/agents/python/form_filling/README-template.md +++ b/templates/components/agents/python/form_filling/README-template.md @@ -39,7 +39,7 @@ curl --location 'localhost:8000/api/chat' \ --data '{ "messages": [{ "role": "user", "content": "What can you do?" }] }' ``` -You can start editing the API by modifying `app/api/routers/chat.py` or `app/agents/form_filling.py`. The API auto-updates as you save the files. +You can start editing the API by modifying `app/api/routers/chat.py` or `app/workflows/form_filling.py`. The API auto-updates as you save the files. Open [http://localhost:8000/docs](http://localhost:8000/docs) with your browser to see the Swagger UI of the API. diff --git a/templates/components/agents/python/form_filling/app/agents/form_filling.py b/templates/components/agents/python/form_filling/app/agents/form_filling.py deleted file mode 100644 index 2cc7a0f79f45e37e49d7bb9ded07f15216cea6e3..0000000000000000000000000000000000000000 --- a/templates/components/agents/python/form_filling/app/agents/form_filling.py +++ /dev/null @@ -1,397 +0,0 @@ -import os -import uuid -from enum import Enum -from typing import AsyncGenerator, List, Optional - -from app.engine.index import get_index -from app.engine.tools import ToolFactory -from app.engine.tools.form_filling import CellValue, MissingCell -from llama_index.core import Settings -from llama_index.core.base.llms.types import ChatMessage, MessageRole -from llama_index.core.indices.vector_store import VectorStoreIndex -from llama_index.core.llms.function_calling import FunctionCallingLLM -from llama_index.core.memory import ChatMemoryBuffer -from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection -from llama_index.core.tools.types import ToolOutput -from llama_index.core.workflow import ( - Context, - Event, - StartEvent, - StopEvent, - Workflow, - step, -) -from pydantic import Field - - -def create_workflow( - chat_history: Optional[List[ChatMessage]] = None, **kwargs -) -> Workflow: - index: VectorStoreIndex = get_index() - if index is None: - query_engine_tool = None - else: - top_k = int(os.getenv("TOP_K", 10)) - query_engine = index.as_query_engine(similarity_top_k=top_k) - query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine) - - configured_tools = ToolFactory.from_env(map_result=True) - extractor_tool = configured_tools.get("extract_questions") - filling_tool = configured_tools.get("fill_form") - - if extractor_tool is None or filling_tool is None: - raise ValueError("Extractor or filling tool is not found!") - - workflow = FormFillingWorkflow( - query_engine_tool=query_engine_tool, - extractor_tool=extractor_tool, - filling_tool=filling_tool, - chat_history=chat_history, - ) - - return workflow - - -class InputEvent(Event): - input: List[ChatMessage] - response: bool = False - - -class ExtractMissingCellsEvent(Event): - tool_call: ToolSelection - - -class FindAnswersEvent(Event): - missing_cells: list[MissingCell] - - -class FillEvent(Event): - tool_call: ToolSelection - - -class AgentRunEventType(Enum): - TEXT = "text" - PROGRESS = "progress" - - -class AgentRunEvent(Event): - name: str - msg: str - event_type: AgentRunEventType = Field(default=AgentRunEventType.TEXT) - data: Optional[dict] = None - - def to_response(self) -> dict: - return { - "type": "agent", - "data": { - "agent": self.name, - "type": self.event_type.value, - "text": self.msg, - "data": self.data, - }, - } - - -class FormFillingWorkflow(Workflow): - """ - A predefined workflow for filling missing cells in a CSV file. - Required tools: - - query_engine: A query engine to query for the answers to the questions. - - extract_question: Extract missing cells in a CSV file and generate questions to fill them. - - answer_question: Query for the answers to the questions. - - Flow: - 1. Extract missing cells in a CSV file and generate questions to fill them. - 2. Query for the answers to the questions. - 3. Fill the missing cells with the answers. - """ - - _default_system_prompt = """ - You are a helpful assistant who helps fill missing cells in a CSV file. - Only use provided data, never make up any information yourself. Fill N/A if the answer is not found. - """ - - def __init__( - self, - query_engine_tool: QueryEngineTool, - extractor_tool: FunctionTool, - filling_tool: FunctionTool, - llm: Optional[FunctionCallingLLM] = None, - timeout: int = 360, - chat_history: Optional[List[ChatMessage]] = None, - system_prompt: Optional[str] = None, - ): - super().__init__(timeout=timeout) - self.system_prompt = system_prompt or self._default_system_prompt - self.chat_history = chat_history or [] - self.query_engine_tool = query_engine_tool - self.extractor_tool = extractor_tool - self.filling_tool = filling_tool - self.llm: FunctionCallingLLM = llm or Settings.llm - if not isinstance(self.llm, FunctionCallingLLM): - raise ValueError("FormFillingWorkflow only supports FunctionCallingLLM.") - self.memory = ChatMemoryBuffer.from_defaults( - llm=self.llm, chat_history=self.chat_history - ) - - @step() - async def start(self, ctx: Context, ev: StartEvent) -> InputEvent: - ctx.data["streaming"] = getattr(ev, "streaming", False) - ctx.data["input"] = ev.input - - if self.system_prompt: - system_msg = ChatMessage( - role=MessageRole.SYSTEM, content=self.system_prompt - ) - self.memory.put(system_msg) - - user_input = ev.input - user_msg = ChatMessage(role=MessageRole.USER, content=user_input) - self.memory.put(user_msg) - - chat_history = self.memory.get() - return InputEvent(input=chat_history) - - @step(pass_context=True) - async def handle_llm_input( # type: ignore - self, - ctx: Context, - ev: InputEvent, - ) -> ExtractMissingCellsEvent | FillEvent | StopEvent: - """ - Handle an LLM input and decide the next step. - """ - chat_history: list[ChatMessage] = ev.input - - generator = self._tool_call_generator(chat_history) - - # Check for immediate tool call - is_tool_call = await generator.__anext__() - if is_tool_call: - full_response = await generator.__anext__() - tool_calls = self.llm.get_tool_calls_from_response(full_response) # type: ignore - for tool_call in tool_calls: - if tool_call.tool_name == self.extractor_tool.metadata.get_name(): - ctx.send_event(ExtractMissingCellsEvent(tool_call=tool_call)) - elif tool_call.tool_name == self.filling_tool.metadata.get_name(): - ctx.send_event(FillEvent(tool_call=tool_call)) - else: - # If no tool call, return the generator - return StopEvent(result=generator) - - @step() - async def extract_missing_cells( - self, ctx: Context, ev: ExtractMissingCellsEvent - ) -> InputEvent | FindAnswersEvent: - """ - Extract missing cells in a CSV file and generate questions to fill them. - """ - ctx.write_event_to_stream( - AgentRunEvent( - name="Extractor", - msg="Extracting missing cells", - ) - ) - # Call the extract questions tool - response = self._call_tool( - ctx, - agent_name="Extractor", - tool=self.extractor_tool, - tool_selection=ev.tool_call, - ) - if response.is_error: - return InputEvent(input=self.memory.get()) - - missing_cells = response.raw_output.get("missing_cells", []) - message = ChatMessage( - role=MessageRole.TOOL, - content=str(missing_cells), - additional_kwargs={ - "tool_call_id": ev.tool_call.tool_id, - "name": ev.tool_call.tool_name, - }, - ) - self.memory.put(message) - - if self.query_engine_tool is None: - # Fallback to input that query engine tool is not found so that cannot answer questions - self.memory.put( - ChatMessage( - role=MessageRole.ASSISTANT, - content="Extracted missing cells but query engine tool is not found so cannot answer questions. Ask user to upload file or connect to a knowledge base.", - ) - ) - return InputEvent(input=self.memory.get()) - - # Forward missing cells information to find answers step - return FindAnswersEvent(missing_cells=missing_cells) - - @step() - async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent: - """ - Call answer questions tool to query for the answers to the questions. - """ - ctx.write_event_to_stream( - AgentRunEvent( - name="Researcher", - msg="Finding answers for missing cells", - ) - ) - missing_cells = ev.missing_cells - # If missing cells information is not found, fallback to other tools - # It means that the extractor tool has not been called yet - # Fallback to input - if missing_cells is None: - ctx.write_event_to_stream( - AgentRunEvent( - name="Researcher", - msg="Error: Missing cells information not found. Fallback to other tools.", - ) - ) - message = ChatMessage( - role=MessageRole.TOOL, - content="Error: Missing cells information not found.", - additional_kwargs={ - "tool_call_id": ev.tool_call.tool_id, - "name": ev.tool_call.tool_name, - }, - ) - self.memory.put(message) - return InputEvent(input=self.memory.get()) - - cell_values: list[CellValue] = [] - # Iterate over missing cells and query for the answers - # and stream the progress - progress_id = str(uuid.uuid4()) - total_steps = len(missing_cells) - for i, cell in enumerate(missing_cells): - if cell.question_to_answer is None: - continue - ctx.write_event_to_stream( - AgentRunEvent( - name="Researcher", - msg=f"Querying for: {cell.question_to_answer}", - event_type=AgentRunEventType.PROGRESS, - data={ - "id": progress_id, - "total": total_steps, - "current": i, - }, - ) - ) - # Call query engine tool directly - answer = await self.query_engine_tool.acall(query=cell.question_to_answer) - cell_values.append( - CellValue( - row_index=cell.row_index, - column_index=cell.column_index, - value=str(answer), - ) - ) - self.memory.put( - ChatMessage( - role=MessageRole.ASSISTANT, - content=str(cell_values), - ) - ) - return InputEvent(input=self.memory.get()) - - @step() - async def fill_cells(self, ctx: Context, ev: FillEvent) -> InputEvent: - """ - Call fill cells tool to fill the missing cells with the answers. - """ - ctx.write_event_to_stream( - AgentRunEvent( - name="Processor", - msg="Filling missing cells", - ) - ) - # Call the fill cells tool - result = self._call_tool( - ctx, - agent_name="Processor", - tool=self.filling_tool, - tool_selection=ev.tool_call, - ) - if result.is_error: - return InputEvent(input=self.memory.get()) - - message = ChatMessage( - role=MessageRole.TOOL, - content=str(result.raw_output), - additional_kwargs={ - "tool_call_id": ev.tool_call.tool_id, - "name": ev.tool_call.tool_name, - }, - ) - self.memory.put(message) - return InputEvent(input=self.memory.get(), response=True) - - async def _tool_call_generator( - self, chat_history: list[ChatMessage] - ) -> AsyncGenerator[ChatMessage | bool, None]: - response_stream = await self.llm.astream_chat_with_tools( - [self.extractor_tool, self.filling_tool], - chat_history=chat_history, - ) - - full_response = None - yielded_indicator = False - async for chunk in response_stream: - if "tool_calls" not in chunk.message.additional_kwargs: - # Yield a boolean to indicate whether the response is a tool call - if not yielded_indicator: - yield False - yielded_indicator = True - - # if not a tool call, yield the chunks! - yield chunk - elif not yielded_indicator: - # Yield the indicator for a tool call - yield True - yielded_indicator = True - - full_response = chunk - - # Write the full response to memory and yield it - if full_response: - self.memory.put(full_response.message) - yield full_response - - def _call_tool( - self, - ctx: Context, - agent_name: str, - tool: FunctionTool, - tool_selection: ToolSelection, - ) -> ToolOutput: - """ - Safely call a tool and handle errors. - """ - try: - response: ToolOutput = tool.call(**tool_selection.tool_kwargs) - return response - except Exception as e: - ctx.write_event_to_stream( - AgentRunEvent( - name=agent_name, - msg=f"Error: {str(e)}", - ) - ) - message = ChatMessage( - role=MessageRole.TOOL, - content=f"Error: {str(e)}", - additional_kwargs={ - "tool_call_id": tool_selection.tool_id, - "name": tool.metadata.get_name(), - }, - ) - self.memory.put(message) - return ToolOutput( - content=f"Error: {str(e)}", - tool_name=tool.metadata.get_name(), - raw_input=tool_selection.tool_kwargs, - raw_output=None, - is_error=True, - ) diff --git a/templates/components/agents/python/form_filling/app/engine/engine.py b/templates/components/agents/python/form_filling/app/engine/engine.py deleted file mode 100644 index 68dbb6ce68b45207c346e46da2a86c0b6debc9fc..0000000000000000000000000000000000000000 --- a/templates/components/agents/python/form_filling/app/engine/engine.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import List, Optional - -from app.agents.form_filling import create_workflow -from llama_index.core.chat_engine.types import ChatMessage -from llama_index.core.workflow import Workflow - - -def get_chat_engine( - chat_history: Optional[List[ChatMessage]] = None, **kwargs -) -> Workflow: - return create_workflow(chat_history=chat_history, **kwargs) diff --git a/templates/components/agents/python/form_filling/app/workflows/__init__.py b/templates/components/agents/python/form_filling/app/workflows/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c999f993cac09ba9654eb947dfdd929a53666029 --- /dev/null +++ b/templates/components/agents/python/form_filling/app/workflows/__init__.py @@ -0,0 +1,3 @@ +from .form_filling import create_workflow + +__all__ = ["create_workflow"] diff --git a/templates/components/agents/python/form_filling/app/workflows/form_filling.py b/templates/components/agents/python/form_filling/app/workflows/form_filling.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2497b78fda4236a17730ae5349e7e1a7ef035c --- /dev/null +++ b/templates/components/agents/python/form_filling/app/workflows/form_filling.py @@ -0,0 +1,241 @@ +import os +from typing import Any, Dict, List, Optional + +from app.engine.index import IndexConfig, get_index +from app.engine.tools import ToolFactory +from app.workflows.events import AgentRunEvent +from app.workflows.tools import ( + call_tools, + chat_with_tools, +) +from llama_index.core import Settings +from llama_index.core.base.llms.types import ChatMessage, MessageRole +from llama_index.core.indices.vector_store import VectorStoreIndex +from llama_index.core.llms.function_calling import FunctionCallingLLM +from llama_index.core.memory import ChatMemoryBuffer +from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection +from llama_index.core.workflow import ( + Context, + Event, + StartEvent, + StopEvent, + Workflow, + step, +) + + +def create_workflow( + chat_history: Optional[List[ChatMessage]] = None, + params: Optional[Dict[str, Any]] = None, + filters: Optional[List[Any]] = None, +) -> Workflow: + if params is None: + params = {} + if filters is None: + filters = [] + index_config = IndexConfig(**params) + index: VectorStoreIndex = get_index(config=index_config) + if index is None: + query_engine_tool = None + else: + top_k = int(os.getenv("TOP_K", 10)) + query_engine = index.as_query_engine(similarity_top_k=top_k, filters=filters) + query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine) + + configured_tools = ToolFactory.from_env(map_result=True) + extractor_tool = configured_tools.get("extract_questions") # type: ignore + filling_tool = configured_tools.get("fill_form") # type: ignore + + workflow = FormFillingWorkflow( + query_engine_tool=query_engine_tool, + extractor_tool=extractor_tool, # type: ignore + filling_tool=filling_tool, # type: ignore + chat_history=chat_history, + ) + + return workflow + + +class InputEvent(Event): + input: List[ChatMessage] + response: bool = False + + +class ExtractMissingCellsEvent(Event): + tool_calls: list[ToolSelection] + + +class FindAnswersEvent(Event): + tool_calls: list[ToolSelection] + + +class FillEvent(Event): + tool_calls: list[ToolSelection] + + +class FormFillingWorkflow(Workflow): + """ + A predefined workflow for filling missing cells in a CSV file. + Required tools: + - query_engine: A query engine to query for the answers to the questions. + - extract_question: Extract missing cells in a CSV file and generate questions to fill them. + - answer_question: Query for the answers to the questions. + + Flow: + 1. Extract missing cells in a CSV file and generate questions to fill them. + 2. Query for the answers to the questions. + 3. Fill the missing cells with the answers. + """ + + _default_system_prompt = """ + You are a helpful assistant who helps fill missing cells in a CSV file. + Only extract missing cells from CSV files. + Only use provided data - never make up any information yourself. Fill N/A if an answer is not found. + If there is no query engine tool or the gathered information has many N/A values indicating the questions don't match the data, respond with a warning and ask the user to upload a different file or connect to a knowledge base. + """ + + def __init__( + self, + query_engine_tool: Optional[QueryEngineTool], + extractor_tool: FunctionTool, + filling_tool: FunctionTool, + llm: Optional[FunctionCallingLLM] = None, + timeout: int = 360, + chat_history: Optional[List[ChatMessage]] = None, + system_prompt: Optional[str] = None, + ): + super().__init__(timeout=timeout) + self.system_prompt = system_prompt or self._default_system_prompt + self.chat_history = chat_history or [] + self.query_engine_tool = query_engine_tool + self.extractor_tool = extractor_tool + self.filling_tool = filling_tool + if self.extractor_tool is None or self.filling_tool is None: + raise ValueError("Extractor and filling tools are required.") + self.tools = [self.extractor_tool, self.filling_tool] + if self.query_engine_tool is not None: + self.tools.append(self.query_engine_tool) # type: ignore + self.llm: FunctionCallingLLM = llm or Settings.llm + if not isinstance(self.llm, FunctionCallingLLM): + raise ValueError("FormFillingWorkflow only supports FunctionCallingLLM.") + self.memory = ChatMemoryBuffer.from_defaults( + llm=self.llm, chat_history=self.chat_history + ) + + @step() + async def start(self, ctx: Context, ev: StartEvent) -> InputEvent: + ctx.data["input"] = ev.input + + if self.system_prompt: + system_msg = ChatMessage( + role=MessageRole.SYSTEM, content=self.system_prompt + ) + self.memory.put(system_msg) + + user_input = ev.input + user_msg = ChatMessage(role=MessageRole.USER, content=user_input) + self.memory.put(user_msg) + + chat_history = self.memory.get() + return InputEvent(input=chat_history) + + @step() + async def handle_llm_input( # type: ignore + self, + ctx: Context, + ev: InputEvent, + ) -> ExtractMissingCellsEvent | FillEvent | StopEvent: + """ + Handle an LLM input and decide the next step. + """ + chat_history: list[ChatMessage] = ev.input + response = await chat_with_tools( + self.llm, + self.tools, + chat_history, + ) + if not response.has_tool_calls(): + return StopEvent(result=response.generator) + # calling different tools at the same time is not supported at the moment + # add an error message to tell the AI to process step by step + if response.is_calling_different_tools(): + self.memory.put( + ChatMessage( + role=MessageRole.ASSISTANT, + content="Cannot call different tools at the same time. Try calling one tool at a time.", + ) + ) + return InputEvent(input=self.memory.get()) + self.memory.put(response.tool_call_message) + match response.tool_name(): + case self.extractor_tool.metadata.name: + return ExtractMissingCellsEvent(tool_calls=response.tool_calls) + case self.query_engine_tool.metadata.name: + return FindAnswersEvent(tool_calls=response.tool_calls) + case self.filling_tool.metadata.name: + return FillEvent(tool_calls=response.tool_calls) + case _: + raise ValueError(f"Unknown tool: {response.tool_name()}") + + @step() + async def extract_missing_cells( + self, ctx: Context, ev: ExtractMissingCellsEvent + ) -> InputEvent | FindAnswersEvent: + """ + Extract missing cells in a CSV file and generate questions to fill them. + """ + ctx.write_event_to_stream( + AgentRunEvent( + name="Extractor", + msg="Extracting missing cells", + ) + ) + # Call the extract questions tool + tool_messages = await call_tools( + agent_name="Extractor", + tools=[self.extractor_tool], + ctx=ctx, + tool_calls=ev.tool_calls, + ) + self.memory.put_messages(tool_messages) + return InputEvent(input=self.memory.get()) + + @step() + async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent: + """ + Call answer questions tool to query for the answers to the questions. + """ + ctx.write_event_to_stream( + AgentRunEvent( + name="Researcher", + msg="Finding answers for missing cells", + ) + ) + tool_messages = await call_tools( + ctx=ctx, + agent_name="Researcher", + tools=[self.query_engine_tool], + tool_calls=ev.tool_calls, + ) + self.memory.put_messages(tool_messages) + return InputEvent(input=self.memory.get()) + + @step() + async def fill_cells(self, ctx: Context, ev: FillEvent) -> InputEvent: + """ + Call fill cells tool to fill the missing cells with the answers. + """ + ctx.write_event_to_stream( + AgentRunEvent( + name="Processor", + msg="Filling missing cells", + ) + ) + tool_messages = await call_tools( + agent_name="Processor", + tools=[self.filling_tool], + ctx=ctx, + tool_calls=ev.tool_calls, + ) + self.memory.put_messages(tool_messages) + return InputEvent(input=self.memory.get()) diff --git a/templates/components/multiagent/python/app/api/routers/chat.py b/templates/components/multiagent/python/app/api/routers/chat.py index 9b9b6c9d190996442aa0283d2e4f079e0ace4c50..747e558968c165b2e9de95113f90c5e77617f9fa 100644 --- a/templates/components/multiagent/python/app/api/routers/chat.py +++ b/templates/components/multiagent/python/app/api/routers/chat.py @@ -4,7 +4,8 @@ from app.api.routers.models import ( ChatData, ) from app.api.routers.vercel_response import VercelStreamResponse -from app.engine.engine import get_chat_engine +from app.engine.query_filter import generate_filters +from app.workflows import create_workflow from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status chat_router = r = APIRouter() @@ -22,19 +23,20 @@ async def chat( last_message_content = data.get_last_message_content() messages = data.get_history_messages(include_agent_messages=True) - # The chat API supports passing private document filters and chat params - # but agent workflow does not support them yet - # ignore chat params and use all documents for now - # TODO: generate filters based on doc_ids + doc_ids = data.get_chat_document_ids() + filters = generate_filters(doc_ids) params = data.data or {} - engine = get_chat_engine(chat_history=messages, params=params) - event_handler = engine.run(input=last_message_content, streaming=True) + workflow = create_workflow( + chat_history=messages, params=params, filters=filters + ) + + event_handler = workflow.run(input=last_message_content, streaming=True) return VercelStreamResponse( request=request, chat_data=data, event_handler=event_handler, - events=engine.stream_events(), + events=workflow.stream_events(), ) except Exception as e: logger.exception("Error in chat engine", exc_info=True) diff --git a/templates/components/multiagent/python/app/workflows/events.py b/templates/components/multiagent/python/app/workflows/events.py new file mode 100644 index 0000000000000000000000000000000000000000..f40e9e1abe29bd1425bec248ec2ba29ef9e88b73 --- /dev/null +++ b/templates/components/multiagent/python/app/workflows/events.py @@ -0,0 +1,27 @@ +from enum import Enum +from typing import Optional + +from llama_index.core.workflow import Event + + +class AgentRunEventType(Enum): + TEXT = "text" + PROGRESS = "progress" + + +class AgentRunEvent(Event): + name: str + msg: str + event_type: AgentRunEventType = AgentRunEventType.TEXT + data: Optional[dict] = None + + def to_response(self) -> dict: + return { + "type": "agent", + "data": { + "agent": self.name, + "type": self.event_type.value, + "text": self.msg, + "data": self.data, + }, + } diff --git a/templates/components/multiagent/python/app/workflows/function_calling_agent.py b/templates/components/multiagent/python/app/workflows/function_calling_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..452fc5e7b668d7377cdd1f11a8a05f662e9ed1ed --- /dev/null +++ b/templates/components/multiagent/python/app/workflows/function_calling_agent.py @@ -0,0 +1,121 @@ +from typing import Any, List, Optional + +from app.workflows.events import AgentRunEvent +from app.workflows.tools import ToolCallResponse, call_tools, chat_with_tools +from llama_index.core.base.llms.types import ChatMessage +from llama_index.core.llms.function_calling import FunctionCallingLLM +from llama_index.core.memory import ChatMemoryBuffer +from llama_index.core.settings import Settings +from llama_index.core.tools.types import BaseTool +from llama_index.core.workflow import ( + Context, + Event, + StartEvent, + StopEvent, + Workflow, + step, +) + + +class InputEvent(Event): + input: list[ChatMessage] + + +class ToolCallEvent(Event): + input: ToolCallResponse + + +class FunctionCallingAgent(Workflow): + """ + A simple workflow to request LLM with tools independently. + You can share the previous chat history to provide the context for the LLM. + """ + + def __init__( + self, + *args: Any, + llm: FunctionCallingLLM | None = None, + chat_history: Optional[List[ChatMessage]] = None, + tools: List[BaseTool] | None = None, + system_prompt: str | None = None, + verbose: bool = False, + timeout: float = 360.0, + name: str, + write_events: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(*args, verbose=verbose, timeout=timeout, **kwargs) # type: ignore + self.tools = tools or [] + self.name = name + self.write_events = write_events + + if llm is None: + llm = Settings.llm + self.llm = llm + if not self.llm.metadata.is_function_calling_model: + raise ValueError("The provided LLM must support function calling.") + + self.system_prompt = system_prompt + + self.memory = ChatMemoryBuffer.from_defaults( + llm=self.llm, chat_history=chat_history + ) + self.sources = [] # type: ignore + + @step() + async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent: + # clear sources + self.sources = [] + + # set streaming + ctx.data["streaming"] = getattr(ev, "streaming", False) + + # set system prompt + if self.system_prompt is not None: + system_msg = ChatMessage(role="system", content=self.system_prompt) + self.memory.put(system_msg) + + # get user input + user_input = ev.input + user_msg = ChatMessage(role="user", content=user_input) + self.memory.put(user_msg) + + if self.write_events: + ctx.write_event_to_stream( + AgentRunEvent(name=self.name, msg=f"Start to work on: {user_input}") + ) + + return InputEvent(input=self.memory.get()) + + @step() + async def handle_llm_input( + self, + ctx: Context, + ev: InputEvent, + ) -> ToolCallEvent | StopEvent: + chat_history = ev.input + + response = await chat_with_tools( + self.llm, + self.tools, + chat_history, + ) + is_tool_call = isinstance(response, ToolCallResponse) + if not is_tool_call: + if ctx.data["streaming"]: + return StopEvent(result=response) + else: + full_response = "" + async for chunk in response.generator: + full_response += chunk.message.content + return StopEvent(result=full_response) + return ToolCallEvent(input=response) + + @step() + async def handle_tool_calls(self, ctx: Context, ev: ToolCallEvent) -> InputEvent: + tool_calls = ev.input.tool_calls + tool_call_message = ev.input.tool_call_message + self.memory.put(tool_call_message) + tool_messages = await call_tools(self.name, self.tools, ctx, tool_calls) + self.memory.put_messages(tool_messages) + return InputEvent(input=self.memory.get()) diff --git a/templates/components/multiagent/python/app/workflows/tools.py b/templates/components/multiagent/python/app/workflows/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..b42c31610abe0542083b169cc746d5ac4dabc2ca --- /dev/null +++ b/templates/components/multiagent/python/app/workflows/tools.py @@ -0,0 +1,237 @@ +import logging +import uuid +from abc import ABC, abstractmethod +from typing import Any, AsyncGenerator, Callable, Optional + +from app.workflows.events import AgentRunEvent, AgentRunEventType +from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole +from llama_index.core.llms.function_calling import FunctionCallingLLM +from llama_index.core.tools import ( + BaseTool, + FunctionTool, + ToolOutput, + ToolSelection, +) +from llama_index.core.workflow import Context +from pydantic import BaseModel, ConfigDict + +logger = logging.getLogger("uvicorn") + + +class ContextAwareTool(FunctionTool, ABC): + @abstractmethod + async def acall(self, ctx: Context, input: Any) -> ToolOutput: # type: ignore + pass + + +class ResponseGenerator(BaseModel): + """ + A response generator from chat_with_tools. + """ + + generator: AsyncGenerator[ChatResponse | None, None] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ChatWithToolsResponse(BaseModel): + """ + A tool call response from chat_with_tools. + """ + + tool_calls: Optional[list[ToolSelection]] + tool_call_message: Optional[ChatMessage] + generator: Optional[AsyncGenerator[ChatResponse | None, None]] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def is_calling_different_tools(self) -> bool: + tool_names = {tool_call.tool_name for tool_call in self.tool_calls} + return len(tool_names) > 1 + + def has_tool_calls(self) -> bool: + return self.tool_calls is not None and len(self.tool_calls) > 0 + + def tool_name(self) -> str: + assert self.has_tool_calls() + assert not self.is_calling_different_tools() + return self.tool_calls[0].tool_name + + async def full_response(self) -> str: + assert self.generator is not None + full_response = "" + async for chunk in self.generator: + full_response += chunk.message.content + return full_response + + +async def chat_with_tools( # type: ignore + llm: FunctionCallingLLM, + tools: list[BaseTool], + chat_history: list[ChatMessage], +) -> ChatWithToolsResponse: + """ + Request LLM to call tools or not. + This function doesn't change the memory. + """ + generator = _tool_call_generator(llm, tools, chat_history) + is_tool_call = await generator.__anext__() + if is_tool_call: + # Last chunk is the full response + # Wait for the last chunk + full_response = None + async for chunk in generator: + full_response = chunk + assert isinstance(full_response, ChatResponse) + return ChatWithToolsResponse( + tool_calls=llm.get_tool_calls_from_response(full_response), + tool_call_message=full_response.message, + generator=None, + ) + else: + return ChatWithToolsResponse( + tool_calls=None, + tool_call_message=None, + generator=generator, + ) + + +async def call_tools( + ctx: Context, + agent_name: str, + tools: list[BaseTool], + tool_calls: list[ToolSelection], + emit_agent_events: bool = True, +) -> list[ChatMessage]: + if len(tool_calls) == 0: + return [] + + tools_by_name = {tool.metadata.get_name(): tool for tool in tools} + if len(tool_calls) == 1: + return [ + await call_tool( + ctx, + tools_by_name[tool_calls[0].tool_name], + tool_calls[0], + lambda msg: ctx.write_event_to_stream( + AgentRunEvent( + name=agent_name, + msg=msg, + ) + ), + ) + ] + # Multiple tool calls, show progress + tool_msgs: list[ChatMessage] = [] + + progress_id = str(uuid.uuid4()) + total_steps = len(tool_calls) + if emit_agent_events: + ctx.write_event_to_stream( + AgentRunEvent( + name=agent_name, + msg=f"Making {total_steps} tool calls", + ) + ) + for i, tool_call in enumerate(tool_calls): + tool = tools_by_name.get(tool_call.tool_name) + if not tool: + tool_msgs.append( + ChatMessage( + role=MessageRole.ASSISTANT, + content=f"Tool {tool_call.tool_name} does not exist", + ) + ) + continue + tool_msg = await call_tool( + ctx, + tool, + tool_call, + event_emitter=lambda msg: ctx.write_event_to_stream( + AgentRunEvent( + name=agent_name, + msg=msg, + event_type=AgentRunEventType.PROGRESS, + data={ + "id": progress_id, + "total": total_steps, + "current": i, + }, + ) + ), + ) + tool_msgs.append(tool_msg) + return tool_msgs + + +async def call_tool( + ctx: Context, + tool: BaseTool, + tool_call: ToolSelection, + event_emitter: Optional[Callable[[str], None]], +) -> ChatMessage: + if event_emitter: + event_emitter( + f"Calling tool {tool_call.tool_name}, {str(tool_call.tool_kwargs)}" + ) + try: + if isinstance(tool, ContextAwareTool): + if ctx is None: + raise ValueError("Context is required for context aware tool") + # inject context for calling an context aware tool + response = await tool.acall(ctx=ctx, **tool_call.tool_kwargs) + else: + response = await tool.acall(**tool_call.tool_kwargs) # type: ignore + return ChatMessage( + role=MessageRole.TOOL, + content=str(response.raw_output), + additional_kwargs={ + "tool_call_id": tool_call.tool_id, + "name": tool.metadata.get_name(), + }, + ) + except Exception as e: + logger.error(f"Got error in tool {tool_call.tool_name}: {str(e)}") + if event_emitter: + event_emitter(f"Got error in tool {tool_call.tool_name}: {str(e)}") + return ChatMessage( + role=MessageRole.TOOL, + content=f"Error: {str(e)}", + additional_kwargs={ + "tool_call_id": tool_call.tool_id, + "name": tool.metadata.get_name(), + }, + ) + + +async def _tool_call_generator( + llm: FunctionCallingLLM, + tools: list[BaseTool], + chat_history: list[ChatMessage], +) -> AsyncGenerator[ChatResponse | bool, None]: + response_stream = await llm.astream_chat_with_tools( + tools, + chat_history=chat_history, + allow_parallel_tool_calls=False, + ) + + full_response = None + yielded_indicator = False + async for chunk in response_stream: + if "tool_calls" not in chunk.message.additional_kwargs: + # Yield a boolean to indicate whether the response is a tool call + if not yielded_indicator: + yield False + yielded_indicator = True + + # if not a tool call, yield the chunks! + yield chunk # type: ignore + elif not yielded_indicator: + # Yield the indicator for a tool call + yield True + yielded_indicator = True + + full_response = chunk + + if full_response: + yield full_response # type: ignore