From 02b12c34f8645668134be98b91c37e9cf43e3ee3 Mon Sep 17 00:00:00 2001
From: Daniel J <38250010+Kigstn@users.noreply.github.com>
Date: Fri, 15 Mar 2024 02:02:38 +0100
Subject: [PATCH] Feat: Add async streaming support to `query_engine` (#11949)

---
 .../llama_index/core/base/response/schema.py  | 68 ++++++++++++++++++-
 .../core/response_synthesizers/base.py        | 15 +++-
 .../core/response_synthesizers/refine.py      | 41 +++++++++--
 llama-index-core/llama_index/core/types.py    |  2 +-
 4 files changed, 115 insertions(+), 11 deletions(-)

diff --git a/llama-index-core/llama_index/core/base/response/schema.py b/llama-index-core/llama_index/core/base/response/schema.py
index c10730956..abaed38f1 100644
--- a/llama-index-core/llama_index/core/base/response/schema.py
+++ b/llama-index-core/llama_index/core/base/response/schema.py
@@ -1,11 +1,11 @@
 """Response schema."""
-
+import asyncio
 from dataclasses import dataclass, field
 from typing import Any, Dict, List, Optional, Union
 
 from llama_index.core.bridge.pydantic import BaseModel
 from llama_index.core.schema import NodeWithScore
-from llama_index.core.types import TokenGen
+from llama_index.core.types import TokenGen, TokenAsyncGen
 from llama_index.core.utils import truncate_text
 
 
@@ -139,4 +139,66 @@ class StreamingResponse:
         return "\n\n".join(texts)
 
 
-RESPONSE_TYPE = Union[Response, StreamingResponse, PydanticResponse]
+@dataclass
+class AsyncStreamingResponse:
+    """AsyncStreamingResponse object.
+
+    Returned if streaming=True while using async.
+
+    Attributes:
+        async_response_gen: The response async generator.
+
+    """
+
+    async_response_gen: TokenAsyncGen
+    source_nodes: List[NodeWithScore] = field(default_factory=list)
+    metadata: Optional[Dict[str, Any]] = None
+    response_txt: Optional[str] = None
+    _lock: asyncio.Lock = field(default_factory=asyncio.Lock)
+
+    async def _yield_response(self) -> TokenAsyncGen:
+        """Yield the string response."""
+        async with self._lock:
+            if self.response_txt is None and self.async_response_gen is not None:
+                self.response_txt = ""
+                async for text in self.async_response_gen:
+                    self.response_txt += text
+                    yield text
+            else:
+                yield self.response_txt
+
+    async def async_response_gen(self) -> TokenAsyncGen:
+        """Yield the string response."""
+        async for text in self._yield_response():
+            yield text
+
+    async def get_response(self) -> Response:
+        """Get a standard response object."""
+        async for _ in self._yield_response():
+            ...
+        return Response(self.response_txt, self.source_nodes, self.metadata)
+
+    async def print_response_stream(self) -> None:
+        """Print the response stream."""
+        streaming = True
+        async for text in self._yield_response():
+            print(text, end="", flush=True)
+        # do an empty print to print on the next line again next time
+        print()
+
+    def get_formatted_sources(self, length: int = 100, trim_text: int = True) -> str:
+        """Get formatted sources text."""
+        texts = []
+        for source_node in self.source_nodes:
+            fmt_text_chunk = source_node.node.get_content()
+            if trim_text:
+                fmt_text_chunk = truncate_text(fmt_text_chunk, length)
+            node_id = source_node.node.node_id or "None"
+            source_text = f"> Source (Node id: {node_id}): {fmt_text_chunk}"
+            texts.append(source_text)
+        return "\n\n".join(texts)
+
+
+RESPONSE_TYPE = Union[
+    Response, StreamingResponse, AsyncStreamingResponse, PydanticResponse
+]
diff --git a/llama-index-core/llama_index/core/response_synthesizers/base.py b/llama-index-core/llama_index/core/response_synthesizers/base.py
index 7c0e9cf0f..967a3bf81 100644
--- a/llama-index-core/llama_index/core/response_synthesizers/base.py
+++ b/llama-index-core/llama_index/core/response_synthesizers/base.py
@@ -10,7 +10,7 @@ Will support different modes, from 1) stuffing chunks into prompt,
 
 import logging
 from abc import abstractmethod
-from typing import Any, Dict, Generator, List, Optional, Sequence, Union
+from typing import Any, Dict, Generator, List, Optional, Sequence, Union, AsyncGenerator
 
 from llama_index.core.base.query_pipeline.query import (
     ChainableMixin,
@@ -24,6 +24,7 @@ from llama_index.core.base.response.schema import (
     PydanticResponse,
     Response,
     StreamingResponse,
+    AsyncStreamingResponse,
 )
 from llama_index.core.bridge.pydantic import BaseModel, Field
 from llama_index.core.callbacks.base import CallbackManager
@@ -54,6 +55,10 @@ def empty_response_generator() -> Generator[str, None, None]:
     yield "Empty Response"
 
 
+async def empty_response_agenerator() -> AsyncGenerator[str, None]:
+    yield "Empty Response"
+
+
 class BaseSynthesizer(ChainableMixin, PromptMixin):
     """Response builder class."""
 
@@ -164,6 +169,12 @@ class BaseSynthesizer(ChainableMixin, PromptMixin):
                 source_nodes=source_nodes,
                 metadata=response_metadata,
             )
+        if isinstance(response_str, AsyncGenerator):
+            return AsyncStreamingResponse(
+                response_str,
+                source_nodes=source_nodes,
+                metadata=response_metadata,
+            )
         if isinstance(response_str, self._output_cls):
             return PydanticResponse(
                 response_str, source_nodes=source_nodes, metadata=response_metadata
@@ -218,7 +229,7 @@ class BaseSynthesizer(ChainableMixin, PromptMixin):
     ) -> RESPONSE_TYPE:
         if len(nodes) == 0:
             if self._streaming:
-                return StreamingResponse(response_gen=empty_response_generator())
+                return AsyncStreamingResponse(response_gen=empty_response_agenerator())
             else:
                 return Response("Empty Response")
 
diff --git a/llama-index-core/llama_index/core/response_synthesizers/refine.py b/llama-index-core/llama_index/core/response_synthesizers/refine.py
index 8dd0fcc50..059c6f4b3 100644
--- a/llama-index-core/llama_index/core/response_synthesizers/refine.py
+++ b/llama-index-core/llama_index/core/response_synthesizers/refine.py
@@ -1,5 +1,14 @@
 import logging
-from typing import Any, Callable, Generator, Optional, Sequence, Type, cast
+from typing import (
+    Any,
+    Callable,
+    Generator,
+    Optional,
+    Sequence,
+    Type,
+    cast,
+    AsyncGenerator,
+)
 
 from llama_index.core.bridge.pydantic import BaseModel, Field, ValidationError
 from llama_index.core.callbacks.base import CallbackManager
@@ -351,7 +360,7 @@ class Refine(BaseSynthesizer):
             else:
                 response = response or "Empty Response"
         else:
-            response = cast(Generator, response)
+            response = cast(AsyncGenerator, response)
         return response
 
     async def _arefine_response_single(
@@ -411,7 +420,24 @@ class Refine(BaseSynthesizer):
                         f"Validation error on structured response: {e}", exc_info=True
                     )
             else:
-                raise ValueError("Streaming not supported for async")
+                if isinstance(response, Generator):
+                    response = "".join(response)
+
+                if isinstance(response, AsyncGenerator):
+                    _r = ""
+                    async for text in response:
+                        _r += text
+                    response = _r
+
+                refine_template = self._refine_template.partial_format(
+                    query_str=query_str, existing_answer=response
+                )
+
+                response = await self._llm.astream(
+                    refine_template,
+                    context_msg=cur_text_chunk,
+                    **response_kwargs,
+                )
 
             if query_satisfied:
                 refine_template = self._refine_template.partial_format(
@@ -451,7 +477,12 @@ class Refine(BaseSynthesizer):
                         f"Validation error on structured response: {e}", exc_info=True
                     )
             elif response is None and self._streaming:
-                raise ValueError("Streaming not supported for async")
+                response = await self._llm.astream(
+                    text_qa_template,
+                    context_str=cur_text_chunk,
+                    **response_kwargs,
+                )
+                query_satisfied = True
             else:
                 response = await self._arefine_response_single(
                     cast(RESPONSE_TEXT_TYPE, response),
@@ -464,5 +495,5 @@ class Refine(BaseSynthesizer):
         if isinstance(response, str):
             response = response or "Empty Response"
         else:
-            response = cast(Generator, response)
+            response = cast(AsyncGenerator, response)
         return response
diff --git a/llama-index-core/llama_index/core/types.py b/llama-index-core/llama_index/core/types.py
index cfa9a0fa5..236303324 100644
--- a/llama-index-core/llama_index/core/types.py
+++ b/llama-index-core/llama_index/core/types.py
@@ -21,7 +21,7 @@ Model = TypeVar("Model", bound=BaseModel)
 
 TokenGen = Generator[str, None, None]
 TokenAsyncGen = AsyncGenerator[str, None]
-RESPONSE_TEXT_TYPE = Union[BaseModel, str, TokenGen]
+RESPONSE_TEXT_TYPE = Union[BaseModel, str, TokenGen, TokenAsyncGen]
 
 
 # TODO: move into a `core` folder
-- 
GitLab