Skip to content
Snippets Groups Projects
Commit 22e4be93 authored by leehuwuj's avatar leehuwuj
Browse files

use agent workflow for financial report use case

parent 6d5749d6
No related branches found
No related tags found
No related merge requests found
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional, Tuple, Type
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection
from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
from llama_index.core.agent.workflow import (
AgentWorkflow,
FunctionAgent,
ReActAgent,
)
from llama_index.core.llms import LLM
from llama_index.core.tools import FunctionTool, QueryEngineTool
from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)
def create_workflow(
params: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Workflow:
):
query_engine_tool, code_interpreter_tool, document_generator_tool = _prepare_tools(
params, **kwargs
)
agent_cls = _get_agent_cls_from_llm(Settings.llm)
research_agent = agent_cls(
name="researcher",
description="A financial researcher who are given a tasks that need to look up information from the financial documents about the user's request.",
tools=[query_engine_tool],
system_prompt="""
You are a financial researcher who are given a tasks that need to look up information from the financial documents about the user's request.
You should use the query engine tool to look up the information and return the result to the user.
Always handoff the task to the `analyst` agent after gathering information.
""",
llm=Settings.llm,
can_handoff_to=["analyst"],
)
analyst_agent = agent_cls(
name="analyst",
description="A financial analyst who takes responsibility to analyze the financial data and generate a report.",
tools=[code_interpreter_tool],
system_prompt="""
Use the given information, don't make up anything yourself.
If you have enough numerical information, it's good to include some charts/visualizations to the report so you can use the code interpreter tool to generate a report.
You can use the code interpreter tool to generate a report.
Always handoff the task and pass the researched information to the `reporter` agent.
""",
llm=Settings.llm,
can_handoff_to=["reporter"],
)
reporter_agent = agent_cls(
name="reporter",
description="A reporter who takes responsibility to generate a document for a report content.",
tools=[document_generator_tool],
system_prompt="""
Use the document generator tool to generate the document and return the result to the user.
Don't update the content of the document, just generate a new one.
After generating the document, tell the user for the content or the issue if there is any.
""",
llm=Settings.llm,
)
workflow = AgentWorkflow(
agents=[research_agent, analyst_agent, reporter_agent],
root_agent="researcher",
verbose=True,
)
return workflow
def _get_agent_cls_from_llm(llm: LLM) -> Type[FunctionAgent | ReActAgent]:
if llm.metadata.is_function_calling_model:
return FunctionAgent
else:
return ReActAgent
def _prepare_tools(
params: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Tuple[QueryEngineTool, FunctionTool, FunctionTool]:
# Create query engine tool
index_config = IndexConfig(**params)
index = get_index(index_config)
......@@ -41,260 +94,4 @@ def create_workflow(
code_interpreter_tool = configured_tools.get("interpret")
document_generator_tool = configured_tools.get("generate_document")
return FinancialReportWorkflow(
query_engine_tool=query_engine_tool,
code_interpreter_tool=code_interpreter_tool,
document_generator_tool=document_generator_tool,
)
class InputEvent(Event):
input: List[ChatMessage]
response: bool = False
class ResearchEvent(Event):
input: list[ToolSelection]
class AnalyzeEvent(Event):
input: list[ToolSelection] | ChatMessage
class ReportEvent(Event):
input: list[ToolSelection]
class FinancialReportWorkflow(Workflow):
"""
A workflow to generate a financial report using indexed documents.
Requirements:
- Indexed documents containing financial data and a query engine tool to search them
- A code interpreter tool to analyze data and generate reports
- A document generator tool to create report files
Steps:
1. LLM Input: The LLM determines the next step based on function calling.
For example, if the model requests the query engine tool, it returns a ResearchEvent;
if it requests document generation, it returns a ReportEvent.
2. Research: Uses the query engine to find relevant chunks from indexed documents.
After gathering information, it requests analysis (step 3).
3. Analyze: Uses a custom prompt to analyze research results and can call the code
interpreter tool for visualization or calculation. Returns results to the LLM.
4. Report: Uses the document generator tool to create a report. Returns results to the LLM.
"""
_default_system_prompt = """
You are a financial analyst who are given a set of tools to help you.
It's good to using appropriate tools for the user request and always use the information from the tools, don't make up anything yourself.
For the query engine tool, you should break down the user request into a list of queries and call the tool with the queries.
"""
stream: bool = True
def __init__(
self,
query_engine_tool: QueryEngineTool,
code_interpreter_tool: FunctionTool,
document_generator_tool: FunctionTool,
llm: Optional[FunctionCallingLLM] = None,
timeout: int = 360,
system_prompt: Optional[str] = None,
):
super().__init__(timeout=timeout)
self.system_prompt = system_prompt or self._default_system_prompt
self.query_engine_tool = query_engine_tool
self.code_interpreter_tool = code_interpreter_tool
self.document_generator_tool = document_generator_tool
assert query_engine_tool is not None, (
"Query engine tool is not found. Try run generation script or upload a document file first."
)
assert code_interpreter_tool is not None, "Code interpreter tool is required"
assert document_generator_tool is not None, (
"Document generator tool is required"
)
self.tools = [
self.query_engine_tool,
self.code_interpreter_tool,
self.document_generator_tool,
]
self.llm: FunctionCallingLLM = llm or Settings.llm
assert isinstance(self.llm, FunctionCallingLLM)
self.memory = ChatMemoryBuffer.from_defaults(llm=self.llm)
@step()
async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent:
self.stream = ev.get("stream", True)
user_msg = ev.get("user_msg")
chat_history = ev.get("chat_history")
if chat_history is not None:
self.memory.put_messages(chat_history)
# Add user message to memory
self.memory.put(ChatMessage(role=MessageRole.USER, content=user_msg))
if self.system_prompt:
system_msg = ChatMessage(
role=MessageRole.SYSTEM, content=self.system_prompt
)
self.memory.put(system_msg)
return InputEvent(input=self.memory.get())
@step()
async def handle_llm_input( # type: ignore
self,
ctx: Context,
ev: InputEvent,
) -> ResearchEvent | AnalyzeEvent | ReportEvent | StopEvent:
"""
Handle an LLM input and decide the next step.
"""
# Always use the latest chat history from the input
chat_history: list[ChatMessage] = ev.input
# Get tool calls
response = await chat_with_tools(
self.llm,
self.tools, # type: ignore
chat_history,
)
if not response.has_tool_calls():
if self.stream:
return StopEvent(result=response.generator)
else:
return StopEvent(result=await response.full_response())
# calling different tools at the same time is not supported at the moment
# add an error message to tell the AI to process step by step
if response.is_calling_different_tools():
self.memory.put(
ChatMessage(
role=MessageRole.ASSISTANT,
content="Cannot call different tools at the same time. Try calling one tool at a time.",
)
)
return InputEvent(input=self.memory.get())
self.memory.put(response.tool_call_message)
match response.tool_name():
case self.code_interpreter_tool.metadata.name:
return AnalyzeEvent(input=response.tool_calls)
case self.document_generator_tool.metadata.name:
return ReportEvent(input=response.tool_calls)
case self.query_engine_tool.metadata.name:
return ResearchEvent(input=response.tool_calls)
case _:
raise ValueError(f"Unknown tool: {response.tool_name()}")
@step()
async def research(self, ctx: Context, ev: ResearchEvent) -> AnalyzeEvent:
"""
Do a research to gather information for the user's request.
A researcher should have these tools: query engine, search engine, etc.
"""
ctx.write_event_to_stream(
AgentRunEvent(
name="Researcher",
msg="Starting research",
)
)
tool_calls = ev.input
tool_messages = await call_tools(
ctx=ctx,
agent_name="Researcher",
tools=[self.query_engine_tool],
tool_calls=tool_calls,
)
self.memory.put_messages(tool_messages)
return AnalyzeEvent(
input=ChatMessage(
role=MessageRole.ASSISTANT,
content="I've finished the research. Please analyze the result.",
),
)
@step()
async def analyze(self, ctx: Context, ev: AnalyzeEvent) -> InputEvent:
"""
Analyze the research result.
"""
ctx.write_event_to_stream(
AgentRunEvent(
name="Analyst",
msg="Starting analysis",
)
)
event_requested_by_workflow_llm = isinstance(ev.input, list)
# Requested by the workflow LLM Input step, it's a tool call
if event_requested_by_workflow_llm:
# Set the tool calls
tool_calls = ev.input
else:
# Otherwise, it's triggered by the research step
# Use a custom prompt and independent memory for the analyst agent
analysis_prompt = """
You are a financial analyst, you are given a research result and a set of tools to help you.
Always use the given information, don't make up anything yourself. If there is not enough information, you can asking for more information.
If you have enough numerical information, it's good to include some charts/visualizations to the report so you can use the code interpreter tool to generate a report.
"""
# This is handled by analyst agent
# Clone the shared memory to avoid conflicting with the workflow.
chat_history = self.memory.get()
chat_history.append(
ChatMessage(role=MessageRole.SYSTEM, content=analysis_prompt)
)
chat_history.append(ev.input) # type: ignore
# Check if the analyst agent needs to call tools
response = await chat_with_tools(
self.llm,
[self.code_interpreter_tool],
chat_history,
)
if not response.has_tool_calls():
# If no tool call, fallback analyst message to the workflow
analyst_msg = ChatMessage(
role=MessageRole.ASSISTANT,
content=await response.full_response(),
)
self.memory.put(analyst_msg)
return InputEvent(input=self.memory.get())
else:
# Set the tool calls and the tool call message to the memory
tool_calls = response.tool_calls
self.memory.put(response.tool_call_message)
# Call tools
tool_messages = await call_tools(
ctx=ctx,
agent_name="Analyst",
tools=[self.code_interpreter_tool],
tool_calls=tool_calls, # type: ignore
)
self.memory.put_messages(tool_messages)
# Fallback to the input with the latest chat history
return InputEvent(input=self.memory.get())
@step()
async def report(self, ctx: Context, ev: ReportEvent) -> InputEvent:
"""
Generate a report based on the analysis result.
"""
ctx.write_event_to_stream(
AgentRunEvent(
name="Reporter",
msg="Starting report generation",
)
)
tool_calls = ev.input
tool_messages = await call_tools(
ctx=ctx,
agent_name="Reporter",
tools=[self.document_generator_tool],
tool_calls=tool_calls,
)
self.memory.put_messages(tool_messages)
# After the tool calls, fallback to the input with the latest chat history
return InputEvent(input=self.memory.get())
return query_engine_tool, code_interpreter_tool, document_generator_tool
from typing import Any, List, Optional
from app.workflows.events import AgentRunEvent
from app.workflows.tools import ToolCallResponse, call_tools, chat_with_tools
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.settings import Settings
from llama_index.core.tools.types import BaseTool
from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
)
class InputEvent(Event):
input: list[ChatMessage]
class ToolCallEvent(Event):
input: ToolCallResponse
class FunctionCallingAgent(Workflow):
"""
A simple workflow to request LLM with tools independently.
You can share the previous chat history to provide the context for the LLM.
"""
def __init__(
self,
*args: Any,
llm: FunctionCallingLLM | None = None,
chat_history: Optional[List[ChatMessage]] = None,
tools: List[BaseTool] | None = None,
system_prompt: str | None = None,
verbose: bool = False,
timeout: float = 360.0,
name: str,
write_events: bool = True,
**kwargs: Any,
) -> None:
super().__init__(*args, verbose=verbose, timeout=timeout, **kwargs) # type: ignore
self.tools = tools or []
self.name = name
self.write_events = write_events
if llm is None:
llm = Settings.llm
self.llm = llm
if not self.llm.metadata.is_function_calling_model:
raise ValueError("The provided LLM must support function calling.")
self.system_prompt = system_prompt
self.memory = ChatMemoryBuffer.from_defaults(
llm=self.llm, chat_history=chat_history
)
self.sources = [] # type: ignore
@step()
async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent:
# clear sources
self.sources = []
# set streaming
ctx.data["streaming"] = getattr(ev, "streaming", False)
# set system prompt
if self.system_prompt is not None:
system_msg = ChatMessage(role="system", content=self.system_prompt)
self.memory.put(system_msg)
# get user input
user_input = ev.input
user_msg = ChatMessage(role="user", content=user_input)
self.memory.put(user_msg)
if self.write_events:
ctx.write_event_to_stream(
AgentRunEvent(name=self.name, msg=f"Start to work on: {user_input}")
)
return InputEvent(input=self.memory.get())
@step()
async def handle_llm_input(
self,
ctx: Context,
ev: InputEvent,
) -> ToolCallEvent | StopEvent:
chat_history = ev.input
response = await chat_with_tools(
self.llm,
self.tools,
chat_history,
)
is_tool_call = isinstance(response, ToolCallResponse)
if not is_tool_call:
if ctx.data["streaming"]:
return StopEvent(result=response)
else:
full_response = ""
async for chunk in response.generator:
full_response += chunk.message.content
return StopEvent(result=full_response)
return ToolCallEvent(input=response)
@step()
async def handle_tool_calls(self, ctx: Context, ev: ToolCallEvent) -> InputEvent:
tool_calls = ev.input.tool_calls
tool_call_message = ev.input.tool_call_message
self.memory.put(tool_call_message)
tool_messages = await call_tools(self.name, self.tools, ctx, tool_calls)
self.memory.put_messages(tool_messages)
return InputEvent(input=self.memory.get())
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, Optional
from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.tools import (
BaseTool,
FunctionTool,
ToolOutput,
ToolSelection,
)
from llama_index.core.workflow import Context
from pydantic import BaseModel, ConfigDict
from app.workflows.events import AgentRunEvent, AgentRunEventType
logger = logging.getLogger("uvicorn")
class ContextAwareTool(FunctionTool, ABC):
@abstractmethod
async def acall(self, ctx: Context, input: Any) -> ToolOutput: # type: ignore
pass
class ChatWithToolsResponse(BaseModel):
"""
A tool call response from chat_with_tools.
"""
tool_calls: Optional[list[ToolSelection]]
tool_call_message: Optional[ChatMessage]
generator: Optional[AsyncGenerator[ChatResponse | None, None]]
model_config = ConfigDict(arbitrary_types_allowed=True)
def is_calling_different_tools(self) -> bool:
tool_names = {tool_call.tool_name for tool_call in self.tool_calls}
return len(tool_names) > 1
def has_tool_calls(self) -> bool:
return self.tool_calls is not None and len(self.tool_calls) > 0
def tool_name(self) -> str:
assert self.has_tool_calls()
assert not self.is_calling_different_tools()
return self.tool_calls[0].tool_name
async def full_response(self) -> str:
assert self.generator is not None
full_response = ""
async for chunk in self.generator:
content = chunk.message.content
if content:
full_response += content
return full_response
async def chat_with_tools( # type: ignore
llm: FunctionCallingLLM,
tools: list[BaseTool],
chat_history: list[ChatMessage],
) -> ChatWithToolsResponse:
"""
Request LLM to call tools or not.
This function doesn't change the memory.
"""
generator = _tool_call_generator(llm, tools, chat_history)
is_tool_call = await generator.__anext__()
if is_tool_call:
# Last chunk is the full response
# Wait for the last chunk
full_response = None
async for chunk in generator:
full_response = chunk
assert isinstance(full_response, ChatResponse)
return ChatWithToolsResponse(
tool_calls=llm.get_tool_calls_from_response(full_response),
tool_call_message=full_response.message,
generator=None,
)
else:
return ChatWithToolsResponse(
tool_calls=None,
tool_call_message=None,
generator=generator,
)
async def call_tools(
ctx: Context,
agent_name: str,
tools: list[BaseTool],
tool_calls: list[ToolSelection],
emit_agent_events: bool = True,
) -> list[ChatMessage]:
if len(tool_calls) == 0:
return []
tools_by_name = {tool.metadata.get_name(): tool for tool in tools}
if len(tool_calls) == 1:
return [
await call_tool(
ctx,
tools_by_name[tool_calls[0].tool_name],
tool_calls[0],
lambda msg: ctx.write_event_to_stream(
AgentRunEvent(
name=agent_name,
msg=msg,
)
),
)
]
# Multiple tool calls, show progress
tool_msgs: list[ChatMessage] = []
progress_id = str(uuid.uuid4())
total_steps = len(tool_calls)
if emit_agent_events:
ctx.write_event_to_stream(
AgentRunEvent(
name=agent_name,
msg=f"Making {total_steps} tool calls",
)
)
for i, tool_call in enumerate(tool_calls):
tool = tools_by_name.get(tool_call.tool_name)
if not tool:
tool_msgs.append(
ChatMessage(
role=MessageRole.ASSISTANT,
content=f"Tool {tool_call.tool_name} does not exist",
)
)
continue
tool_msg = await call_tool(
ctx,
tool,
tool_call,
event_emitter=lambda msg: ctx.write_event_to_stream(
AgentRunEvent(
name=agent_name,
msg=msg,
event_type=AgentRunEventType.PROGRESS,
data={
"id": progress_id,
"total": total_steps,
"current": i,
},
)
),
)
tool_msgs.append(tool_msg)
return tool_msgs
async def call_tool(
ctx: Context,
tool: BaseTool,
tool_call: ToolSelection,
event_emitter: Optional[Callable[[str], None]],
) -> ChatMessage:
if event_emitter:
event_emitter(
f"Calling tool {tool_call.tool_name}, {str(tool_call.tool_kwargs)}"
)
try:
if isinstance(tool, ContextAwareTool):
if ctx is None:
raise ValueError("Context is required for context aware tool")
# inject context for calling an context aware tool
response = await tool.acall(ctx=ctx, **tool_call.tool_kwargs)
else:
response = await tool.acall(**tool_call.tool_kwargs) # type: ignore
return ChatMessage(
role=MessageRole.TOOL,
content=str(response.raw_output),
additional_kwargs={
"tool_call_id": tool_call.tool_id,
"name": tool.metadata.get_name(),
},
)
except Exception as e:
logger.error(f"Got error in tool {tool_call.tool_name}: {str(e)}")
if event_emitter:
event_emitter(f"Got error in tool {tool_call.tool_name}: {str(e)}")
return ChatMessage(
role=MessageRole.TOOL,
content=f"Error: {str(e)}",
additional_kwargs={
"tool_call_id": tool_call.tool_id,
"name": tool.metadata.get_name(),
},
)
async def _tool_call_generator(
llm: FunctionCallingLLM,
tools: list[BaseTool],
chat_history: list[ChatMessage],
) -> AsyncGenerator[ChatResponse | bool, None]:
response_stream = await llm.astream_chat_with_tools(
tools,
chat_history=chat_history,
allow_parallel_tool_calls=False,
)
full_response = None
yielded_indicator = False
async for chunk in response_stream:
if "tool_calls" not in chunk.message.additional_kwargs:
# Yield a boolean to indicate whether the response is a tool call
if not yielded_indicator:
yield False
yielded_indicator = True
# if not a tool call, yield the chunks!
yield chunk # type: ignore
elif not yielded_indicator:
# Yield the indicator for a tool call
yield True
yielded_indicator = True
full_response = chunk
if full_response:
yield full_response # type: ignore
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment