diff --git a/llama-index-core/llama_index/core/agent/legacy/react/base.py b/llama-index-core/llama_index/core/agent/legacy/react/base.py index cfc512c23d83ab3b6dc0f0612279deb9455a3a85..36565a849756a9fc5d109a6277b8c0637c362e14 100644 --- a/llama-index-core/llama_index/core/agent/legacy/react/base.py +++ b/llama-index-core/llama_index/core/agent/legacy/react/base.py @@ -518,10 +518,10 @@ class ReActAgent(BaseAgent): achat_stream=response_stream, sources=self.sources ) # create task to write chat response to history - asyncio.create_task( + chat_stream_response.awrite_response_to_history_task = asyncio.create_task( chat_stream_response.awrite_response_to_history(self._memory) ) - # thread.start() + return chat_stream_response def get_tools(self, message: str) -> List[AsyncBaseTool]: diff --git a/llama-index-core/llama_index/core/agent/react/step.py b/llama-index-core/llama_index/core/agent/react/step.py index b7c8f6dd106d11a6fe9cf407879bd422d40e772d..e61e79e6c6ab88f5c9c3fb67f81f751395066d43 100644 --- a/llama-index-core/llama_index/core/agent/react/step.py +++ b/llama-index-core/llama_index/core/agent/react/step.py @@ -797,7 +797,7 @@ class ReActAgentWorker(BaseAgentWorker): sources=task.extra_state["sources"], ) # create task to write chat response to history - asyncio.create_task( + agent_response_stream.awrite_response_to_history_task = asyncio.create_task( agent_response_stream.awrite_response_to_history( task.extra_state["new_memory"], on_stream_end_fn=partial(self.finalize_task, task), diff --git a/llama-index-core/llama_index/core/chat_engine/condense_question.py b/llama-index-core/llama_index/core/chat_engine/condense_question.py index bb1c2717cbaa038c894ca6fd478598f464d03749..44d84b98793d733867f49d963d52b93eeb574a74 100644 --- a/llama-index-core/llama_index/core/chat_engine/condense_question.py +++ b/llama-index-core/llama_index/core/chat_engine/condense_question.py @@ -359,7 +359,10 @@ class CondenseQuestionChatEngine(BaseChatEngine): ), sources=[tool_output], ) - asyncio.create_task(response.awrite_response_to_history(self._memory)) + response.awrite_response_to_history_task = asyncio.create_task( + response.awrite_response_to_history(self._memory) + ) + else: raise ValueError("Streaming is not enabled. Please use achat() instead.") return response diff --git a/llama-index-core/llama_index/core/chat_engine/simple.py b/llama-index-core/llama_index/core/chat_engine/simple.py index c96b000b3288833037eaece84dab7b517f56345c..04b62bdfce9b36942ccf896077b80bd9f8257e85 100644 --- a/llama-index-core/llama_index/core/chat_engine/simple.py +++ b/llama-index-core/llama_index/core/chat_engine/simple.py @@ -202,7 +202,9 @@ class SimpleChatEngine(BaseChatEngine): chat_response = StreamingAgentChatResponse( achat_stream=await self._llm.astream_chat(all_messages) ) - asyncio.create_task(chat_response.awrite_response_to_history(self._memory)) + chat_response.awrite_response_to_history_task = asyncio.create_task( + chat_response.awrite_response_to_history(self._memory) + ) return chat_response diff --git a/llama-index-core/llama_index/core/chat_engine/types.py b/llama-index-core/llama_index/core/chat_engine/types.py index a100a697ee08ca760df86c07e17453963b7b74a8..79a0dcfad46b91ff7a30171beebe3078656eeeb3 100644 --- a/llama-index-core/llama_index/core/chat_engine/types.py +++ b/llama-index-core/llama_index/core/chat_engine/types.py @@ -124,6 +124,7 @@ class StreamingAgentChatResponse: is_writing_to_memory: bool = True # Track if an exception occurred exception: Optional[Exception] = None + awrite_response_to_history_task: Optional[asyncio.Task] = None def set_source_nodes(self) -> None: if self.sources and not self.source_nodes: @@ -300,34 +301,44 @@ class StreamingAgentChatResponse: self.response = self.unformatted_response.strip() async def async_response_gen(self) -> AsyncGenerator[str, None]: - self._ensure_async_setup() - assert self.aqueue is not None - - if self.is_writing_to_memory: - while True: - if not self.aqueue.empty() or not self.is_done: - if self.exception is not None: - raise self.exception - - try: - delta = await asyncio.wait_for(self.aqueue.get(), timeout=0.1) - except asyncio.TimeoutError: - if self.is_done: - break - continue - if delta is not None: - self.unformatted_response += delta - yield delta - else: - break - else: - if self.achat_stream is None: - raise ValueError("achat_stream is None!") + try: + self._ensure_async_setup() + assert self.aqueue is not None + + if self.is_writing_to_memory: + while True: + if not self.aqueue.empty() or not self.is_done: + if self.exception is not None: + raise self.exception + + try: + delta = await asyncio.wait_for( + self.aqueue.get(), timeout=0.1 + ) + except asyncio.TimeoutError: + if self.is_done: + break + continue + if delta is not None: + self.unformatted_response += delta + yield delta + else: + break + else: + if self.achat_stream is None: + raise ValueError("achat_stream is None!") - async for chat_response in self.achat_stream: - self.unformatted_response += chat_response.delta or "" - yield chat_response.delta or "" - self.response = self.unformatted_response.strip() + async for chat_response in self.achat_stream: + self.unformatted_response += chat_response.delta or "" + yield chat_response.delta or "" + self.response = self.unformatted_response.strip() + finally: + if self.awrite_response_to_history_task: + # Make sure that the background task ran to completion, retrieve any exceptions + await self.awrite_response_to_history_task + self.awrite_response_to_history_task = ( + None # No need to keep the reference to the finished task + ) def print_response_stream(self) -> None: for token in self.response_gen: diff --git a/llama-index-core/tests/chat_engine/test_simple.py b/llama-index-core/tests/chat_engine/test_simple.py index 20dd816a52591e75190f276ccceb48ab37a48ea2..a307196492fd1124698502792f1b4c505e44ab58 100644 --- a/llama-index-core/tests/chat_engine/test_simple.py +++ b/llama-index-core/tests/chat_engine/test_simple.py @@ -1,3 +1,15 @@ +import gc +import asyncio +from llama_index.core.memory import ChatMemoryBuffer +from llama_index.core.base.llms.types import ( + ChatMessage, + CompletionResponse, + CompletionResponseGen, +) +from typing import Any +from llama_index.core.llms.callbacks import llm_completion_callback +from llama_index.core.llms.mock import MockLLM +import pytest from llama_index.core.base.llms.types import ChatMessage, MessageRole from llama_index.core.chat_engine.simple import SimpleChatEngine @@ -34,3 +46,86 @@ def test_simple_chat_engine_with_init_history() -> None: str(response) == "user: test human message\nassistant: test ai message\n" "user: new human message\nassistant: " ) + + +@pytest.mark.asyncio() +async def test_simple_chat_engine_astream(): + engine = SimpleChatEngine.from_defaults() + response = await engine.astream_chat("Hello World!") + + num_iters = 0 + async for response_part in response.async_response_gen(): + num_iters += 1 + + assert num_iters > 10 + + assert "Hello World!" in response.unformatted_response + assert len(engine.chat_history) == 2 + + response = await engine.astream_chat("What is the capital of the moon?") + + num_iters = 0 + async for _ in response.async_response_gen(): + num_iters += 1 + + assert num_iters > 10 + assert "Hello World!" in response.unformatted_response + assert "What is the capital of the moon?" in response.unformatted_response + + +def test_simple_chat_engine_astream_exception_handling(): + """Test that an exception thrown while retrieving the streamed LLM response gets bubbled up to the user. + Also tests that the non-retrieved exception does not remain in an task that was not awaited leading to + a 'Task exception was never retrieved' message during garbage collection. + """ + + class ExceptionThrownInTest(Exception): + pass + + class ExceptionMockLLM(MockLLM): + """Raises an exception while streaming back the mocked LLM response.""" + + @classmethod + def class_name(cls) -> str: + return "ExceptionMockLLM" + + @llm_completion_callback() + def stream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + def gen_prompt() -> CompletionResponseGen: + for ch in prompt: + yield CompletionResponse( + text=prompt, + delta=ch, + ) + raise ExceptionThrownInTest("Exception thrown for testing purposes") + + return gen_prompt() + + async def async_part(): + engine = SimpleChatEngine.from_defaults( + llm=ExceptionMockLLM(), memory=ChatMemoryBuffer.from_defaults() + ) + response = await engine.astream_chat("Hello World!") + + with pytest.raises(ExceptionThrownInTest): + async for response_part in response.async_response_gen(): + pass + + not_retrieved_exception = False + + def custom_exception_handler(loop, context): + if context.get("message") == "Task exception was never retrieved": + nonlocal not_retrieved_exception + not_retrieved_exception = True + + loop = asyncio.new_event_loop() + loop.set_exception_handler(custom_exception_handler) + result = loop.run_until_complete(async_part()) + loop.close() + gc.collect() + if not_retrieved_exception: + pytest.fail( + "Exception was not correctly handled - ended up in asyncio cleanup performed during garbage collection" + ) diff --git a/llama-index-integrations/agent/llama-index-agent-openai-legacy/llama_index/agent/openai_legacy/openai_agent.py b/llama-index-integrations/agent/llama-index-agent-openai-legacy/llama_index/agent/openai_legacy/openai_agent.py index ebb022f21931ac3ce9bc7d34d3ff749542a529bc..64198deffd00047a905eafbc4d068eeee24411ce 100644 --- a/llama-index-integrations/agent/llama-index-agent-openai-legacy/llama_index/agent/openai_legacy/openai_agent.py +++ b/llama-index-integrations/agent/llama-index-agent-openai-legacy/llama_index/agent/openai_legacy/openai_agent.py @@ -244,7 +244,7 @@ class BaseOpenAIAgent(BaseAgent): sources=self.sources, ) # create task to write chat response to history - asyncio.create_task( + chat_stream_response.awrite_response_to_history_task = asyncio.create_task( chat_stream_response.awrite_response_to_history(self.memory) ) # wait until openAI functions stop executing diff --git a/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py b/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py index 0d00c1dc0eb96fe8755e0f68c683f13211b01b6c..2324b6b0f10b15dbe664923c9d632d191c701adf 100644 --- a/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py +++ b/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py @@ -277,7 +277,7 @@ class OpenAIAgentWorker(BaseAgentWorker): sources=task.extra_state["sources"], ) # create task to write chat response to history - asyncio.create_task( + chat_stream_response.awrite_response_to_history_task = asyncio.create_task( chat_stream_response.awrite_response_to_history( task.extra_state["new_memory"], on_stream_end_fn=partial(self.afinalize_task, task),