From c1b70315bc423957595984aeaad062f059f1f298 Mon Sep 17 00:00:00 2001
From: Andre Wisplinghoff <andre.wisplinghoff@commerzbank.com>
Date: Fri, 14 Feb 2025 23:14:06 +0100
Subject: [PATCH] Keep a reference to asyncio tasks in astream_chat() (#17812)

---
 .../core/agent/legacy/react/base.py           |  4 +-
 .../llama_index/core/agent/react/step.py      |  2 +-
 .../core/chat_engine/condense_question.py     |  5 +-
 .../llama_index/core/chat_engine/simple.py    |  4 +-
 .../llama_index/core/chat_engine/types.py     | 65 +++++++------
 .../tests/chat_engine/test_simple.py          | 95 +++++++++++++++++++
 .../agent/openai_legacy/openai_agent.py       |  2 +-
 .../llama_index/agent/openai/step.py          |  2 +-
 8 files changed, 145 insertions(+), 34 deletions(-)

diff --git a/llama-index-core/llama_index/core/agent/legacy/react/base.py b/llama-index-core/llama_index/core/agent/legacy/react/base.py
index cfc512c23d..36565a8497 100644
--- a/llama-index-core/llama_index/core/agent/legacy/react/base.py
+++ b/llama-index-core/llama_index/core/agent/legacy/react/base.py
@@ -518,10 +518,10 @@ class ReActAgent(BaseAgent):
             achat_stream=response_stream, sources=self.sources
         )
         # create task to write chat response to history
-        asyncio.create_task(
+        chat_stream_response.awrite_response_to_history_task = asyncio.create_task(
             chat_stream_response.awrite_response_to_history(self._memory)
         )
-        # thread.start()
+
         return chat_stream_response
 
     def get_tools(self, message: str) -> List[AsyncBaseTool]:
diff --git a/llama-index-core/llama_index/core/agent/react/step.py b/llama-index-core/llama_index/core/agent/react/step.py
index b7c8f6dd10..e61e79e6c6 100644
--- a/llama-index-core/llama_index/core/agent/react/step.py
+++ b/llama-index-core/llama_index/core/agent/react/step.py
@@ -797,7 +797,7 @@ class ReActAgentWorker(BaseAgentWorker):
                 sources=task.extra_state["sources"],
             )
             # create task to write chat response to history
-            asyncio.create_task(
+            agent_response_stream.awrite_response_to_history_task = asyncio.create_task(
                 agent_response_stream.awrite_response_to_history(
                     task.extra_state["new_memory"],
                     on_stream_end_fn=partial(self.finalize_task, task),
diff --git a/llama-index-core/llama_index/core/chat_engine/condense_question.py b/llama-index-core/llama_index/core/chat_engine/condense_question.py
index bb1c2717cb..44d84b9879 100644
--- a/llama-index-core/llama_index/core/chat_engine/condense_question.py
+++ b/llama-index-core/llama_index/core/chat_engine/condense_question.py
@@ -359,7 +359,10 @@ class CondenseQuestionChatEngine(BaseChatEngine):
                 ),
                 sources=[tool_output],
             )
-            asyncio.create_task(response.awrite_response_to_history(self._memory))
+            response.awrite_response_to_history_task = asyncio.create_task(
+                response.awrite_response_to_history(self._memory)
+            )
+
         else:
             raise ValueError("Streaming is not enabled. Please use achat() instead.")
         return response
diff --git a/llama-index-core/llama_index/core/chat_engine/simple.py b/llama-index-core/llama_index/core/chat_engine/simple.py
index c96b000b32..04b62bdfce 100644
--- a/llama-index-core/llama_index/core/chat_engine/simple.py
+++ b/llama-index-core/llama_index/core/chat_engine/simple.py
@@ -202,7 +202,9 @@ class SimpleChatEngine(BaseChatEngine):
         chat_response = StreamingAgentChatResponse(
             achat_stream=await self._llm.astream_chat(all_messages)
         )
-        asyncio.create_task(chat_response.awrite_response_to_history(self._memory))
+        chat_response.awrite_response_to_history_task = asyncio.create_task(
+            chat_response.awrite_response_to_history(self._memory)
+        )
 
         return chat_response
 
diff --git a/llama-index-core/llama_index/core/chat_engine/types.py b/llama-index-core/llama_index/core/chat_engine/types.py
index a100a697ee..79a0dcfad4 100644
--- a/llama-index-core/llama_index/core/chat_engine/types.py
+++ b/llama-index-core/llama_index/core/chat_engine/types.py
@@ -124,6 +124,7 @@ class StreamingAgentChatResponse:
     is_writing_to_memory: bool = True
     # Track if an exception occurred
     exception: Optional[Exception] = None
+    awrite_response_to_history_task: Optional[asyncio.Task] = None
 
     def set_source_nodes(self) -> None:
         if self.sources and not self.source_nodes:
@@ -300,34 +301,44 @@ class StreamingAgentChatResponse:
         self.response = self.unformatted_response.strip()
 
     async def async_response_gen(self) -> AsyncGenerator[str, None]:
-        self._ensure_async_setup()
-        assert self.aqueue is not None
-
-        if self.is_writing_to_memory:
-            while True:
-                if not self.aqueue.empty() or not self.is_done:
-                    if self.exception is not None:
-                        raise self.exception
-
-                    try:
-                        delta = await asyncio.wait_for(self.aqueue.get(), timeout=0.1)
-                    except asyncio.TimeoutError:
-                        if self.is_done:
-                            break
-                        continue
-                    if delta is not None:
-                        self.unformatted_response += delta
-                        yield delta
-                else:
-                    break
-        else:
-            if self.achat_stream is None:
-                raise ValueError("achat_stream is None!")
+        try:
+            self._ensure_async_setup()
+            assert self.aqueue is not None
+
+            if self.is_writing_to_memory:
+                while True:
+                    if not self.aqueue.empty() or not self.is_done:
+                        if self.exception is not None:
+                            raise self.exception
+
+                        try:
+                            delta = await asyncio.wait_for(
+                                self.aqueue.get(), timeout=0.1
+                            )
+                        except asyncio.TimeoutError:
+                            if self.is_done:
+                                break
+                            continue
+                        if delta is not None:
+                            self.unformatted_response += delta
+                            yield delta
+                    else:
+                        break
+            else:
+                if self.achat_stream is None:
+                    raise ValueError("achat_stream is None!")
 
-            async for chat_response in self.achat_stream:
-                self.unformatted_response += chat_response.delta or ""
-                yield chat_response.delta or ""
-        self.response = self.unformatted_response.strip()
+                async for chat_response in self.achat_stream:
+                    self.unformatted_response += chat_response.delta or ""
+                    yield chat_response.delta or ""
+            self.response = self.unformatted_response.strip()
+        finally:
+            if self.awrite_response_to_history_task:
+                # Make sure that the background task ran to completion, retrieve any exceptions
+                await self.awrite_response_to_history_task
+                self.awrite_response_to_history_task = (
+                    None  # No need to keep the reference to the finished task
+                )
 
     def print_response_stream(self) -> None:
         for token in self.response_gen:
diff --git a/llama-index-core/tests/chat_engine/test_simple.py b/llama-index-core/tests/chat_engine/test_simple.py
index 20dd816a52..a307196492 100644
--- a/llama-index-core/tests/chat_engine/test_simple.py
+++ b/llama-index-core/tests/chat_engine/test_simple.py
@@ -1,3 +1,15 @@
+import gc
+import asyncio
+from llama_index.core.memory import ChatMemoryBuffer
+from llama_index.core.base.llms.types import (
+    ChatMessage,
+    CompletionResponse,
+    CompletionResponseGen,
+)
+from typing import Any
+from llama_index.core.llms.callbacks import llm_completion_callback
+from llama_index.core.llms.mock import MockLLM
+import pytest
 from llama_index.core.base.llms.types import ChatMessage, MessageRole
 from llama_index.core.chat_engine.simple import SimpleChatEngine
 
@@ -34,3 +46,86 @@ def test_simple_chat_engine_with_init_history() -> None:
         str(response) == "user: test human message\nassistant: test ai message\n"
         "user: new human message\nassistant: "
     )
