Skip to content
Snippets Groups Projects
Unverified Commit 857c5228 authored by Logan's avatar Logan Committed by GitHub
Browse files

fix agentworkflow tool call tracking (#17968)

parent 3c01c98d
Branches
Tags
No related merge requests found
...@@ -345,6 +345,12 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta): ...@@ -345,6 +345,12 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta):
memory: BaseMemory = await ctx.get("memory") memory: BaseMemory = await ctx.get("memory")
output = await agent.finalize(ctx, ev, memory) output = await agent.finalize(ctx, ev, memory)
cur_tool_calls: List[ToolCallResult] = await ctx.get(
"current_tool_calls", default=[]
)
output.tool_calls.extend(cur_tool_calls) # type: ignore
await ctx.set("current_tool_calls", [])
return StopEvent(result=output) return StopEvent(result=output)
await ctx.set("num_tool_calls", len(ev.tool_calls)) await ctx.set("num_tool_calls", len(ev.tool_calls))
...@@ -417,6 +423,13 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta): ...@@ -417,6 +423,13 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta):
agent_name: str = await ctx.get("current_agent_name") agent_name: str = await ctx.get("current_agent_name")
agent: BaseWorkflowAgent = self.agents[agent_name] agent: BaseWorkflowAgent = self.agents[agent_name]
# track tool calls made during a .run() call
cur_tool_calls: List[ToolCallResult] = await ctx.get(
"current_tool_calls", default=[]
)
cur_tool_calls.extend(tool_call_results)
await ctx.set("current_tool_calls", cur_tool_calls)
await agent.handle_tool_call_results(ctx, tool_call_results, memory) await agent.handle_tool_call_results(ctx, tool_call_results, memory)
# set the next agent, if needed # set the next agent, if needed
...@@ -447,7 +460,7 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta): ...@@ -447,7 +460,7 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta):
tool_name=t.tool_name, tool_name=t.tool_name,
tool_kwargs=t.tool_kwargs, tool_kwargs=t.tool_kwargs,
) )
for t in tool_call_results for t in cur_tool_calls
], ],
raw=str(return_direct_tool.tool_output.raw_output), raw=str(return_direct_tool.tool_output.raw_output),
current_agent_name=agent.name, current_agent_name=agent.name,
...@@ -456,6 +469,7 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta): ...@@ -456,6 +469,7 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta):
# we don't want to stop the system if we're just handing off # we don't want to stop the system if we're just handing off
if return_direct_tool.tool_name != "handoff": if return_direct_tool.tool_name != "handoff":
await ctx.set("current_tool_calls", [])
return StopEvent(result=result) return StopEvent(result=result)
user_msg_str = await ctx.get("user_msg_str") user_msg_str = await ctx.get("user_msg_str")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment