From 60b75cb014bccf60a153d5dc5295a91c7cdcd9f6 Mon Sep 17 00:00:00 2001 From: Jerry Liu <jerryjliu98@gmail.com> Date: Sat, 10 Feb 2024 12:08:09 -0800 Subject: [PATCH] fix agent reset (#10562) --- llama_index/agent/runner/base.py | 5 +++ tests/agent/runner/test_base.py | 66 +++++++++++++++++++++++++++++++- 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/llama_index/agent/runner/base.py b/llama_index/agent/runner/base.py index eaf81e39f1..bd05cdbaa4 100644 --- a/llama_index/agent/runner/base.py +++ b/llama_index/agent/runner/base.py @@ -173,6 +173,10 @@ class AgentState(BaseModel): """Get step queue.""" return self.task_dict[task_id].step_queue + def reset(self) -> None: + """Reset.""" + self.task_dict = {} + class AgentRunner(BaseAgentRunner): """Agent runner. @@ -246,6 +250,7 @@ class AgentRunner(BaseAgentRunner): def reset(self) -> None: self.memory.reset() + self.state.reset() def create_task(self, input: str, **kwargs: Any) -> Task: """Create task.""" diff --git a/tests/agent/runner/test_base.py b/tests/agent/runner/test_base.py index 374889392e..46cbe93d2d 100644 --- a/tests/agent/runner/test_base.py +++ b/tests/agent/runner/test_base.py @@ -1,12 +1,13 @@ """Test agent executor.""" import uuid -from typing import Any +from typing import Any, cast from llama_index.agent.runner.base import AgentRunner from llama_index.agent.runner.parallel import ParallelAgentRunner from llama_index.agent.types import BaseAgentWorker, Task, TaskStep, TaskStepOutput from llama_index.chat_engine.types import AgentChatResponse +from llama_index.core.llms.types import ChatMessage, MessageRole # define mock agent worker @@ -64,6 +65,49 @@ class MockAgentWorker(BaseAgentWorker): """Finalize task, after all the steps are completed.""" +# define mock agent worker +class MockAgentWorkerWithMemory(MockAgentWorker): + """Mock agent worker with memory.""" + + def __init__(self, limit: int = 2): + """Initialize.""" + self.limit = limit + + def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep: + """Initialize step from task.""" + # counter will be set to the last value in memory + if len(task.memory.get()) > 0: + start = int(cast(Any, task.memory.get()[-1].content)) + else: + start = 0 + task.extra_state["counter"] = 0 + task.extra_state["start"] = start + return TaskStep( + task_id=task.task_id, + step_id=str(uuid.uuid4()), + input=task.input, + memory=task.memory, + ) + + def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput: + """Run step.""" + task.extra_state["counter"] += 1 + counter = task.extra_state["counter"] + task.extra_state["start"] + is_done = task.extra_state["counter"] >= self.limit + + new_steps = [step.get_next_step(step_id=str(uuid.uuid4()))] + + if is_done: + task.memory.put(ChatMessage(role=MessageRole.USER, content=str(counter))) + + return TaskStepOutput( + output=AgentChatResponse(response=f"counter: {counter}"), + task_step=step, + is_last=is_done, + next_steps=new_steps, + ) + + # define mock agent worker class MockForkStepEngine(BaseAgentWorker): """Mock agent worker that adds an exponential # steps.""" @@ -167,6 +211,26 @@ def test_agent() -> None: assert len(agent_runner.state.task_dict) == 1 +def test_agent_with_reset() -> None: + """Test agents with reset.""" + # test e2e chat + # NOTE: to use chat, output needs to be AgentChatResponse + agent_runner = AgentRunner(agent_worker=MockAgentWorkerWithMemory(limit=10)) + for idx in range(4): + if idx % 2 == 0: + agent_runner.reset() + + response = agent_runner.chat("hello world") + if idx % 2 == 0: + assert str(response) == "counter: 10" + assert len(agent_runner.state.task_dict) == 1 + assert len(agent_runner.memory.get()) == 1 + elif idx % 2 == 1: + assert str(response) == "counter: 20" + assert len(agent_runner.state.task_dict) == 2 + assert len(agent_runner.memory.get()) == 2 + + def test_dag_agent() -> None: """Test DAG agent executor.""" agent_runner = ParallelAgentRunner(agent_worker=MockForkStepEngine(limit=2)) -- GitLab