diff --git a/llama-index-core/llama_index/core/agent/react/step.py b/llama-index-core/llama_index/core/agent/react/step.py index beee2fdc55e1c80e332b1cd32b29e6a4404f50e3..70187062fcc73e792253df21fb7f6530d549a5f5 100644 --- a/llama-index-core/llama_index/core/agent/react/step.py +++ b/llama-index-core/llama_index/core/agent/react/step.py @@ -2,6 +2,7 @@ import asyncio import uuid +from functools import partial from itertools import chain from threading import Thread from typing import ( @@ -529,6 +530,7 @@ class ReActAgentWorker(BaseAgentWorker): thread = Thread( target=agent_response.write_response_to_history, args=(task.extra_state["new_memory"],), + kwargs={"on_stream_end_fn": partial(self.finalize_task, task)}, ) thread.start() @@ -592,7 +594,8 @@ class ReActAgentWorker(BaseAgentWorker): # create task to write chat response to history asyncio.create_task( agent_response.awrite_response_to_history( - task.extra_state["new_memory"] + task.extra_state["new_memory"], + on_stream_end_fn=partial(self.finalize_task, task), ) ) # wait until response writing is done @@ -628,7 +631,9 @@ class ReActAgentWorker(BaseAgentWorker): def finalize_task(self, task: Task, **kwargs: Any) -> None: """Finalize task, after all the steps are completed.""" # add new messages to memory - task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all()) + task.memory.set( + task.memory.get_all() + task.extra_state["new_memory"].get_all() + ) # reset new memory task.extra_state["new_memory"].reset() diff --git a/llama-index-core/llama_index/core/agent/runner/base.py b/llama-index-core/llama_index/core/agent/runner/base.py index bfdc17b94b43b6f411851cf31724db0c2ad88952..3419b1e37bff8f10f8a41683b14918579e16548a 100644 --- a/llama-index-core/llama_index/core/agent/runner/base.py +++ b/llama-index-core/llama_index/core/agent/runner/base.py @@ -528,7 +528,10 @@ class AgentRunner(BaseAgentRunner): # ensure tool_choice does not cause endless loops tool_choice = "auto" - return self.finalize_response(task.task_id, result_output) + return self.finalize_response( + task.task_id, + result_output, + ) async def _achat( self, @@ -556,7 +559,10 @@ class AgentRunner(BaseAgentRunner): # ensure tool_choice does not cause endless loops tool_choice = "auto" - return self.finalize_response(task.task_id, result_output) + return self.finalize_response( + task.task_id, + result_output, + ) @trace_method("chat") def chat( diff --git a/llama-index-core/llama_index/core/agent/runner/parallel.py b/llama-index-core/llama_index/core/agent/runner/parallel.py index 5c5ae0ace2fdab6324e9a6d380eee67ca3433b8e..55d0adfde25476d3e814b39cd970866af05e5a0b 100644 --- a/llama-index-core/llama_index/core/agent/runner/parallel.py +++ b/llama-index-core/llama_index/core/agent/runner/parallel.py @@ -361,7 +361,10 @@ class ParallelAgentRunner(BaseAgentRunner): result_output = cur_step_output break - return self.finalize_response(task.task_id, result_output) + return self.finalize_response( + task.task_id, + result_output, + ) async def _achat( self, @@ -393,7 +396,10 @@ class ParallelAgentRunner(BaseAgentRunner): result_output = cur_step_output break - return self.finalize_response(task.task_id, result_output) + return self.finalize_response( + task.task_id, + result_output, + ) @trace_method("chat") def chat( diff --git a/llama-index-core/llama_index/core/chat_engine/types.py b/llama-index-core/llama_index/core/chat_engine/types.py index 124e1f5d6bca491d4bad26616fb2b231762b0501..29d27731bc284f3bd88919d84c4d7702f8711a72 100644 --- a/llama-index-core/llama_index/core/chat_engine/types.py +++ b/llama-index-core/llama_index/core/chat_engine/types.py @@ -100,7 +100,10 @@ class StreamingAgentChatResponse: self._new_item_event.set() def write_response_to_history( - self, memory: BaseMemory, raise_error: bool = False + self, + memory: BaseMemory, + on_stream_end_fn: Optional[callable] = None, + raise_error: bool = False, ) -> None: if self.chat_stream is None: raise ValueError( @@ -131,10 +134,13 @@ class StreamingAgentChatResponse: # This act as is_done events for any consumers waiting self._is_function_not_none_thread_event.set() + if on_stream_end_fn is not None and not self._is_function: + on_stream_end_fn() async def awrite_response_to_history( self, memory: BaseMemory, + on_stream_end_fn: Optional[callable] = None, ) -> None: if self.achat_stream is None: raise ValueError( @@ -164,6 +170,8 @@ class StreamingAgentChatResponse: # These act as is_done events for any consumers waiting self._is_function_false_event.set() self._new_item_event.set() + if on_stream_end_fn is not None and not self._is_function: + on_stream_end_fn() @property def response_gen(self) -> Generator[str, None, None]: diff --git a/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py b/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py index 08d13368aa11415ff07b75ee43447490ddb92b4a..3a3f7b22ea33e3ee4ddbc1c4c0cf8cf4a6e0bfbf 100644 --- a/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py +++ b/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py @@ -4,6 +4,7 @@ import asyncio import json import logging import uuid +from functools import partial from threading import Thread from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args @@ -285,6 +286,7 @@ class OpenAIAgentWorker(BaseAgentWorker): thread = Thread( target=chat_stream_response.write_response_to_history, args=(task.extra_state["new_memory"],), + kwargs={"on_stream_end_fn": partial(self.finalize_task, task)}, ) thread.start() # Wait for the event to be set @@ -306,7 +308,8 @@ class OpenAIAgentWorker(BaseAgentWorker): # create task to write chat response to history asyncio.create_task( chat_stream_response.awrite_response_to_history( - task.extra_state["new_memory"] + task.extra_state["new_memory"], + on_stream_end_fn=partial(self.finalize_task, task), ) ) # wait until openAI functions stop executing @@ -605,7 +608,9 @@ class OpenAIAgentWorker(BaseAgentWorker): def finalize_task(self, task: Task, **kwargs: Any) -> None: """Finalize task, after all the steps are completed.""" # add new messages to memory - task.memory.set(task.memory.get() + task.extra_state["new_memory"].get_all()) + task.memory.set( + task.memory.get_all() + task.extra_state["new_memory"].get_all() + ) # reset new memory task.extra_state["new_memory"].reset() diff --git a/llama-index-integrations/agent/llama-index-agent-openai/tests/test_openai_agent.py b/llama-index-integrations/agent/llama-index-agent-openai/tests/test_openai_agent.py index e1665ffc17ca795337c8dbf5530a8c2530268dd4..1e860584c9e693b101ba1f80fc6aecb1dbbb4a02 100644 --- a/llama-index-integrations/agent/llama-index-agent-openai/tests/test_openai_agent.py +++ b/llama-index-integrations/agent/llama-index-agent-openai/tests/test_openai_agent.py @@ -148,6 +148,9 @@ def test_chat_basic(MockSyncOpenAI: MagicMock, add_tool: FunctionTool) -> None: response = agent.chat("What is 1 + 1?") assert isinstance(response, AgentChatResponse) assert response.response == "\n\nThis is a test!" + assert len(agent.chat_history) == 2 + assert agent.chat_history[0].content == "What is 1 + 1?" + assert agent.chat_history[1].content == "\n\nThis is a test!" @patch("llama_index.llms.openai.base.AsyncOpenAI") @@ -165,6 +168,9 @@ async def test_achat_basic(MockAsyncOpenAI: MagicMock, add_tool: FunctionTool) - response = await agent.achat("What is 1 + 1?") assert isinstance(response, AgentChatResponse) assert response.response == "\n\nThis is a test!" + assert len(agent.chat_history) == 2 + assert agent.chat_history[0].content == "What is 1 + 1?" + assert agent.chat_history[1].content == "\n\nThis is a test!" @patch("llama_index.llms.openai.base.SyncOpenAI") @@ -182,6 +188,9 @@ def test_stream_chat_basic(MockSyncOpenAI: MagicMock, add_tool: FunctionTool) -> assert isinstance(response, StreamingAgentChatResponse) # str() strips newline values assert str(response) == "This is a test!" + assert len(agent.chat_history) == 2 + assert agent.chat_history[0].content == "What is 1 + 1?" + assert agent.chat_history[1].content == "This is a test!" @patch("llama_index.llms.openai.base.AsyncOpenAI") @@ -204,6 +213,9 @@ async def test_astream_chat_basic( assert isinstance(response_stream, StreamingAgentChatResponse) # str() strips newline values assert response == "\n\nThis is a test!" + assert len(agent.chat_history) == 2 + assert agent.chat_history[0].content == "What is 1 + 1?" + assert agent.chat_history[1].content == "This is a test!" @patch("llama_index.llms.openai.base.SyncOpenAI") @@ -319,6 +331,11 @@ async def test_async_add_step( # add human input (not used but should be in memory) task = agent.create_task("What is 1 + 1?") mock_instance.chat.completions.create.side_effect = mock_achat_stream + + # stream the output to ensure it gets written to memory step_output = await agent.astream_step(task.task_id, input="tmp") - chat_history = task.extra_state["new_memory"].get_all() + async for _ in step_output.output.async_response_gen(): + pass + + chat_history = task.memory.get_all() assert "tmp" in [m.content for m in chat_history]