Skip to content
Snippets Groups Projects
Unverified Commit 60b75cb0 authored by Jerry Liu's avatar Jerry Liu Committed by GitHub
Browse files

fix agent reset (#10562)

parent f6a71735
Branches
Tags
No related merge requests found
...@@ -173,6 +173,10 @@ class AgentState(BaseModel): ...@@ -173,6 +173,10 @@ class AgentState(BaseModel):
"""Get step queue.""" """Get step queue."""
return self.task_dict[task_id].step_queue return self.task_dict[task_id].step_queue
def reset(self) -> None:
"""Reset."""
self.task_dict = {}
class AgentRunner(BaseAgentRunner): class AgentRunner(BaseAgentRunner):
"""Agent runner. """Agent runner.
...@@ -246,6 +250,7 @@ class AgentRunner(BaseAgentRunner): ...@@ -246,6 +250,7 @@ class AgentRunner(BaseAgentRunner):
def reset(self) -> None: def reset(self) -> None:
self.memory.reset() self.memory.reset()
self.state.reset()
def create_task(self, input: str, **kwargs: Any) -> Task: def create_task(self, input: str, **kwargs: Any) -> Task:
"""Create task.""" """Create task."""
......
"""Test agent executor.""" """Test agent executor."""
import uuid import uuid
from typing import Any from typing import Any, cast
from llama_index.agent.runner.base import AgentRunner from llama_index.agent.runner.base import AgentRunner
from llama_index.agent.runner.parallel import ParallelAgentRunner from llama_index.agent.runner.parallel import ParallelAgentRunner
from llama_index.agent.types import BaseAgentWorker, Task, TaskStep, TaskStepOutput from llama_index.agent.types import BaseAgentWorker, Task, TaskStep, TaskStepOutput
from llama_index.chat_engine.types import AgentChatResponse from llama_index.chat_engine.types import AgentChatResponse
from llama_index.core.llms.types import ChatMessage, MessageRole
# define mock agent worker # define mock agent worker
...@@ -64,6 +65,49 @@ class MockAgentWorker(BaseAgentWorker): ...@@ -64,6 +65,49 @@ class MockAgentWorker(BaseAgentWorker):
"""Finalize task, after all the steps are completed.""" """Finalize task, after all the steps are completed."""
# define mock agent worker
class MockAgentWorkerWithMemory(MockAgentWorker):
"""Mock agent worker with memory."""
def __init__(self, limit: int = 2):
"""Initialize."""
self.limit = limit
def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
"""Initialize step from task."""
# counter will be set to the last value in memory
if len(task.memory.get()) > 0:
start = int(cast(Any, task.memory.get()[-1].content))
else:
start = 0
task.extra_state["counter"] = 0
task.extra_state["start"] = start
return TaskStep(
task_id=task.task_id,
step_id=str(uuid.uuid4()),
input=task.input,
memory=task.memory,
)
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step."""
task.extra_state["counter"] += 1
counter = task.extra_state["counter"] + task.extra_state["start"]
is_done = task.extra_state["counter"] >= self.limit
new_steps = [step.get_next_step(step_id=str(uuid.uuid4()))]
if is_done:
task.memory.put(ChatMessage(role=MessageRole.USER, content=str(counter)))
return TaskStepOutput(
output=AgentChatResponse(response=f"counter: {counter}"),
task_step=step,
is_last=is_done,
next_steps=new_steps,
)
# define mock agent worker # define mock agent worker
class MockForkStepEngine(BaseAgentWorker): class MockForkStepEngine(BaseAgentWorker):
"""Mock agent worker that adds an exponential # steps.""" """Mock agent worker that adds an exponential # steps."""
...@@ -167,6 +211,26 @@ def test_agent() -> None: ...@@ -167,6 +211,26 @@ def test_agent() -> None:
assert len(agent_runner.state.task_dict) == 1 assert len(agent_runner.state.task_dict) == 1
def test_agent_with_reset() -> None:
"""Test agents with reset."""
# test e2e chat
# NOTE: to use chat, output needs to be AgentChatResponse
agent_runner = AgentRunner(agent_worker=MockAgentWorkerWithMemory(limit=10))
for idx in range(4):
if idx % 2 == 0:
agent_runner.reset()
response = agent_runner.chat("hello world")
if idx % 2 == 0:
assert str(response) == "counter: 10"
assert len(agent_runner.state.task_dict) == 1
assert len(agent_runner.memory.get()) == 1
elif idx % 2 == 1:
assert str(response) == "counter: 20"
assert len(agent_runner.state.task_dict) == 2
assert len(agent_runner.memory.get()) == 2
def test_dag_agent() -> None: def test_dag_agent() -> None:
"""Test DAG agent executor.""" """Test DAG agent executor."""
agent_runner = ParallelAgentRunner(agent_worker=MockForkStepEngine(limit=2)) agent_runner = ParallelAgentRunner(agent_worker=MockForkStepEngine(limit=2))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment