From 857c5228543541a52c25eaa09b47c09adf4a576e Mon Sep 17 00:00:00 2001
From: Logan <logan.markewich@live.com>
Date: Fri, 28 Feb 2025 15:25:08 -0600
Subject: [PATCH] fix agentworkflow tool call tracking (#17968)

---
 .../core/agent/workflow/multi_agent_workflow.py  | 16 +++++++++++++++-
 1 file changed, 15 insertions(+), 1 deletion(-)

diff --git a/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py
index 4e3a9bc552..9ab4f69134 100644
--- a/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py
+++ b/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py
@@ -345,6 +345,12 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta):
             memory: BaseMemory = await ctx.get("memory")
             output = await agent.finalize(ctx, ev, memory)
 
+            cur_tool_calls: List[ToolCallResult] = await ctx.get(
+                "current_tool_calls", default=[]
+            )
+            output.tool_calls.extend(cur_tool_calls)  # type: ignore
+            await ctx.set("current_tool_calls", [])
+
             return StopEvent(result=output)
 
         await ctx.set("num_tool_calls", len(ev.tool_calls))
@@ -417,6 +423,13 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta):
         agent_name: str = await ctx.get("current_agent_name")
         agent: BaseWorkflowAgent = self.agents[agent_name]
 
+        # track tool calls made during a .run() call
+        cur_tool_calls: List[ToolCallResult] = await ctx.get(
+            "current_tool_calls", default=[]
+        )
+        cur_tool_calls.extend(tool_call_results)
+        await ctx.set("current_tool_calls", cur_tool_calls)
+
         await agent.handle_tool_call_results(ctx, tool_call_results, memory)
 
         # set the next agent, if needed
@@ -447,7 +460,7 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta):
                         tool_name=t.tool_name,
                         tool_kwargs=t.tool_kwargs,
                     )
-                    for t in tool_call_results
+                    for t in cur_tool_calls
                 ],
                 raw=str(return_direct_tool.tool_output.raw_output),
                 current_agent_name=agent.name,
@@ -456,6 +469,7 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta):
 
             # we don't want to stop the system if we're just handing off
             if return_direct_tool.tool_name != "handoff":
+                await ctx.set("current_tool_calls", [])
                 return StopEvent(result=result)
 
         user_msg_str = await ctx.get("user_msg_str")
-- 
GitLab