Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Llama Index
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
mirrored_repos
MachineLearning
run-llama
Llama Index
Commits
60b75cb0
Unverified
Commit
60b75cb0
authored
1 year ago
by
Jerry Liu
Committed by
GitHub
1 year ago
Browse files
Options
Downloads
Patches
Plain Diff
fix agent reset (#10562)
parent
f6a71735
Branches
Branches containing commit
Tags
Tags containing commit
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
llama_index/agent/runner/base.py
+5
-0
5 additions, 0 deletions
llama_index/agent/runner/base.py
tests/agent/runner/test_base.py
+65
-1
65 additions, 1 deletion
tests/agent/runner/test_base.py
with
70 additions
and
1 deletion
llama_index/agent/runner/base.py
+
5
−
0
View file @
60b75cb0
...
@@ -173,6 +173,10 @@ class AgentState(BaseModel):
...
@@ -173,6 +173,10 @@ class AgentState(BaseModel):
"""
Get step queue.
"""
"""
Get step queue.
"""
return
self
.
task_dict
[
task_id
].
step_queue
return
self
.
task_dict
[
task_id
].
step_queue
def
reset
(
self
)
->
None
:
"""
Reset.
"""
self
.
task_dict
=
{}
class
AgentRunner
(
BaseAgentRunner
):
class
AgentRunner
(
BaseAgentRunner
):
"""
Agent runner.
"""
Agent runner.
...
@@ -246,6 +250,7 @@ class AgentRunner(BaseAgentRunner):
...
@@ -246,6 +250,7 @@ class AgentRunner(BaseAgentRunner):
def
reset
(
self
)
->
None
:
def
reset
(
self
)
->
None
:
self
.
memory
.
reset
()
self
.
memory
.
reset
()
self
.
state
.
reset
()
def
create_task
(
self
,
input
:
str
,
**
kwargs
:
Any
)
->
Task
:
def
create_task
(
self
,
input
:
str
,
**
kwargs
:
Any
)
->
Task
:
"""
Create task.
"""
"""
Create task.
"""
...
...
This diff is collapsed.
Click to expand it.
tests/agent/runner/test_base.py
+
65
−
1
View file @
60b75cb0
"""
Test agent executor.
"""
"""
Test agent executor.
"""
import
uuid
import
uuid
from
typing
import
Any
from
typing
import
Any
,
cast
from
llama_index.agent.runner.base
import
AgentRunner
from
llama_index.agent.runner.base
import
AgentRunner
from
llama_index.agent.runner.parallel
import
ParallelAgentRunner
from
llama_index.agent.runner.parallel
import
ParallelAgentRunner
from
llama_index.agent.types
import
BaseAgentWorker
,
Task
,
TaskStep
,
TaskStepOutput
from
llama_index.agent.types
import
BaseAgentWorker
,
Task
,
TaskStep
,
TaskStepOutput
from
llama_index.chat_engine.types
import
AgentChatResponse
from
llama_index.chat_engine.types
import
AgentChatResponse
from
llama_index.core.llms.types
import
ChatMessage
,
MessageRole
# define mock agent worker
# define mock agent worker
...
@@ -64,6 +65,49 @@ class MockAgentWorker(BaseAgentWorker):
...
@@ -64,6 +65,49 @@ class MockAgentWorker(BaseAgentWorker):
"""
Finalize task, after all the steps are completed.
"""
"""
Finalize task, after all the steps are completed.
"""
# define mock agent worker
class
MockAgentWorkerWithMemory
(
MockAgentWorker
):
"""
Mock agent worker with memory.
"""
def
__init__
(
self
,
limit
:
int
=
2
):
"""
Initialize.
"""
self
.
limit
=
limit
def
initialize_step
(
self
,
task
:
Task
,
**
kwargs
:
Any
)
->
TaskStep
:
"""
Initialize step from task.
"""
# counter will be set to the last value in memory
if
len
(
task
.
memory
.
get
())
>
0
:
start
=
int
(
cast
(
Any
,
task
.
memory
.
get
()[
-
1
].
content
))
else
:
start
=
0
task
.
extra_state
[
"
counter
"
]
=
0
task
.
extra_state
[
"
start
"
]
=
start
return
TaskStep
(
task_id
=
task
.
task_id
,
step_id
=
str
(
uuid
.
uuid4
()),
input
=
task
.
input
,
memory
=
task
.
memory
,
)
def
run_step
(
self
,
step
:
TaskStep
,
task
:
Task
,
**
kwargs
:
Any
)
->
TaskStepOutput
:
"""
Run step.
"""
task
.
extra_state
[
"
counter
"
]
+=
1
counter
=
task
.
extra_state
[
"
counter
"
]
+
task
.
extra_state
[
"
start
"
]
is_done
=
task
.
extra_state
[
"
counter
"
]
>=
self
.
limit
new_steps
=
[
step
.
get_next_step
(
step_id
=
str
(
uuid
.
uuid4
()))]
if
is_done
:
task
.
memory
.
put
(
ChatMessage
(
role
=
MessageRole
.
USER
,
content
=
str
(
counter
)))
return
TaskStepOutput
(
output
=
AgentChatResponse
(
response
=
f
"
counter:
{
counter
}
"
),
task_step
=
step
,
is_last
=
is_done
,
next_steps
=
new_steps
,
)
# define mock agent worker
# define mock agent worker
class
MockForkStepEngine
(
BaseAgentWorker
):
class
MockForkStepEngine
(
BaseAgentWorker
):
"""
Mock agent worker that adds an exponential # steps.
"""
"""
Mock agent worker that adds an exponential # steps.
"""
...
@@ -167,6 +211,26 @@ def test_agent() -> None:
...
@@ -167,6 +211,26 @@ def test_agent() -> None:
assert
len
(
agent_runner
.
state
.
task_dict
)
==
1
assert
len
(
agent_runner
.
state
.
task_dict
)
==
1
def
test_agent_with_reset
()
->
None
:
"""
Test agents with reset.
"""
# test e2e chat
# NOTE: to use chat, output needs to be AgentChatResponse
agent_runner
=
AgentRunner
(
agent_worker
=
MockAgentWorkerWithMemory
(
limit
=
10
))
for
idx
in
range
(
4
):
if
idx
%
2
==
0
:
agent_runner
.
reset
()
response
=
agent_runner
.
chat
(
"
hello world
"
)
if
idx
%
2
==
0
:
assert
str
(
response
)
==
"
counter: 10
"
assert
len
(
agent_runner
.
state
.
task_dict
)
==
1
assert
len
(
agent_runner
.
memory
.
get
())
==
1
elif
idx
%
2
==
1
:
assert
str
(
response
)
==
"
counter: 20
"
assert
len
(
agent_runner
.
state
.
task_dict
)
==
2
assert
len
(
agent_runner
.
memory
.
get
())
==
2
def
test_dag_agent
()
->
None
:
def
test_dag_agent
()
->
None
:
"""
Test DAG agent executor.
"""
"""
Test DAG agent executor.
"""
agent_runner
=
ParallelAgentRunner
(
agent_worker
=
MockForkStepEngine
(
limit
=
2
))
agent_runner
=
ParallelAgentRunner
(
agent_worker
=
MockForkStepEngine
(
limit
=
2
))
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment