Skip to content
Snippets Groups Projects
Commit d09ae65d authored by leehuwuj's avatar leehuwuj
Browse files

keep the old code for financial report and form-filling

parent 798f3785
No related branches found
No related tags found
No related merge requests found
from typing import Any, Dict, Optional, Tuple, Type from typing import Any, Dict, List, Optional
from llama_index.core import Settings from llama_index.core import Settings
from llama_index.core.agent.workflow import ( from llama_index.core.base.llms.types import ChatMessage, MessageRole
AgentWorkflow, from llama_index.core.llms.function_calling import FunctionCallingLLM
FunctionAgent, from llama_index.core.memory import ChatMemoryBuffer
ReActAgent, from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection
from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
) )
from llama_index.core.llms import LLM
from llama_index.core.tools import FunctionTool, QueryEngineTool
from app.engine.index import IndexConfig, get_index from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool from app.engine.tools.query_engine import get_query_engine_tool
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)
def create_workflow( def create_workflow(
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
**kwargs, **kwargs,
): ) -> Workflow:
query_engine_tool, code_interpreter_tool, document_generator_tool = _prepare_tools(
params, **kwargs
)
agent_cls = _get_agent_cls_from_llm(Settings.llm)
research_agent = agent_cls(
name="researcher",
description="A financial researcher who are given a tasks that need to look up information from the financial documents about the user's request.",
tools=[query_engine_tool],
system_prompt="""
You are a financial researcher who are given a tasks that need to look up information from the financial documents about the user's request.
You should use the query engine tool to look up the information and return the result to the user.
Always handoff the task to the `analyst` agent after gathering information.
""",
llm=Settings.llm,
can_handoff_to=["analyst"],
)
analyst_agent = agent_cls(
name="analyst",
description="A financial analyst who takes responsibility to analyze the financial data and generate a report.",
tools=[code_interpreter_tool],
system_prompt="""
Use the given information, don't make up anything yourself.
If you have enough numerical information, it's good to include some charts/visualizations to the report so you can use the code interpreter tool to generate a report.
You can use the code interpreter tool to generate a report.
Always handoff the task and pass the researched information to the `reporter` agent.
""",
llm=Settings.llm,
can_handoff_to=["reporter"],
)
reporter_agent = agent_cls(
name="reporter",
description="A reporter who takes responsibility to generate a document for a report content.",
tools=[document_generator_tool],
system_prompt="""
Use the document generator tool to generate the document and return the result to the user.
Don't update the content of the document, just generate a new one.
After generating the document, tell the user for the content or the issue if there is any.
""",
llm=Settings.llm,
)
workflow = AgentWorkflow(
agents=[research_agent, analyst_agent, reporter_agent],
root_agent="researcher",
verbose=True,
)
return workflow
def _get_agent_cls_from_llm(llm: LLM) -> Type[FunctionAgent | ReActAgent]:
if llm.metadata.is_function_calling_model:
return FunctionAgent
else:
return ReActAgent
def _prepare_tools(
params: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Tuple[QueryEngineTool, FunctionTool, FunctionTool]:
# Create query engine tool # Create query engine tool
index_config = IndexConfig(**params) index_config = IndexConfig(**params)
index = get_index(index_config) index = get_index(index_config)
...@@ -94,4 +41,260 @@ def _prepare_tools( ...@@ -94,4 +41,260 @@ def _prepare_tools(
code_interpreter_tool = configured_tools.get("interpret") code_interpreter_tool = configured_tools.get("interpret")
document_generator_tool = configured_tools.get("generate_document") document_generator_tool = configured_tools.get("generate_document")
return query_engine_tool, code_interpreter_tool, document_generator_tool return FinancialReportWorkflow(
query_engine_tool=query_engine_tool,
code_interpreter_tool=code_interpreter_tool,
document_generator_tool=document_generator_tool,
)
class InputEvent(Event):
input: List[ChatMessage]
response: bool = False
class ResearchEvent(Event):
input: list[ToolSelection]
class AnalyzeEvent(Event):
input: list[ToolSelection] | ChatMessage
class ReportEvent(Event):
input: list[ToolSelection]
class FinancialReportWorkflow(Workflow):
"""
A workflow to generate a financial report using indexed documents.
Requirements:
- Indexed documents containing financial data and a query engine tool to search them
- A code interpreter tool to analyze data and generate reports
- A document generator tool to create report files
Steps:
1. LLM Input: The LLM determines the next step based on function calling.
For example, if the model requests the query engine tool, it returns a ResearchEvent;
if it requests document generation, it returns a ReportEvent.
2. Research: Uses the query engine to find relevant chunks from indexed documents.
After gathering information, it requests analysis (step 3).
3. Analyze: Uses a custom prompt to analyze research results and can call the code
interpreter tool for visualization or calculation. Returns results to the LLM.
4. Report: Uses the document generator tool to create a report. Returns results to the LLM.
"""
_default_system_prompt = """
You are a financial analyst who are given a set of tools to help you.
It's good to using appropriate tools for the user request and always use the information from the tools, don't make up anything yourself.
For the query engine tool, you should break down the user request into a list of queries and call the tool with the queries.
"""
stream: bool = True
def __init__(
self,
query_engine_tool: QueryEngineTool,
code_interpreter_tool: FunctionTool,
document_generator_tool: FunctionTool,
llm: Optional[FunctionCallingLLM] = None,
timeout: int = 360,
system_prompt: Optional[str] = None,
):
super().__init__(timeout=timeout)
self.system_prompt = system_prompt or self._default_system_prompt
self.query_engine_tool = query_engine_tool
self.code_interpreter_tool = code_interpreter_tool
self.document_generator_tool = document_generator_tool
assert query_engine_tool is not None, (
"Query engine tool is not found. Try run generation script or upload a document file first."
)
assert code_interpreter_tool is not None, "Code interpreter tool is required"
assert document_generator_tool is not None, (
"Document generator tool is required"
)
self.tools = [
self.query_engine_tool,
self.code_interpreter_tool,
self.document_generator_tool,
]
self.llm: FunctionCallingLLM = llm or Settings.llm
assert isinstance(self.llm, FunctionCallingLLM)
self.memory = ChatMemoryBuffer.from_defaults(llm=self.llm)
@step()
async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent:
self.stream = ev.get("stream", True)
user_msg = ev.get("user_msg")
chat_history = ev.get("chat_history")
if chat_history is not None:
self.memory.put_messages(chat_history)
# Add user message to memory
self.memory.put(ChatMessage(role=MessageRole.USER, content=user_msg))
if self.system_prompt:
system_msg = ChatMessage(
role=MessageRole.SYSTEM, content=self.system_prompt
)
self.memory.put(system_msg)
return InputEvent(input=self.memory.get())
@step()
async def handle_llm_input( # type: ignore
self,
ctx: Context,
ev: InputEvent,
) -> ResearchEvent | AnalyzeEvent | ReportEvent | StopEvent:
"""
Handle an LLM input and decide the next step.
"""
# Always use the latest chat history from the input
chat_history: list[ChatMessage] = ev.input
# Get tool calls
response = await chat_with_tools(
self.llm,
self.tools, # type: ignore
chat_history,
)
if not response.has_tool_calls():
if self.stream:
return StopEvent(result=response.generator)
else:
return StopEvent(result=await response.full_response())
# calling different tools at the same time is not supported at the moment
# add an error message to tell the AI to process step by step
if response.is_calling_different_tools():
self.memory.put(
ChatMessage(
role=MessageRole.ASSISTANT,
content="Cannot call different tools at the same time. Try calling one tool at a time.",
)
)
return InputEvent(input=self.memory.get())
self.memory.put(response.tool_call_message)
match response.tool_name():
case self.code_interpreter_tool.metadata.name:
return AnalyzeEvent(input=response.tool_calls)
case self.document_generator_tool.metadata.name:
return ReportEvent(input=response.tool_calls)
case self.query_engine_tool.metadata.name:
return ResearchEvent(input=response.tool_calls)
case _:
raise ValueError(f"Unknown tool: {response.tool_name()}")
@step()
async def research(self, ctx: Context, ev: ResearchEvent) -> AnalyzeEvent:
"""
Do a research to gather information for the user's request.
A researcher should have these tools: query engine, search engine, etc.
"""
ctx.write_event_to_stream(
AgentRunEvent(
name="Researcher",
msg="Starting research",
)
)
tool_calls = ev.input
tool_messages = await call_tools(
ctx=ctx,
agent_name="Researcher",
tools=[self.query_engine_tool],
tool_calls=tool_calls,
)
self.memory.put_messages(tool_messages)
return AnalyzeEvent(
input=ChatMessage(
role=MessageRole.ASSISTANT,
content="I've finished the research. Please analyze the result.",
),
)
@step()
async def analyze(self, ctx: Context, ev: AnalyzeEvent) -> InputEvent:
"""
Analyze the research result.
"""
ctx.write_event_to_stream(
AgentRunEvent(
name="Analyst",
msg="Starting analysis",
)
)
event_requested_by_workflow_llm = isinstance(ev.input, list)
# Requested by the workflow LLM Input step, it's a tool call
if event_requested_by_workflow_llm:
# Set the tool calls
tool_calls = ev.input
else:
# Otherwise, it's triggered by the research step
# Use a custom prompt and independent memory for the analyst agent
analysis_prompt = """
You are a financial analyst, you are given a research result and a set of tools to help you.
Always use the given information, don't make up anything yourself. If there is not enough information, you can asking for more information.
If you have enough numerical information, it's good to include some charts/visualizations to the report so you can use the code interpreter tool to generate a report.
"""
# This is handled by analyst agent
# Clone the shared memory to avoid conflicting with the workflow.
chat_history = self.memory.get()
chat_history.append(
ChatMessage(role=MessageRole.SYSTEM, content=analysis_prompt)
)
chat_history.append(ev.input) # type: ignore
# Check if the analyst agent needs to call tools
response = await chat_with_tools(
self.llm,
[self.code_interpreter_tool],
chat_history,
)
if not response.has_tool_calls():
# If no tool call, fallback analyst message to the workflow
analyst_msg = ChatMessage(
role=MessageRole.ASSISTANT,
content=await response.full_response(),
)
self.memory.put(analyst_msg)
return InputEvent(input=self.memory.get())
else:
# Set the tool calls and the tool call message to the memory
tool_calls = response.tool_calls
self.memory.put(response.tool_call_message)
# Call tools
tool_messages = await call_tools(
ctx=ctx,
agent_name="Analyst",
tools=[self.code_interpreter_tool],
tool_calls=tool_calls, # type: ignore
)
self.memory.put_messages(tool_messages)
# Fallback to the input with the latest chat history
return InputEvent(input=self.memory.get())
@step()
async def report(self, ctx: Context, ev: ReportEvent) -> InputEvent:
"""
Generate a report based on the analysis result.
"""
ctx.write_event_to_stream(
AgentRunEvent(
name="Reporter",
msg="Starting report generation",
)
)
tool_calls = ev.input
tool_messages = await call_tools(
ctx=ctx,
agent_name="Reporter",
tools=[self.document_generator_tool],
tool_calls=tool_calls,
)
self.memory.put_messages(tool_messages)
# After the tool calls, fallback to the input with the latest chat history
return InputEvent(input=self.memory.get())
from typing import Any, Dict, Optional, Type from typing import Any, Dict, List, Optional
from llama_index.core import Settings from llama_index.core import Settings
from llama_index.core.agent.workflow import ( from llama_index.core.base.llms.types import ChatMessage, MessageRole
AgentWorkflow, from llama_index.core.llms.function_calling import FunctionCallingLLM
FunctionAgent, from llama_index.core.memory import ChatMemoryBuffer
ReActAgent, from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection
from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
) )
from llama_index.core.llms import LLM
from app.engine.index import IndexConfig, get_index from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool from app.engine.tools.query_engine import get_query_engine_tool
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)
def create_workflow( def create_workflow(
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
**kwargs, **kwargs,
) -> AgentWorkflow: ) -> Workflow:
# Create query engine tool # Create query engine tool
index_config = IndexConfig(**params) index_config = IndexConfig(**params)
index = get_index(index_config) index = get_index(index_config)
...@@ -29,60 +40,197 @@ def create_workflow( ...@@ -29,60 +40,197 @@ def create_workflow(
extractor_tool = configured_tools.get("extract_questions") # type: ignore extractor_tool = configured_tools.get("extract_questions") # type: ignore
filling_tool = configured_tools.get("fill_form") # type: ignore filling_tool = configured_tools.get("fill_form") # type: ignore
if extractor_tool is None or filling_tool is None: workflow = FormFillingWorkflow(
raise ValueError("Extractor and filling tools are required.") query_engine_tool=query_engine_tool,
extractor_tool=extractor_tool, # type: ignore
agent_cls = _get_agent_cls_from_llm(Settings.llm) filling_tool=filling_tool, # type: ignore
extractor_agent = agent_cls(
name="extractor",
description="An agent that extracts missing cells from CSV files and generates questions to fill them.",
tools=[extractor_tool],
system_prompt="""
You are a helpful assistant who extracts missing cells from CSV files.
Only extract missing cells from CSV files and generate questions to fill them.
Always handoff the task to the `researcher` agent after extracting the questions.
""",
llm=Settings.llm,
can_handoff_to=["researcher"],
) )
researcher_agent = agent_cls( return workflow
name="researcher",
description="An agent that finds answers to questions about missing cells.",
tools=[query_engine_tool] if query_engine_tool else [],
system_prompt="""
You are a researcher who finds answers to questions about missing cells.
Only use provided data - never make up any information yourself. Use N/A if an answer is not found.
Always handoff the task to the `processor` agent after finding the answers.
""",
llm=Settings.llm,
can_handoff_to=["processor"],
)
processor_agent = agent_cls(
name="processor",
description="An agent that fills missing cells with found answers.",
tools=[filling_tool],
system_prompt="""
You are a processor who fills missing cells with found answers.
Fill N/A for any missing answers.
After filling the cells, tell the user about the results or any issues encountered.
""",
llm=Settings.llm,
)
workflow = AgentWorkflow( class InputEvent(Event):
agents=[extractor_agent, researcher_agent, processor_agent], input: List[ChatMessage]
root_agent="extractor", response: bool = False
verbose=True,
)
return workflow
class ExtractMissingCellsEvent(Event):
tool_calls: list[ToolSelection]
def _get_agent_cls_from_llm(llm: LLM) -> Type[FunctionAgent | ReActAgent]:
if llm.metadata.is_function_calling_model: class FindAnswersEvent(Event):
return FunctionAgent tool_calls: list[ToolSelection]
else:
return ReActAgent
class FillEvent(Event):
tool_calls: list[ToolSelection]
class FormFillingWorkflow(Workflow):
"""
A predefined workflow for filling missing cells in a CSV file.
Required tools:
- query_engine: A query engine to query for the answers to the questions.
- extract_question: Extract missing cells in a CSV file and generate questions to fill them.
- answer_question: Query for the answers to the questions.
Flow:
1. Extract missing cells in a CSV file and generate questions to fill them.
2. Query for the answers to the questions.
3. Fill the missing cells with the answers.
"""
_default_system_prompt = """
You are a helpful assistant who helps fill missing cells in a CSV file.
Only extract missing cells from CSV files.
Only use provided data - never make up any information yourself. Fill N/A if an answer is not found.
If there is no query engine tool or the gathered information has many N/A values indicating the questions don't match the data, respond with a warning and ask the user to upload a different file or connect to a knowledge base.
"""
stream: bool = True
def __init__(
self,
query_engine_tool: Optional[QueryEngineTool],
extractor_tool: FunctionTool,
filling_tool: FunctionTool,
llm: Optional[FunctionCallingLLM] = None,
timeout: int = 360,
system_prompt: Optional[str] = None,
):
super().__init__(timeout=timeout)
self.system_prompt = system_prompt or self._default_system_prompt
self.query_engine_tool = query_engine_tool
self.extractor_tool = extractor_tool
self.filling_tool = filling_tool
if self.extractor_tool is None or self.filling_tool is None:
raise ValueError("Extractor and filling tools are required.")
self.tools = [self.extractor_tool, self.filling_tool]
if self.query_engine_tool is not None:
self.tools.append(self.query_engine_tool) # type: ignore
self.llm: FunctionCallingLLM = llm or Settings.llm
if not isinstance(self.llm, FunctionCallingLLM):
raise ValueError("FormFillingWorkflow only supports FunctionCallingLLM.")
self.memory = ChatMemoryBuffer.from_defaults(llm=self.llm)
@step()
async def start(self, ctx: Context, ev: StartEvent) -> InputEvent:
self.stream = ev.get("stream", True)
user_msg = ev.get("user_msg", "")
chat_history = ev.get("chat_history", [])
if chat_history:
self.memory.put_messages(chat_history)
self.memory.put(ChatMessage(role=MessageRole.USER, content=user_msg))
if self.system_prompt:
system_msg = ChatMessage(
role=MessageRole.SYSTEM, content=self.system_prompt
)
self.memory.put(system_msg)
return InputEvent(input=self.memory.get())
@step()
async def handle_llm_input( # type: ignore
self,
ctx: Context,
ev: InputEvent,
) -> ExtractMissingCellsEvent | FillEvent | StopEvent:
"""
Handle an LLM input and decide the next step.
"""
chat_history: list[ChatMessage] = ev.input
response = await chat_with_tools(
self.llm,
self.tools,
chat_history,
)
if not response.has_tool_calls():
if self.stream:
return StopEvent(result=response.generator)
else:
return StopEvent(result=await response.full_response())
# calling different tools at the same time is not supported at the moment
# add an error message to tell the AI to process step by step
if response.is_calling_different_tools():
self.memory.put(
ChatMessage(
role=MessageRole.ASSISTANT,
content="Cannot call different tools at the same time. Try calling one tool at a time.",
)
)
return InputEvent(input=self.memory.get())
self.memory.put(response.tool_call_message)
match response.tool_name():
case self.extractor_tool.metadata.name:
return ExtractMissingCellsEvent(tool_calls=response.tool_calls)
case self.query_engine_tool.metadata.name:
return FindAnswersEvent(tool_calls=response.tool_calls)
case self.filling_tool.metadata.name:
return FillEvent(tool_calls=response.tool_calls)
case _:
raise ValueError(f"Unknown tool: {response.tool_name()}")
@step()
async def extract_missing_cells(
self, ctx: Context, ev: ExtractMissingCellsEvent
) -> InputEvent | FindAnswersEvent:
"""
Extract missing cells in a CSV file and generate questions to fill them.
"""
ctx.write_event_to_stream(
AgentRunEvent(
name="Extractor",
msg="Extracting missing cells",
)
)
# Call the extract questions tool
tool_messages = await call_tools(
agent_name="Extractor",
tools=[self.extractor_tool],
ctx=ctx,
tool_calls=ev.tool_calls,
)
self.memory.put_messages(tool_messages)
return InputEvent(input=self.memory.get())
@step()
async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
"""
Call answer questions tool to query for the answers to the questions.
"""
ctx.write_event_to_stream(
AgentRunEvent(
name="Researcher",
msg="Finding answers for missing cells",
)
)
tool_messages = await call_tools(
ctx=ctx,
agent_name="Researcher",
tools=[self.query_engine_tool],
tool_calls=ev.tool_calls,
)
self.memory.put_messages(tool_messages)
return InputEvent(input=self.memory.get())
@step()
async def fill_cells(self, ctx: Context, ev: FillEvent) -> InputEvent:
"""
Call fill cells tool to fill the missing cells with the answers.
"""
ctx.write_event_to_stream(
AgentRunEvent(
name="Processor",
msg="Filling missing cells",
)
)
tool_messages = await call_tools(
agent_name="Processor",
tools=[self.filling_tool],
ctx=ctx,
tool_calls=ev.tool_calls,
)
self.memory.put_messages(tool_messages)
return InputEvent(input=self.memory.get())
...@@ -5,7 +5,7 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status ...@@ -5,7 +5,7 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status
from app.api.callbacks.llamacloud import LlamaCloudFileDownload from app.api.callbacks.llamacloud import LlamaCloudFileDownload
from app.api.callbacks.next_question import SuggestNextQuestions from app.api.callbacks.next_question import SuggestNextQuestions
from app.api.callbacks.stream_handler import StreamHandler from app.api.callbacks.stream_handler import StreamHandler
from app.api.callbacks.add_node_url import AddNodeUrl from app.api.callbacks.source_nodes import AddNodeUrl
from app.api.routers.models import ( from app.api.routers.models import (
ChatData, ChatData,
) )
......
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, Optional
from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.tools import (
BaseTool,
FunctionTool,
ToolOutput,
ToolSelection,
)
from llama_index.core.workflow import Context
from pydantic import BaseModel, ConfigDict
from app.workflows.events import AgentRunEvent, AgentRunEventType
logger = logging.getLogger("uvicorn")
class ContextAwareTool(FunctionTool, ABC):
@abstractmethod
async def acall(self, ctx: Context, input: Any) -> ToolOutput: # type: ignore
pass
class ChatWithToolsResponse(BaseModel):
"""
A tool call response from chat_with_tools.
"""
tool_calls: Optional[list[ToolSelection]]
tool_call_message: Optional[ChatMessage]
generator: Optional[AsyncGenerator[ChatResponse | None, None]]
model_config = ConfigDict(arbitrary_types_allowed=True)
def is_calling_different_tools(self) -> bool:
tool_names = {tool_call.tool_name for tool_call in self.tool_calls}
return len(tool_names) > 1
def has_tool_calls(self) -> bool:
return self.tool_calls is not None and len(self.tool_calls) > 0
def tool_name(self) -> str:
assert self.has_tool_calls()
assert not self.is_calling_different_tools()
return self.tool_calls[0].tool_name
async def full_response(self) -> str:
assert self.generator is not None
full_response = ""
async for chunk in self.generator:
content = chunk.message.content
if content:
full_response += content
return full_response
async def chat_with_tools( # type: ignore
llm: FunctionCallingLLM,
tools: list[BaseTool],
chat_history: list[ChatMessage],
) -> ChatWithToolsResponse:
"""
Request LLM to call tools or not.
This function doesn't change the memory.
"""
generator = _tool_call_generator(llm, tools, chat_history)
is_tool_call = await generator.__anext__()
if is_tool_call:
# Last chunk is the full response
# Wait for the last chunk
full_response = None
async for chunk in generator:
full_response = chunk
assert isinstance(full_response, ChatResponse)
return ChatWithToolsResponse(
tool_calls=llm.get_tool_calls_from_response(full_response),
tool_call_message=full_response.message,
generator=None,
)
else:
return ChatWithToolsResponse(
tool_calls=None,
tool_call_message=None,
generator=generator,
)
async def call_tools(
ctx: Context,
agent_name: str,
tools: list[BaseTool],
tool_calls: list[ToolSelection],
emit_agent_events: bool = True,
) -> list[ChatMessage]:
if len(tool_calls) == 0:
return []
tools_by_name = {tool.metadata.get_name(): tool for tool in tools}
if len(tool_calls) == 1:
return [
await call_tool(
ctx,
tools_by_name[tool_calls[0].tool_name],
tool_calls[0],
lambda msg: ctx.write_event_to_stream(
AgentRunEvent(
name=agent_name,
msg=msg,
)
),
)
]
# Multiple tool calls, show progress
tool_msgs: list[ChatMessage] = []
progress_id = str(uuid.uuid4())
total_steps = len(tool_calls)
if emit_agent_events:
ctx.write_event_to_stream(
AgentRunEvent(
name=agent_name,
msg=f"Making {total_steps} tool calls",
)
)
for i, tool_call in enumerate(tool_calls):
tool = tools_by_name.get(tool_call.tool_name)
if not tool:
tool_msgs.append(
ChatMessage(
role=MessageRole.ASSISTANT,
content=f"Tool {tool_call.tool_name} does not exist",
)
)
continue
tool_msg = await call_tool(
ctx,
tool,
tool_call,
event_emitter=lambda msg: ctx.write_event_to_stream(
AgentRunEvent(
name=agent_name,
msg=msg,
event_type=AgentRunEventType.PROGRESS,
data={
"id": progress_id,
"total": total_steps,
"current": i,
},
)
),
)
tool_msgs.append(tool_msg)
return tool_msgs
async def call_tool(
ctx: Context,
tool: BaseTool,
tool_call: ToolSelection,
event_emitter: Optional[Callable[[str], None]],
) -> ChatMessage:
if event_emitter:
event_emitter(
f"Calling tool {tool_call.tool_name}, {str(tool_call.tool_kwargs)}"
)
try:
if isinstance(tool, ContextAwareTool):
if ctx is None:
raise ValueError("Context is required for context aware tool")
# inject context for calling an context aware tool
response = await tool.acall(ctx=ctx, **tool_call.tool_kwargs)
else:
response = await tool.acall(**tool_call.tool_kwargs) # type: ignore
return ChatMessage(
role=MessageRole.TOOL,
content=str(response.raw_output),
additional_kwargs={
"tool_call_id": tool_call.tool_id,
"name": tool.metadata.get_name(),
},
)
except Exception as e:
logger.error(f"Got error in tool {tool_call.tool_name}: {str(e)}")
if event_emitter:
event_emitter(f"Got error in tool {tool_call.tool_name}: {str(e)}")
return ChatMessage(
role=MessageRole.TOOL,
content=f"Error: {str(e)}",
additional_kwargs={
"tool_call_id": tool_call.tool_id,
"name": tool.metadata.get_name(),
},
)
async def _tool_call_generator(
llm: FunctionCallingLLM,
tools: list[BaseTool],
chat_history: list[ChatMessage],
) -> AsyncGenerator[ChatResponse | bool, None]:
response_stream = await llm.astream_chat_with_tools(
tools,
chat_history=chat_history,
allow_parallel_tool_calls=False,
)
full_response = None
yielded_indicator = False
async for chunk in response_stream:
if "tool_calls" not in chunk.message.additional_kwargs:
# Yield a boolean to indicate whether the response is a tool call
if not yielded_indicator:
yield False
yielded_indicator = True
# if not a tool call, yield the chunks!
yield chunk # type: ignore
elif not yielded_indicator:
# Yield the indicator for a tool call
yield True
yielded_indicator = True
full_response = chunk
if full_response:
yield full_response # type: ignore
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment