Skip to content
Snippets Groups Projects
Unverified Commit 3dee9641 authored by Daniel Bustamante Ospina's avatar Daniel Bustamante Ospina Committed by GitHub
Browse files

Add support for running single-agent workflows within the BaseWorkflowAgent class (#18038)

parent 8b3e456c
No related branches found
No related tags found
No related merge requests found
from abc import ABC, abstractmethod
from typing import Callable, List, Sequence, Optional, Union
from typing import Callable, List, Sequence, Optional, Union, Any
from llama_index.core.agent.workflow.workflow_events import (
AgentOutput,
......@@ -18,6 +18,8 @@ from llama_index.core.tools import BaseTool, AsyncBaseTool, FunctionTool
from llama_index.core.workflow import Context
from llama_index.core.objects import ObjectRetriever
from llama_index.core.settings import Settings
from llama_index.core.workflow.checkpointer import CheckpointCallback
from llama_index.core.workflow.handler import WorkflowHandler
def get_default_llm() -> LLM:
......@@ -109,3 +111,16 @@ class BaseWorkflowAgent(BaseModel, PromptMixin, ABC):
self, ctx: Context, output: AgentOutput, memory: BaseMemory
) -> AgentOutput:
"""Finalize the agent's execution."""
@abstractmethod
def run(
self,
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None,
ctx: Optional[Context] = None,
stepwise: bool = False,
checkpoint_callback: Optional[CheckpointCallback] = None,
**workflow_kwargs: Any,
) -> WorkflowHandler:
"""Run the agent."""
from typing import List, Sequence
from llama_index.core.agent.workflow.base_agent import BaseWorkflowAgent
from llama_index.core.agent.workflow.single_agent_workflow import SingleAgentRunnerMixin
from llama_index.core.agent.workflow.workflow_events import (
AgentInput,
AgentOutput,
......@@ -15,7 +16,7 @@ from llama_index.core.tools import AsyncBaseTool
from llama_index.core.workflow import Context
class FunctionAgent(BaseWorkflowAgent):
class FunctionAgent(SingleAgentRunnerMixin, BaseWorkflowAgent):
"""Function calling agent implementation."""
scratchpad_key: str = "scratchpad"
......
......@@ -10,6 +10,7 @@ from llama_index.core.agent.react.types import (
ResponseReasoningStep,
)
from llama_index.core.agent.workflow.base_agent import BaseWorkflowAgent
from llama_index.core.agent.workflow.single_agent_workflow import SingleAgentRunnerMixin
from llama_index.core.agent.workflow.workflow_events import (
AgentInput,
AgentOutput,
......@@ -32,7 +33,7 @@ def default_formatter() -> ReActChatFormatter:
return ReActChatFormatter.from_defaults(context="some context")
class ReActAgent(BaseWorkflowAgent):
class ReActAgent(SingleAgentRunnerMixin, BaseWorkflowAgent):
"""React agent implementation."""
reasoning_key: str = "current_reasoning"
......
from abc import ABC
from typing import Any, List, Optional, Union, TypeVar
from llama_index.core.llms import ChatMessage
from llama_index.core.memory import BaseMemory
from llama_index.core.workflow import (
Context,
)
from llama_index.core.workflow.checkpointer import CheckpointCallback
from llama_index.core.workflow.handler import WorkflowHandler
T = TypeVar("T", bound="BaseWorkflowAgent") # type: ignore[name-defined]
class SingleAgentRunnerMixin(ABC):
"""Mixin class for executing a single agent within a workflow system.
This class provides the necessary interface for running a single agent.
"""
def run(
self: T,
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None,
ctx: Optional[Context] = None,
stepwise: bool = False,
checkpoint_callback: Optional[CheckpointCallback] = None,
**workflow_kwargs: Any,
) -> WorkflowHandler:
"""Run the agent."""
from llama_index.core.agent.workflow import AgentWorkflow
workflow = AgentWorkflow(agents=[self], **workflow_kwargs)
return workflow.run(
user_msg=user_msg,
chat_history=chat_history,
memory=memory,
ctx=ctx,
stepwise=stepwise,
checkpoint_callback=checkpoint_callback,
)
from typing import List, Any
import pytest
from llama_index.core.agent.workflow import FunctionAgent, ReActAgent
from llama_index.core.base.llms.types import (
ChatMessage,
LLMMetadata,
ChatResponseAsyncGen,
ChatResponse,
MessageRole,
)
from llama_index.core.llms import MockLLM
from llama_index.core.llms.llm import ToolSelection
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools import FunctionTool
class MockLLM(MockLLM):
def __init__(self, responses: List[ChatMessage]):
super().__init__()
self._responses = responses
self._response_index = 0
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(is_function_calling_model=True)
async def astream_chat(
self, messages: List[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
response_msg = None
if self._responses:
response_msg = self._responses[self._response_index]
self._response_index = (self._response_index + 1) % len(self._responses)
async def _gen():
if response_msg:
yield ChatResponse(
message=response_msg,
delta=response_msg.content,
raw={"content": response_msg.content},
)
return _gen()
async def astream_chat_with_tools(
self, tools: List[Any], chat_history: List[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
response_msg = None
if self._responses:
response_msg = self._responses[self._response_index]
self._response_index = (self._response_index + 1) % len(self._responses)
async def _gen():
if response_msg:
yield ChatResponse(
message=response_msg,
delta=response_msg.content,
raw={"content": response_msg.content},
)
return _gen()
def get_tool_calls_from_response(
self, response: ChatResponse, **kwargs: Any
) -> List[ToolSelection]:
return response.message.additional_kwargs.get("tool_calls", [])
@pytest.fixture()
def function_agent():
return FunctionAgent(
name="retriever",
description="Manages data retrieval",
system_prompt="You are a retrieval assistant.",
llm=MockLLM(
responses=[
ChatMessage(
role=MessageRole.ASSISTANT, content="Success with the FunctionAgent"
)
],
),
)
def add(a: int, b: int) -> int:
"""Add two numbers."""
return a + b
def subtract(a: int, b: int) -> int:
"""Subtract two numbers."""
return a - b
@pytest.fixture()
def calculator_agent():
return ReActAgent(
name="calculator",
description="Performs basic arithmetic operations",
system_prompt="You are a calculator assistant.",
tools=[
FunctionTool.from_defaults(fn=add),
FunctionTool.from_defaults(fn=subtract),
],
llm=MockLLM(
responses=[
ChatMessage(
role=MessageRole.ASSISTANT,
content='Thought: I need to add these numbers\nAction: add\nAction Input: {"a": 5, "b": 3}\n',
),
ChatMessage(
role=MessageRole.ASSISTANT,
content=r"Thought: The result is 8\Answer: The sum is 8",
),
]
),
)
@pytest.mark.asyncio()
async def test_single_function_agent(function_agent):
"""Test single agent with state management."""
handler = function_agent.run(user_msg="test")
async for _ in handler.stream_events():
pass
response = await handler
assert "Success with the FunctionAgent" in str(response.response)
@pytest.mark.asyncio()
async def test_single_react_agent(calculator_agent):
"""Verify execution of basic ReAct single agent."""
memory = ChatMemoryBuffer.from_defaults()
handler = calculator_agent.run(user_msg="Can you add 5 and 3?", memory=memory)
events = []
async for event in handler.stream_events():
events.append(event)
response = await handler
assert "8" in str(response.response)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment