diff --git a/.changeset/yellow-walls-happen.md b/.changeset/yellow-walls-happen.md new file mode 100644 index 0000000000000000000000000000000000000000..2385a69f5aead09a74249f2af21c12f2f8ef5283 --- /dev/null +++ b/.changeset/yellow-walls-happen.md @@ -0,0 +1,5 @@ +--- +"@llamaindex/workflow": patch +--- + +Fix: multi-agent handover diff --git a/packages/workflow/src/agent/agent-workflow.ts b/packages/workflow/src/agent/agent-workflow.ts index a1a4a579df6c7ba8be190f0513978232cd062beb..0b2117c269c4188a6df21d922bae23b12051dea9 100644 --- a/packages/workflow/src/agent/agent-workflow.ts +++ b/packages/workflow/src/agent/agent-workflow.ts @@ -163,6 +163,24 @@ export class AgentWorkflow { this.addAgents(processedAgents); } + private addAgents(agents: BaseWorkflowAgent[]): void { + const agentNames = new Set(agents.map((a) => a.name)); + if (agentNames.size !== agents.length) { + throw new Error("The agent names must be unique!"); + } + + agents.forEach((agent) => { + this.agents.set(agent.name, agent); + }); + + if (agents.length > 1) { + agents.forEach((agent) => { + this.validateAgent(agent); + this.addHandoffTool(agent); + }); + } + } + private validateAgent(agent: BaseWorkflowAgent) { // Validate that all canHandoffTo agents exist const invalidAgents = agent.canHandoffTo.filter( @@ -176,7 +194,14 @@ export class AgentWorkflow { } private addHandoffTool(agent: BaseWorkflowAgent) { - const handoffTool = createHandoffTool(this.agents); + if (agent.tools.some((t) => t.metadata.name === "handOff")) { + return; + } + const toHandoffAgents: Map<string, BaseWorkflowAgent> = new Map(); + agent.canHandoffTo.forEach((name) => { + toHandoffAgents.set(name, this.agents.get(name)!); + }); + const handoffTool = createHandoffTool(toHandoffAgents); if ( agent.canHandoffTo.length > 0 && !agent.tools.some((t) => t.metadata.name === handoffTool.metadata.name) @@ -185,24 +210,6 @@ export class AgentWorkflow { } } - private addAgents(agents: BaseWorkflowAgent[]): void { - const agentNames = new Set(agents.map((a) => a.name)); - if (agentNames.size !== agents.length) { - throw new Error("The agent names must be unique!"); - } - - // First pass: add all agents to the map - agents.forEach((agent) => { - this.agents.set(agent.name, agent); - }); - - // Second pass: validate and setup handoff tools - agents.forEach((agent) => { - this.validateAgent(agent); - this.addHandoffTool(agent); - }); - } - /** * Adds a new agent to the workflow */ @@ -226,7 +233,6 @@ export class AgentWorkflow { * @param params - Parameters for the single agent workflow * @returns A new AgentWorkflow instance */ - static fromTools(params: SingleAgentParams): AgentWorkflow { const agent = new FunctionAgent({ name: params.name, @@ -234,6 +240,7 @@ export class AgentWorkflow { tools: params.tools, llm: params.llm, systemPrompt: params.systemPrompt, + canHandoffTo: params.canHandoffTo, }); const workflow = new AgentWorkflow({ diff --git a/packages/workflow/test/agent-workflow.test.ts b/packages/workflow/test/agent-workflow.test.ts index e467d691c9c845eaebb7c814faaa8b54aa724a07..cb9cc895da92043c3509ea65a446126a5bf8d80a 100644 --- a/packages/workflow/test/agent-workflow.test.ts +++ b/packages/workflow/test/agent-workflow.test.ts @@ -3,7 +3,7 @@ import { FunctionTool } from "@llamaindex/core/tools"; import { MockLLM } from "@llamaindex/core/utils"; import { describe, expect, test, vi } from "vitest"; import { z } from "zod"; -import { AgentWorkflow, FunctionAgent } from "../src/agent"; +import { AgentWorkflow, FunctionAgent, agent, multiAgent } from "../src/agent"; import { setupToolCallingMockLLM } from "./mock"; describe("AgentWorkflow", () => { @@ -157,3 +157,125 @@ describe("AgentWorkflow", () => { // }); }); + +describe("Multiple agents", () => { + test("multiple agents are set up correctly with handoff capabilities", () => { + // Create mock LLM + const mockLLM = new MockLLM(); + mockLLM.supportToolCall = true; + + // Create tools for agents + const addTool = FunctionTool.from( + (params: { x: number; y: number }) => params.x + params.y, + { + name: "add", + description: "Adds two numbers", + parameters: z.object({ + x: z.number(), + y: z.number(), + }), + }, + ); + + const multiplyTool = FunctionTool.from( + (params: { x: number; y: number }) => params.x * params.y, + { + name: "multiply", + description: "Multiplies two numbers", + parameters: z.object({ + x: z.number(), + y: z.number(), + }), + }, + ); + + const subtractTool = FunctionTool.from( + (params: { x: number; y: number }) => params.x - params.y, + { + name: "subtract", + description: "Subtracts two numbers", + parameters: z.object({ + x: z.number(), + y: z.number(), + }), + }, + ); + + // Create agents using the agent() function + const addAgent = agent({ + name: "AddAgent", + description: "Agent that can add numbers", + tools: [addTool], + llm: mockLLM, + }); + + const multiplyAgent = agent({ + name: "MultiplyAgent", + description: "Agent that can multiply numbers", + tools: [multiplyTool], + llm: mockLLM, + }); + + const mathAgent = agent({ + name: "MathAgent", + description: "Agent that can do various math operations", + tools: [addTool, multiplyTool, subtractTool], + llm: mockLLM, + canHandoffTo: ["AddAgent", "MultiplyAgent"], + }); + + // Create workflow with multiple agents using multiAgent + const workflow = multiAgent({ + agents: [mathAgent, addAgent, multiplyAgent], + rootAgent: mathAgent, + verbose: false, + }); + + // Verify agents are set up correctly + expect(workflow).toBeDefined(); + expect(workflow.getAgents().length).toBe(3); + + // Verify that the mathAgent has a handoff tool + const mathAgentInstance = workflow + .getAgents() + .find((agent) => agent.name === "MathAgent"); + expect(mathAgentInstance).toBeDefined(); + expect( + mathAgentInstance?.tools.some((tool) => tool.metadata.name === "handOff"), + ).toBe(true); + + // Verify that addAgent and multiplyAgent don't have handoff tools since they don't handoff to other agents + const addAgentInstance = workflow + .getAgents() + .find((agent) => agent.name === "AddAgent"); + expect(addAgentInstance).toBeDefined(); + expect( + addAgentInstance?.tools.some((tool) => tool.metadata.name === "handOff"), + ).toBe(false); + + const multiplyAgentInstance = workflow + .getAgents() + .find((agent) => agent.name === "MultiplyAgent"); + expect(multiplyAgentInstance).toBeDefined(); + expect( + multiplyAgentInstance?.tools.some( + (tool) => tool.metadata.name === "handOff", + ), + ).toBe(false); + + // Verify agent specific tools are preserved + expect( + mathAgentInstance?.tools.some( + (tool) => tool.metadata.name === "subtract", + ), + ).toBe(true); + expect( + addAgentInstance?.tools.some((tool) => tool.metadata.name === "add"), + ).toBe(true); + expect( + multiplyAgentInstance?.tools.some( + (tool) => tool.metadata.name === "multiply", + ), + ).toBe(true); + }); +});