From 857c5228543541a52c25eaa09b47c09adf4a576e Mon Sep 17 00:00:00 2001 From: Logan <logan.markewich@live.com> Date: Fri, 28 Feb 2025 15:25:08 -0600 Subject: [PATCH] fix agentworkflow tool call tracking (#17968) --- .../core/agent/workflow/multi_agent_workflow.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py index 4e3a9bc552..9ab4f69134 100644 --- a/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py +++ b/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py @@ -345,6 +345,12 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta): memory: BaseMemory = await ctx.get("memory") output = await agent.finalize(ctx, ev, memory) + cur_tool_calls: List[ToolCallResult] = await ctx.get( + "current_tool_calls", default=[] + ) + output.tool_calls.extend(cur_tool_calls) # type: ignore + await ctx.set("current_tool_calls", []) + return StopEvent(result=output) await ctx.set("num_tool_calls", len(ev.tool_calls)) @@ -417,6 +423,13 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta): agent_name: str = await ctx.get("current_agent_name") agent: BaseWorkflowAgent = self.agents[agent_name] + # track tool calls made during a .run() call + cur_tool_calls: List[ToolCallResult] = await ctx.get( + "current_tool_calls", default=[] + ) + cur_tool_calls.extend(tool_call_results) + await ctx.set("current_tool_calls", cur_tool_calls) + await agent.handle_tool_call_results(ctx, tool_call_results, memory) # set the next agent, if needed @@ -447,7 +460,7 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta): tool_name=t.tool_name, tool_kwargs=t.tool_kwargs, ) - for t in tool_call_results + for t in cur_tool_calls ], raw=str(return_direct_tool.tool_output.raw_output), current_agent_name=agent.name, @@ -456,6 +469,7 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta): # we don't want to stop the system if we're just handing off if return_direct_tool.tool_name != "handoff": + await ctx.set("current_tool_calls", []) return StopEvent(result=result) user_msg_str = await ctx.get("user_msg_str") -- GitLab