+
+
+@pytest.mark.asyncio()
+async def test_simple_chat_engine_astream():
+    engine = SimpleChatEngine.from_defaults()
+    response = await engine.astream_chat("Hello World!")
+
+    num_iters = 0
+    async for response_part in response.async_response_gen():
+        num_iters += 1
+
+    assert num_iters > 10
+
+    assert "Hello World!" in response.unformatted_response
+    assert len(engine.chat_history) == 2
+
+    response = await engine.astream_chat("What is the capital of the moon?")
+
+    num_iters = 0
+    async for _ in response.async_response_gen():
+        num_iters += 1
+
+    assert num_iters > 10
+    assert "Hello World!" in response.unformatted_response
+    assert "What is the capital of the moon?" in response.unformatted_response
+
+
+def test_simple_chat_engine_astream_exception_handling():
+    """Test that an exception thrown while retrieving the streamed LLM response gets bubbled up to the user.
+    Also tests that the non-retrieved exception does not remain in an task that was not awaited leading to
+    a 'Task exception was never retrieved' message during garbage collection.
+    """
+
+    class ExceptionThrownInTest(Exception):
+        pass
+
+    class ExceptionMockLLM(MockLLM):
+        """Raises an exception while streaming back the mocked LLM response."""
+
+        @classmethod
+        def class_name(cls) -> str:
+            return "ExceptionMockLLM"
+
+        @llm_completion_callback()
+        def stream_complete(
+            self, prompt: str, formatted: bool = False, **kwargs: Any
+        ) -> CompletionResponseGen:
+            def gen_prompt() -> CompletionResponseGen:
+                for ch in prompt:
+                    yield CompletionResponse(
+                        text=prompt,
+                        delta=ch,
+                    )
+                raise ExceptionThrownInTest("Exception thrown for testing purposes")
+
+            return gen_prompt()
+
+    async def async_part():
+        engine = SimpleChatEngine.from_defaults(
+            llm=ExceptionMockLLM(), memory=ChatMemoryBuffer.from_defaults()
+        )
+        response = await engine.astream_chat("Hello World!")
+
+        with pytest.raises(ExceptionThrownInTest):
+            async for response_part in response.async_response_gen():
+                pass
+
+    not_retrieved_exception = False
+
+    def custom_exception_handler(loop, context):
+        if context.get("message") == "Task exception was never retrieved":
+            nonlocal not_retrieved_exception
+            not_retrieved_exception = True
+
+    loop = asyncio.new_event_loop()
+    loop.set_exception_handler(custom_exception_handler)
+    result = loop.run_until_complete(async_part())
+    loop.close()
+    gc.collect()
+    if not_retrieved_exception:
+        pytest.fail(
+            "Exception was not correctly handled - ended up in asyncio cleanup performed during garbage collection"
+        )
diff --git a/llama-index-integrations/agent/llama-index-agent-openai-legacy/llama_index/agent/openai_legacy/openai_agent.py b/llama-index-integrations/agent/llama-index-agent-openai-legacy/llama_index/agent/openai_legacy/openai_agent.py
index ebb022f219..64198deffd 100644
--- a/llama-index-integrations/agent/llama-index-agent-openai-legacy/llama_index/agent/openai_legacy/openai_agent.py
+++ b/llama-index-integrations/agent/llama-index-agent-openai-legacy/llama_index/agent/openai_legacy/openai_agent.py
@@ -244,7 +244,7 @@ class BaseOpenAIAgent(BaseAgent):
             sources=self.sources,
         )
         # create task to write chat response to history
-        asyncio.create_task(
+        chat_stream_response.awrite_response_to_history_task = asyncio.create_task(
             chat_stream_response.awrite_response_to_history(self.memory)
         )
         # wait until openAI functions stop executing
diff --git a/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py b/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py
index 0d00c1dc0e..2324b6b0f1 100644
--- a/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py
+++ b/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py
@@ -277,7 +277,7 @@ class OpenAIAgentWorker(BaseAgentWorker):
             sources=task.extra_state["sources"],
         )
         # create task to write chat response to history
-        asyncio.create_task(
+        chat_stream_response.awrite_response_to_history_task = asyncio.create_task(
             chat_stream_response.awrite_response_to_history(
                 task.extra_state["new_memory"],
                 on_stream_end_fn=partial(self.afinalize_task, task),
-- 
GitLab