Skip to content
Snippets Groups Projects
Unverified Commit 6edea6af authored by Huu Le's avatar Huu Le Committed by GitHub
Browse files

enhance workflow code for Python (#412)


* enhance workflow shared code

* fix streaming

* refactor code

* add missing helper

* update

* update form filling

* add filters

* simplify the code

* simplify the code

* simplify the code

* update form filling

* update e2e

* update function calling agent

* fix unneeded condition

* Create light-parrots-work.md

* revert change on using functioncallingagent

* update readme

* clean code

* extract call one tool function

* update for blog use case

* fix streaming

* fix e2e

* fix missing await

* improve tools code

* improve assertion code

* skip form filling test for TS framework

* update for tools helper

---------

Co-authored-by: default avatarMarcus Schiesser <mail@marcusschiesser.de>
parent d79d1652
No related branches found
No related tags found
No related merge requests found
Showing
with 574 additions and 408 deletions
---
"create-llama": patch
---
Optimize generated workflow code for Python
......@@ -18,7 +18,7 @@ const templateUI: TemplateUI = "shadcn";
const templatePostInstallAction: TemplatePostInstallAction = "runApp";
const appType: AppType = templateFramework === "nextjs" ? "" : "--frontend";
const userMessage = "Write a blog post about physical standards for letters";
const templateAgents = ["financial_report", "blog"];
const templateAgents = ["financial_report", "blog", "form_filling"];
for (const agents of templateAgents) {
test.describe(`Test multiagent template ${agents} ${templateFramework} ${dataSource} ${templateUI} ${appType} ${templatePostInstallAction}`, async () => {
......@@ -26,6 +26,10 @@ for (const agents of templateAgents) {
process.platform !== "linux" || process.env.DATASOURCE === "--no-files",
"The multiagent template currently only works with files. We also only run on Linux to speed up tests.",
);
test.skip(
agents === "form_filling" && templateFramework !== "fastapi",
"Form filling is currently only supported with FastAPI.",
);
let port: number;
let externalPort: number;
let cwd: string;
......@@ -68,6 +72,10 @@ for (const agents of templateAgents) {
test("Frontend should be able to submit a message and receive the start of a streamed response", async ({
page,
}) => {
test.skip(
agents === "financial_report" || agents === "form_filling",
"Skip chat tests for financial report and form filling.",
);
await page.goto(`http://localhost:${port}`);
await page.fill("form textarea", userMessage);
......
......@@ -8,9 +8,9 @@ This example is using three agents to generate a blog post:
There are three different methods how the agents can interact to reach their goal:
1. [Choreography](./app/examples/choreography.py) - the agents decide themselves to delegate a task to another agent
1. [Orchestrator](./app/examples/orchestrator.py) - a central orchestrator decides which agent should execute a task
1. [Explicit Workflow](./app/examples/workflow.py) - a pre-defined workflow specific for the task is used to execute the tasks
1. [Choreography](./app/agents/choreography.py) - the agents decide themselves to delegate a task to another agent
1. [Orchestrator](./app/agents/orchestrator.py) - a central orchestrator decides which agent should execute a task
1. [Explicit Workflow](./app/agents/workflow.py) - a pre-defined workflow specific for the task is used to execute the tasks
## Getting Started
......
from .blog import create_workflow
__all__ = ["create_workflow"]
......@@ -4,17 +4,18 @@ from typing import List, Optional
from app.agents.choreography import create_choreography
from app.agents.orchestrator import create_orchestrator
from app.agents.workflow import create_workflow
from app.agents.workflow import create_workflow as create_blog_workflow
from llama_index.core.chat_engine.types import ChatMessage
from llama_index.core.workflow import Workflow
logger = logging.getLogger("uvicorn")
def get_chat_engine(
def create_workflow(
chat_history: Optional[List[ChatMessage]] = None, **kwargs
) -> Workflow:
# TODO: the EXAMPLE_TYPE could be passed as a chat config parameter?
# Chat filters are not supported yet
kwargs.pop("filters", None)
agent_type = os.getenv("EXAMPLE_TYPE", "").lower()
match agent_type:
case "choreography":
......@@ -22,7 +23,7 @@ def get_chat_engine(
case "orchestrator":
agent = create_orchestrator(chat_history, **kwargs)
case _:
agent = create_workflow(chat_history, **kwargs)
agent = create_blog_workflow(chat_history, **kwargs)
logger.info(f"Using agent pattern: {agent_type}")
......
......@@ -42,9 +42,9 @@ class AgentRunEvent(Event):
return {
"type": "agent",
"data": {
"name": self.name,
"agent": self.name,
"type": self.event_type.value,
"msg": self.msg,
"text": self.msg,
"data": self.data,
},
}
......
......@@ -33,7 +33,7 @@ curl --location 'localhost:8000/api/chat' \
--data '{ "messages": [{ "role": "user", "content": "Create a report comparing the finances of Apple and Tesla" }] }'
```
You can start editing the API by modifying `app/api/routers/chat.py` or `app/financial_report/workflow.py`. The API auto-updates as you save the files.
You can start editing the API by modifying `app/api/routers/chat.py` or `app/workflows/financial_report.py`. The API auto-updates as you save the files.
Open [http://localhost:8000/docs](http://localhost:8000/docs) with your browser to see the Swagger UI of the API.
......
from textwrap import dedent
from typing import List, Tuple
from app.engine.tools import ToolFactory
from app.workflows.single import FunctionCallingAgent
from llama_index.core.chat_engine.types import ChatMessage
from llama_index.core.tools import FunctionTool
def _get_analyst_params() -> Tuple[List[type[FunctionTool]], str, str]:
tools = []
prompt_instructions = dedent(
"""
You are an expert in analyzing financial data.
You are given a task and a set of financial data to analyze. Your task is to analyze the financial data and return a report.
Your response should include a detailed analysis of the financial data, including any trends, patterns, or insights that you find.
Construct the analysis in a textual format like tables would be great!
Don't need to synthesize the data, just analyze and provide your findings.
Always use the provided information, don't make up any information yourself.
"""
)
description = "Expert in analyzing financial data"
configured_tools = ToolFactory.from_env(map_result=True)
# Check if the interpreter tool is configured
if "interpret" in configured_tools.keys():
tools.append(configured_tools["interpret"])
prompt_instructions += dedent("""
You are able to visualize the financial data using code interpreter tool.
It's very useful to create and include visualizations to the report (make sure you include the right code and data for the visualization).
Never include any code into the report, just the visualization.
""")
description += (
", able to visualize the financial data using code interpreter tool."
)
return tools, prompt_instructions, description
def create_analyst(chat_history: List[ChatMessage]):
tools, prompt_instructions, description = _get_analyst_params()
return FunctionCallingAgent(
name="analyst",
tools=tools,
description=description,
system_prompt=dedent(prompt_instructions),
chat_history=chat_history,
)
from textwrap import dedent
from typing import List, Tuple
from app.engine.tools import ToolFactory
from app.workflows.single import FunctionCallingAgent
from llama_index.core.chat_engine.types import ChatMessage
from llama_index.core.tools import BaseTool
def _get_reporter_params(
chat_history: List[ChatMessage],
) -> Tuple[List[type[BaseTool]], str, str]:
tools: List[type[BaseTool]] = []
description = "Expert in representing a financial report"
prompt_instructions = dedent(
"""
You are a report generation assistant tasked with producing a well-formatted report given parsed context.
Given a comprehensive analysis of the user request, your task is to synthesize the information and return a well-formatted report.
## Instructions
You are responsible for representing the analysis in a well-formatted report. If tables or visualizations provided, add them to the right sections that are most relevant.
Use only the provided information to create the report. Do not make up any information yourself.
Finally, the report should be presented in markdown format.
"""
)
configured_tools = ToolFactory.from_env(map_result=True)
if "generate_document" in configured_tools: # type: ignore
tools.append(configured_tools["generate_document"]) # type: ignore
prompt_instructions += (
"\nYou are also able to generate a file document (PDF/HTML) of the report."
)
description += " and generate a file document (PDF/HTML) of the report."
return tools, description, prompt_instructions
def create_reporter(chat_history: List[ChatMessage]):
tools, description, prompt_instructions = _get_reporter_params(chat_history)
return FunctionCallingAgent(
name="reporter",
tools=tools,
description=description,
system_prompt=prompt_instructions,
chat_history=chat_history,
)
import os
from textwrap import dedent
from typing import List, Optional
from app.engine.index import IndexConfig, get_index
from app.workflows.single import FunctionCallingAgent
from llama_index.core.chat_engine.types import ChatMessage
from llama_index.core.tools import BaseTool, QueryEngineTool, ToolMetadata
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
def _create_query_engine_tools(params=None) -> Optional[list[type[BaseTool]]]:
"""
Provide an agent worker that can be used to query the index.
"""
# Add query tool if index exists
index_config = IndexConfig(**(params or {}))
index = get_index(index_config)
if index is None:
return None
top_k = int(os.getenv("TOP_K", 5))
# Construct query engine tools
tools = []
# If index is LlamaCloudIndex, we need to add chunk and doc retriever tools
if isinstance(index, LlamaCloudIndex):
# Document retriever
doc_retriever = index.as_query_engine(
retriever_mode="files_via_content",
similarity_top_k=top_k,
)
chunk_retriever = index.as_query_engine(
retriever_mode="chunks",
similarity_top_k=top_k,
)
tools.append(
QueryEngineTool(
query_engine=doc_retriever,
metadata=ToolMetadata(
name="document_retriever",
description=dedent(
"""
Document retriever that retrieves entire documents from the corpus.
ONLY use for research questions that may require searching over entire research reports.
Will be slower and more expensive than chunk-level retrieval but may be necessary.
"""
),
),
)
)
tools.append(
QueryEngineTool(
query_engine=chunk_retriever,
metadata=ToolMetadata(
name="chunk_retriever",
description=dedent(
"""
Retrieves a small set of relevant document chunks from the corpus.
Use for research questions that want to look up specific facts from the knowledge corpus,
and need entire documents.
"""
),
),
)
)
else:
query_engine = index.as_query_engine(
**({"similarity_top_k": top_k} if top_k != 0 else {})
)
tools.append(
QueryEngineTool(
query_engine=query_engine,
metadata=ToolMetadata(
name="retrieve_information",
description="Use this tool to retrieve information about the text corpus from the index.",
),
)
)
return tools
def create_researcher(chat_history: List[ChatMessage], **kwargs):
"""
Researcher is an agent that take responsibility for using tools to complete a given task.
"""
tools = _create_query_engine_tools(**kwargs)
if tools is None:
raise ValueError("No tools found for researcher agent")
return FunctionCallingAgent(
name="researcher",
tools=tools,
description="expert in retrieving any unknown content from the corpus",
system_prompt=dedent(
"""
You are a researcher agent. You are responsible for retrieving information from the corpus.
## Instructions
+ Don't synthesize the information, just return the whole retrieved information.
+ Don't need to retrieve the information that is already provided in the chat history and response with: "There is no new information, please reuse the information from the conversation."
"""
),
chat_history=chat_history,
)
from textwrap import dedent
from typing import AsyncGenerator, List, Optional
from app.agents.analyst import create_analyst
from app.agents.reporter import create_reporter
from app.agents.researcher import create_researcher
from app.workflows.single import AgentRunEvent, AgentRunResult, FunctionCallingAgent
from llama_index.core.chat_engine.types import ChatMessage
from llama_index.core.prompts import PromptTemplate
from llama_index.core.settings import Settings
from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
)
def create_workflow(chat_history: Optional[List[ChatMessage]] = None, **kwargs):
researcher = create_researcher(
chat_history=chat_history,
**kwargs,
)
analyst = create_analyst(chat_history=chat_history)
reporter = create_reporter(chat_history=chat_history)
workflow = FinancialReportWorkflow(timeout=360, chat_history=chat_history)
workflow.add_workflows(
researcher=researcher,
analyst=analyst,
reporter=reporter,
)
return workflow
class ResearchEvent(Event):
input: str
class AnalyzeEvent(Event):
input: str
class ReportEvent(Event):
input: str
class FinancialReportWorkflow(Workflow):
def __init__(
self, timeout: int = 360, chat_history: Optional[List[ChatMessage]] = None
):
super().__init__(timeout=timeout)
self.chat_history = chat_history or []
@step()
async def start(self, ctx: Context, ev: StartEvent) -> ResearchEvent | ReportEvent:
# set streaming
ctx.data["streaming"] = getattr(ev, "streaming", False)
# start the workflow with researching about a topic
ctx.data["task"] = ev.input
ctx.data["user_input"] = ev.input
# Decision-making process
decision = await self._decide_workflow(ev.input, self.chat_history)
if decision != "publish":
return ResearchEvent(input=f"Research for this task: {ev.input}")
else:
chat_history_str = "\n".join(
[f"{msg.role}: {msg.content}" for msg in self.chat_history]
)
return ReportEvent(
input=f"Create a report based on the chat history\n{chat_history_str}\n\n and task: {ev.input}"
)
async def _decide_workflow(
self, input: str, chat_history: List[ChatMessage]
) -> str:
# TODO: Refactor this by using prompt generation
prompt_template = PromptTemplate(
dedent(
"""
You are an expert in decision-making, helping people create financial reports for the provided data.
If the user doesn't need to add or update anything, respond with 'publish'.
Otherwise, respond with 'research'.
Here is the chat history:
{chat_history}
The current user request is:
{input}
Given the chat history and the new user request, decide whether to create a report based on existing information.
Decision (respond with either 'not_publish' or 'publish'):
"""
)
)
chat_history_str = "\n".join(
[f"{msg.role}: {msg.content}" for msg in chat_history]
)
prompt = prompt_template.format(chat_history=chat_history_str, input=input)
output = await Settings.llm.acomplete(prompt)
decision = output.text.strip().lower()
return "publish" if decision == "publish" else "research"
@step()
async def research(
self, ctx: Context, ev: ResearchEvent, researcher: FunctionCallingAgent
) -> AnalyzeEvent:
result: AgentRunResult = await self.run_agent(ctx, researcher, ev.input)
content = result.response.message.content
return AnalyzeEvent(
input=dedent(
f"""
Given the following research content:
{content}
Provide a comprehensive analysis of the data for the user's request: {ctx.data["task"]}
"""
)
)
@step()
async def analyze(
self, ctx: Context, ev: AnalyzeEvent, analyst: FunctionCallingAgent
) -> ReportEvent | StopEvent:
result: AgentRunResult = await self.run_agent(ctx, analyst, ev.input)
content = result.response.message.content
return ReportEvent(
input=dedent(
f"""
Given the following analysis:
{content}
Create a report for the user's request: {ctx.data["task"]}
"""
)
)
@step()
async def report(
self, ctx: Context, ev: ReportEvent, reporter: FunctionCallingAgent
) -> StopEvent:
try:
result: AgentRunResult = await self.run_agent(
ctx, reporter, ev.input, streaming=ctx.data["streaming"]
)
return StopEvent(result=result)
except Exception as e:
ctx.write_event_to_stream(
AgentRunEvent(
name=reporter.name,
msg=f"Error creating a report: {e}",
)
)
return StopEvent(result=None)
async def run_agent(
self,
ctx: Context,
agent: FunctionCallingAgent,
input: str,
streaming: bool = False,
) -> AgentRunResult | AsyncGenerator:
handler = agent.run(input=input, streaming=streaming)
# bubble all events while running the executor to the planner
async for event in handler.stream_events():
# Don't write the StopEvent from sub task to the stream
if type(event) is not StopEvent:
ctx.write_event_to_stream(event)
return await handler
from typing import List, Optional
from app.agents.workflow import create_workflow
from llama_index.core.chat_engine.types import ChatMessage
from llama_index.core.workflow import Workflow
def get_chat_engine(
chat_history: Optional[List[ChatMessage]] = None, **kwargs
) -> Workflow:
agent_workflow = create_workflow(chat_history, **kwargs)
return agent_workflow
from .financial_report import create_workflow
__all__ = ["create_workflow"]
import os
from typing import Any, Dict, List, Optional
from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection
from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
)
def create_workflow(
chat_history: Optional[List[ChatMessage]] = None,
params: Optional[Dict[str, Any]] = None,
filters: Optional[List[Any]] = None,
) -> Workflow:
index_config = IndexConfig(**params)
index: VectorStoreIndex = get_index(config=index_config)
if index is None:
query_engine_tool = None
else:
top_k = int(os.getenv("TOP_K", 10))
query_engine = index.as_query_engine(similarity_top_k=top_k, filters=filters)
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
configured_tools: Dict[str, FunctionTool] = ToolFactory.from_env(map_result=True) # type: ignore
code_interpreter_tool = configured_tools.get("interpret")
document_generator_tool = configured_tools.get("generate_document")
return FinancialReportWorkflow(
query_engine_tool=query_engine_tool,
code_interpreter_tool=code_interpreter_tool,
document_generator_tool=document_generator_tool,
chat_history=chat_history,
)
class InputEvent(Event):
input: List[ChatMessage]
response: bool = False
class ResearchEvent(Event):
input: list[ToolSelection]
class AnalyzeEvent(Event):
input: list[ToolSelection] | ChatMessage
class ReportEvent(Event):
input: list[ToolSelection]
class FinancialReportWorkflow(Workflow):
"""
A workflow to generate a financial report using indexed documents.
Requirements:
- Indexed documents containing financial data and a query engine tool to search them
- A code interpreter tool to analyze data and generate reports
- A document generator tool to create report files
Steps:
1. LLM Input: The LLM determines the next step based on function calling.
For example, if the model requests the query engine tool, it returns a ResearchEvent;
if it requests document generation, it returns a ReportEvent.
2. Research: Uses the query engine to find relevant chunks from indexed documents.
After gathering information, it requests analysis (step 3).
3. Analyze: Uses a custom prompt to analyze research results and can call the code
interpreter tool for visualization or calculation. Returns results to the LLM.
4. Report: Uses the document generator tool to create a report. Returns results to the LLM.
"""
_default_system_prompt = """
You are a financial analyst who are given a set of tools to help you.
It's good to using appropriate tools for the user request and always use the information from the tools, don't make up anything yourself.
For the query engine tool, you should break down the user request into a list of queries and call the tool with the queries.
"""
def __init__(
self,
query_engine_tool: QueryEngineTool,
code_interpreter_tool: FunctionTool,
document_generator_tool: FunctionTool,
llm: Optional[FunctionCallingLLM] = None,
timeout: int = 360,
chat_history: Optional[List[ChatMessage]] = None,
system_prompt: Optional[str] = None,
):
super().__init__(timeout=timeout)
self.system_prompt = system_prompt or self._default_system_prompt
self.chat_history = chat_history or []
self.query_engine_tool = query_engine_tool
self.code_interpreter_tool = code_interpreter_tool
self.document_generator_tool = document_generator_tool
assert (
query_engine_tool is not None
), "Query engine tool is not found. Try run generation script or upload a document file first."
assert code_interpreter_tool is not None, "Code interpreter tool is required"
assert (
document_generator_tool is not None
), "Document generator tool is required"
self.tools = [
self.query_engine_tool,
self.code_interpreter_tool,
self.document_generator_tool,
]
self.llm: FunctionCallingLLM = llm or Settings.llm
assert isinstance(self.llm, FunctionCallingLLM)
self.memory = ChatMemoryBuffer.from_defaults(
llm=self.llm, chat_history=self.chat_history
)
@step()
async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent:
ctx.data["input"] = ev.input
if self.system_prompt:
system_msg = ChatMessage(
role=MessageRole.SYSTEM, content=self.system_prompt
)
self.memory.put(system_msg)
# Add user input to memory
self.memory.put(ChatMessage(role=MessageRole.USER, content=ev.input))
return InputEvent(input=self.memory.get())
@step()
async def handle_llm_input( # type: ignore
self,
ctx: Context,
ev: InputEvent,
) -> ResearchEvent | AnalyzeEvent | ReportEvent | StopEvent:
"""
Handle an LLM input and decide the next step.
"""
# Always use the latest chat history from the input
chat_history: list[ChatMessage] = ev.input
# Get tool calls
response = await chat_with_tools(
self.llm,
self.tools, # type: ignore
chat_history,
)
if not response.has_tool_calls():
# If no tool call, return the response generator
return StopEvent(result=response.generator)
# calling different tools at the same time is not supported at the moment
# add an error message to tell the AI to process step by step
if response.is_calling_different_tools():
self.memory.put(
ChatMessage(
role=MessageRole.ASSISTANT,
content="Cannot call different tools at the same time. Try calling one tool at a time.",
)
)
return InputEvent(input=self.memory.get())
self.memory.put(response.tool_call_message)
match response.tool_name():
case self.code_interpreter_tool.metadata.name:
return AnalyzeEvent(input=response.tool_calls)
case self.document_generator_tool.metadata.name:
return ReportEvent(input=response.tool_calls)
case self.query_engine_tool.metadata.name:
return ResearchEvent(input=response.tool_calls)
case _:
raise ValueError(f"Unknown tool: {response.tool_name()}")
@step()
async def research(self, ctx: Context, ev: ResearchEvent) -> AnalyzeEvent:
"""
Do a research to gather information for the user's request.
A researcher should have these tools: query engine, search engine, etc.
"""
ctx.write_event_to_stream(
AgentRunEvent(
name="Researcher",
msg="Starting research",
)
)
tool_calls = ev.input
tool_messages = await call_tools(
ctx=ctx,
agent_name="Researcher",
tools=[self.query_engine_tool],
tool_calls=tool_calls,
)
self.memory.put_messages(tool_messages)
return AnalyzeEvent(
input=ChatMessage(
role=MessageRole.ASSISTANT,
content="I've finished the research. Please analyze the result.",
),
)
@step()
async def analyze(self, ctx: Context, ev: AnalyzeEvent) -> InputEvent:
"""
Analyze the research result.
"""
ctx.write_event_to_stream(
AgentRunEvent(
name="Analyst",
msg="Starting analysis",
)
)
event_requested_by_workflow_llm = isinstance(ev.input, list)
# Requested by the workflow LLM Input step, it's a tool call
if event_requested_by_workflow_llm:
# Set the tool calls
tool_calls = ev.input
else:
# Otherwise, it's triggered by the research step
# Use a custom prompt and independent memory for the analyst agent
analysis_prompt = """
You are a financial analyst, you are given a research result and a set of tools to help you.
Always use the given information, don't make up anything yourself. If there is not enough information, you can asking for more information.
If you have enough numerical information, it's good to include some charts/visualizations to the report so you can use the code interpreter tool to generate a report.
"""
# This is handled by analyst agent
# Clone the shared memory to avoid conflicting with the workflow.
chat_history = self.memory.get()
chat_history.append(
ChatMessage(role=MessageRole.SYSTEM, content=analysis_prompt)
)
chat_history.append(ev.input) # type: ignore
# Check if the analyst agent needs to call tools
response = await chat_with_tools(
self.llm,
[self.code_interpreter_tool],
chat_history,
)
if not response.has_tool_calls():
# If no tool call, fallback analyst message to the workflow
analyst_msg = ChatMessage(
role=MessageRole.ASSISTANT,
content=await response.full_response(),
)
self.memory.put(analyst_msg)
return InputEvent(input=self.memory.get())
else:
# Set the tool calls and the tool call message to the memory
tool_calls = response.tool_calls
self.memory.put(response.tool_call_message)
# Call tools
tool_messages = await call_tools(
ctx=ctx,
agent_name="Analyst",
tools=[self.code_interpreter_tool],
tool_calls=tool_calls, # type: ignore
)
self.memory.put_messages(tool_messages)
# Fallback to the input with the latest chat history
return InputEvent(input=self.memory.get())
@step()
async def report(self, ctx: Context, ev: ReportEvent) -> InputEvent:
"""
Generate a report based on the analysis result.
"""
ctx.write_event_to_stream(
AgentRunEvent(
name="Reporter",
msg="Starting report generation",
)
)
tool_calls = ev.input
tool_messages = await call_tools(
ctx=ctx,
agent_name="Reporter",
tools=[self.document_generator_tool],
tool_calls=tool_calls,
)
self.memory.put_messages(tool_messages)
# After the tool calls, fallback to the input with the latest chat history
return InputEvent(input=self.memory.get())
......@@ -39,7 +39,7 @@ curl --location 'localhost:8000/api/chat' \
--data '{ "messages": [{ "role": "user", "content": "What can you do?" }] }'
```
You can start editing the API by modifying `app/api/routers/chat.py` or `app/agents/form_filling.py`. The API auto-updates as you save the files.
You can start editing the API by modifying `app/api/routers/chat.py` or `app/workflows/form_filling.py`. The API auto-updates as you save the files.
Open [http://localhost:8000/docs](http://localhost:8000/docs) with your browser to see the Swagger UI of the API.
......
from typing import List, Optional
from app.agents.form_filling import create_workflow
from llama_index.core.chat_engine.types import ChatMessage
from llama_index.core.workflow import Workflow
def get_chat_engine(
chat_history: Optional[List[ChatMessage]] = None, **kwargs
) -> Workflow:
return create_workflow(chat_history=chat_history, **kwargs)
from .form_filling import create_workflow
__all__ = ["create_workflow"]
import os
import uuid
from enum import Enum
from typing import AsyncGenerator, List, Optional
from typing import Any, Dict, List, Optional
from app.engine.index import get_index
from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.engine.tools.form_filling import CellValue, MissingCell
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection
from llama_index.core.tools.types import ToolOutput
from llama_index.core.workflow import (
Context,
Event,
......@@ -21,31 +22,34 @@ from llama_index.core.workflow import (
Workflow,
step,
)
from pydantic import Field
def create_workflow(
chat_history: Optional[List[ChatMessage]] = None, **kwargs
chat_history: Optional[List[ChatMessage]] = None,
params: Optional[Dict[str, Any]] = None,
filters: Optional[List[Any]] = None,
) -> Workflow:
index: VectorStoreIndex = get_index()
if params is None:
params = {}
if filters is None:
filters = []
index_config = IndexConfig(**params)
index: VectorStoreIndex = get_index(config=index_config)
if index is None:
query_engine_tool = None
else:
top_k = int(os.getenv("TOP_K", 10))
query_engine = index.as_query_engine(similarity_top_k=top_k)
query_engine = index.as_query_engine(similarity_top_k=top_k, filters=filters)
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
configured_tools = ToolFactory.from_env(map_result=True)
extractor_tool = configured_tools.get("extract_questions")
filling_tool = configured_tools.get("fill_form")
if extractor_tool is None or filling_tool is None:
raise ValueError("Extractor or filling tool is not found!")
extractor_tool = configured_tools.get("extract_questions") # type: ignore
filling_tool = configured_tools.get("fill_form") # type: ignore
workflow = FormFillingWorkflow(
query_engine_tool=query_engine_tool,
extractor_tool=extractor_tool,
filling_tool=filling_tool,
extractor_tool=extractor_tool, # type: ignore
filling_tool=filling_tool, # type: ignore
chat_history=chat_history,
)
......@@ -58,38 +62,15 @@ class InputEvent(Event):
class ExtractMissingCellsEvent(Event):
tool_call: ToolSelection
tool_calls: list[ToolSelection]
class FindAnswersEvent(Event):
missing_cells: list[MissingCell]
tool_calls: list[ToolSelection]
class FillEvent(Event):
tool_call: ToolSelection
class AgentRunEventType(Enum):
TEXT = "text"
PROGRESS = "progress"
class AgentRunEvent(Event):
name: str
msg: str
event_type: AgentRunEventType = Field(default=AgentRunEventType.TEXT)
data: Optional[dict] = None
def to_response(self) -> dict:
return {
"type": "agent",
"data": {
"agent": self.name,
"type": self.event_type.value,
"text": self.msg,
"data": self.data,
},
}
tool_calls: list[ToolSelection]
class FormFillingWorkflow(Workflow):
......@@ -108,12 +89,14 @@ class FormFillingWorkflow(Workflow):
_default_system_prompt = """
You are a helpful assistant who helps fill missing cells in a CSV file.
Only use provided data, never make up any information yourself. Fill N/A if the answer is not found.
Only extract missing cells from CSV files.
Only use provided data - never make up any information yourself. Fill N/A if an answer is not found.
If there is no query engine tool or the gathered information has many N/A values indicating the questions don't match the data, respond with a warning and ask the user to upload a different file or connect to a knowledge base.
"""
def __init__(
self,
query_engine_tool: QueryEngineTool,
query_engine_tool: Optional[QueryEngineTool],
extractor_tool: FunctionTool,
filling_tool: FunctionTool,
llm: Optional[FunctionCallingLLM] = None,
......@@ -127,6 +110,11 @@ class FormFillingWorkflow(Workflow):
self.query_engine_tool = query_engine_tool
self.extractor_tool = extractor_tool
self.filling_tool = filling_tool
if self.extractor_tool is None or self.filling_tool is None:
raise ValueError("Extractor and filling tools are required.")
self.tools = [self.extractor_tool, self.filling_tool]
if self.query_engine_tool is not None:
self.tools.append(self.query_engine_tool) # type: ignore
self.llm: FunctionCallingLLM = llm or Settings.llm
if not isinstance(self.llm, FunctionCallingLLM):
raise ValueError("FormFillingWorkflow only supports FunctionCallingLLM.")
......@@ -136,7 +124,6 @@ class FormFillingWorkflow(Workflow):
@step()
async def start(self, ctx: Context, ev: StartEvent) -> InputEvent:
ctx.data["streaming"] = getattr(ev, "streaming", False)
ctx.data["input"] = ev.input
if self.system_prompt:
......@@ -152,7 +139,7 @@ class FormFillingWorkflow(Workflow):
chat_history = self.memory.get()
return InputEvent(input=chat_history)
@step(pass_context=True)
@step()
async def handle_llm_input( # type: ignore
self,
ctx: Context,
......@@ -162,22 +149,33 @@ class FormFillingWorkflow(Workflow):
Handle an LLM input and decide the next step.
"""
chat_history: list[ChatMessage] = ev.input
generator = self._tool_call_generator(chat_history)
# Check for immediate tool call
is_tool_call = await generator.__anext__()
if is_tool_call:
full_response = await generator.__anext__()
tool_calls = self.llm.get_tool_calls_from_response(full_response) # type: ignore
for tool_call in tool_calls:
if tool_call.tool_name == self.extractor_tool.metadata.get_name():
ctx.send_event(ExtractMissingCellsEvent(tool_call=tool_call))
elif tool_call.tool_name == self.filling_tool.metadata.get_name():
ctx.send_event(FillEvent(tool_call=tool_call))
else:
# If no tool call, return the generator
return StopEvent(result=generator)
response = await chat_with_tools(
self.llm,
self.tools,
chat_history,
)
if not response.has_tool_calls():
return StopEvent(result=response.generator)
# calling different tools at the same time is not supported at the moment
# add an error message to tell the AI to process step by step
if response.is_calling_different_tools():
self.memory.put(
ChatMessage(
role=MessageRole.ASSISTANT,
content="Cannot call different tools at the same time. Try calling one tool at a time.",
)
)
return InputEvent(input=self.memory.get())
self.memory.put(response.tool_call_message)
match response.tool_name():
case self.extractor_tool.metadata.name:
return ExtractMissingCellsEvent(tool_calls=response.tool_calls)
case self.query_engine_tool.metadata.name:
return FindAnswersEvent(tool_calls=response.tool_calls)
case self.filling_tool.metadata.name:
return FillEvent(tool_calls=response.tool_calls)
case _:
raise ValueError(f"Unknown tool: {response.tool_name()}")
@step()
async def extract_missing_cells(
......@@ -193,38 +191,14 @@ class FormFillingWorkflow(Workflow):
)
)
# Call the extract questions tool
response = self._call_tool(
ctx,
tool_messages = await call_tools(
agent_name="Extractor",
tool=self.extractor_tool,
tool_selection=ev.tool_call,
tools=[self.extractor_tool],
ctx=ctx,
tool_calls=ev.tool_calls,
)
if response.is_error:
return InputEvent(input=self.memory.get())
missing_cells = response.raw_output.get("missing_cells", [])
message = ChatMessage(
role=MessageRole.TOOL,
content=str(missing_cells),
additional_kwargs={
"tool_call_id": ev.tool_call.tool_id,
"name": ev.tool_call.tool_name,
},
)
self.memory.put(message)
if self.query_engine_tool is None:
# Fallback to input that query engine tool is not found so that cannot answer questions
self.memory.put(
ChatMessage(
role=MessageRole.ASSISTANT,
content="Extracted missing cells but query engine tool is not found so cannot answer questions. Ask user to upload file or connect to a knowledge base.",
)
)
return InputEvent(input=self.memory.get())
# Forward missing cells information to find answers step
return FindAnswersEvent(missing_cells=missing_cells)
self.memory.put_messages(tool_messages)
return InputEvent(input=self.memory.get())
@step()
async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
......@@ -237,63 +211,13 @@ class FormFillingWorkflow(Workflow):
msg="Finding answers for missing cells",
)
)
missing_cells = ev.missing_cells
# If missing cells information is not found, fallback to other tools
# It means that the extractor tool has not been called yet
# Fallback to input
if missing_cells is None:
ctx.write_event_to_stream(
AgentRunEvent(
name="Researcher",
msg="Error: Missing cells information not found. Fallback to other tools.",
)
)
message = ChatMessage(
role=MessageRole.TOOL,
content="Error: Missing cells information not found.",
additional_kwargs={
"tool_call_id": ev.tool_call.tool_id,
"name": ev.tool_call.tool_name,
},
)
self.memory.put(message)
return InputEvent(input=self.memory.get())
cell_values: list[CellValue] = []
# Iterate over missing cells and query for the answers
# and stream the progress
progress_id = str(uuid.uuid4())
total_steps = len(missing_cells)
for i, cell in enumerate(missing_cells):
if cell.question_to_answer is None:
continue
ctx.write_event_to_stream(
AgentRunEvent(
name="Researcher",
msg=f"Querying for: {cell.question_to_answer}",
event_type=AgentRunEventType.PROGRESS,
data={
"id": progress_id,
"total": total_steps,
"current": i,
},
)
)
# Call query engine tool directly
answer = await self.query_engine_tool.acall(query=cell.question_to_answer)
cell_values.append(
CellValue(
row_index=cell.row_index,
column_index=cell.column_index,
value=str(answer),
)
)
self.memory.put(
ChatMessage(
role=MessageRole.ASSISTANT,
content=str(cell_values),
)
tool_messages = await call_tools(
ctx=ctx,
agent_name="Researcher",
tools=[self.query_engine_tool],
tool_calls=ev.tool_calls,
)
self.memory.put_messages(tool_messages)
return InputEvent(input=self.memory.get())
@step()
......@@ -307,91 +231,11 @@ class FormFillingWorkflow(Workflow):
msg="Filling missing cells",
)
)
# Call the fill cells tool
result = self._call_tool(
ctx,
tool_messages = await call_tools(
agent_name="Processor",
tool=self.filling_tool,
tool_selection=ev.tool_call,
)
if result.is_error:
return InputEvent(input=self.memory.get())
message = ChatMessage(
role=MessageRole.TOOL,
content=str(result.raw_output),
additional_kwargs={
"tool_call_id": ev.tool_call.tool_id,
"name": ev.tool_call.tool_name,
},
)
self.memory.put(message)
return InputEvent(input=self.memory.get(), response=True)
async def _tool_call_generator(
self, chat_history: list[ChatMessage]
) -> AsyncGenerator[ChatMessage | bool, None]:
response_stream = await self.llm.astream_chat_with_tools(
[self.extractor_tool, self.filling_tool],
chat_history=chat_history,
tools=[self.filling_tool],
ctx=ctx,
tool_calls=ev.tool_calls,
)
full_response = None
yielded_indicator = False
async for chunk in response_stream:
if "tool_calls" not in chunk.message.additional_kwargs:
# Yield a boolean to indicate whether the response is a tool call
if not yielded_indicator:
yield False
yielded_indicator = True
# if not a tool call, yield the chunks!
yield chunk
elif not yielded_indicator:
# Yield the indicator for a tool call
yield True
yielded_indicator = True
full_response = chunk
# Write the full response to memory and yield it
if full_response:
self.memory.put(full_response.message)
yield full_response
def _call_tool(
self,
ctx: Context,
agent_name: str,
tool: FunctionTool,
tool_selection: ToolSelection,
) -> ToolOutput:
"""
Safely call a tool and handle errors.
"""
try:
response: ToolOutput = tool.call(**tool_selection.tool_kwargs)
return response
except Exception as e:
ctx.write_event_to_stream(
AgentRunEvent(
name=agent_name,
msg=f"Error: {str(e)}",
)
)
message = ChatMessage(
role=MessageRole.TOOL,
content=f"Error: {str(e)}",
additional_kwargs={
"tool_call_id": tool_selection.tool_id,
"name": tool.metadata.get_name(),
},
)
self.memory.put(message)
return ToolOutput(
content=f"Error: {str(e)}",
tool_name=tool.metadata.get_name(),
raw_input=tool_selection.tool_kwargs,
raw_output=None,
is_error=True,
)
self.memory.put_messages(tool_messages)
return InputEvent(input=self.memory.get())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment