From 02b12c34f8645668134be98b91c37e9cf43e3ee3 Mon Sep 17 00:00:00 2001 From: Daniel J <38250010+Kigstn@users.noreply.github.com> Date: Fri, 15 Mar 2024 02:02:38 +0100 Subject: [PATCH] Feat: Add async streaming support to `query_engine` (#11949) --- .../llama_index/core/base/response/schema.py | 68 ++++++++++++++++++- .../core/response_synthesizers/base.py | 15 +++- .../core/response_synthesizers/refine.py | 41 +++++++++-- llama-index-core/llama_index/core/types.py | 2 +- 4 files changed, 115 insertions(+), 11 deletions(-) diff --git a/llama-index-core/llama_index/core/base/response/schema.py b/llama-index-core/llama_index/core/base/response/schema.py index c10730956..abaed38f1 100644 --- a/llama-index-core/llama_index/core/base/response/schema.py +++ b/llama-index-core/llama_index/core/base/response/schema.py @@ -1,11 +1,11 @@ """Response schema.""" - +import asyncio from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Union from llama_index.core.bridge.pydantic import BaseModel from llama_index.core.schema import NodeWithScore -from llama_index.core.types import TokenGen +from llama_index.core.types import TokenGen, TokenAsyncGen from llama_index.core.utils import truncate_text @@ -139,4 +139,66 @@ class StreamingResponse: return "\n\n".join(texts) -RESPONSE_TYPE = Union[Response, StreamingResponse, PydanticResponse] +@dataclass +class AsyncStreamingResponse: + """AsyncStreamingResponse object. + + Returned if streaming=True while using async. + + Attributes: + async_response_gen: The response async generator. + + """ + + async_response_gen: TokenAsyncGen + source_nodes: List[NodeWithScore] = field(default_factory=list) + metadata: Optional[Dict[str, Any]] = None + response_txt: Optional[str] = None + _lock: asyncio.Lock = field(default_factory=asyncio.Lock) + + async def _yield_response(self) -> TokenAsyncGen: + """Yield the string response.""" + async with self._lock: + if self.response_txt is None and self.async_response_gen is not None: + self.response_txt = "" + async for text in self.async_response_gen: + self.response_txt += text + yield text + else: + yield self.response_txt + + async def async_response_gen(self) -> TokenAsyncGen: + """Yield the string response.""" + async for text in self._yield_response(): + yield text + + async def get_response(self) -> Response: + """Get a standard response object.""" + async for _ in self._yield_response(): + ... + return Response(self.response_txt, self.source_nodes, self.metadata) + + async def print_response_stream(self) -> None: + """Print the response stream.""" + streaming = True + async for text in self._yield_response(): + print(text, end="", flush=True) + # do an empty print to print on the next line again next time + print() + + def get_formatted_sources(self, length: int = 100, trim_text: int = True) -> str: + """Get formatted sources text.""" + texts = [] + for source_node in self.source_nodes: + fmt_text_chunk = source_node.node.get_content() + if trim_text: + fmt_text_chunk = truncate_text(fmt_text_chunk, length) + node_id = source_node.node.node_id or "None" + source_text = f"> Source (Node id: {node_id}): {fmt_text_chunk}" + texts.append(source_text) + return "\n\n".join(texts) + + +RESPONSE_TYPE = Union[ + Response, StreamingResponse, AsyncStreamingResponse, PydanticResponse +] diff --git a/llama-index-core/llama_index/core/response_synthesizers/base.py b/llama-index-core/llama_index/core/response_synthesizers/base.py index 7c0e9cf0f..967a3bf81 100644 --- a/llama-index-core/llama_index/core/response_synthesizers/base.py +++ b/llama-index-core/llama_index/core/response_synthesizers/base.py @@ -10,7 +10,7 @@ Will support different modes, from 1) stuffing chunks into prompt, import logging from abc import abstractmethod -from typing import Any, Dict, Generator, List, Optional, Sequence, Union +from typing import Any, Dict, Generator, List, Optional, Sequence, Union, AsyncGenerator from llama_index.core.base.query_pipeline.query import ( ChainableMixin, @@ -24,6 +24,7 @@ from llama_index.core.base.response.schema import ( PydanticResponse, Response, StreamingResponse, + AsyncStreamingResponse, ) from llama_index.core.bridge.pydantic import BaseModel, Field from llama_index.core.callbacks.base import CallbackManager @@ -54,6 +55,10 @@ def empty_response_generator() -> Generator[str, None, None]: yield "Empty Response" +async def empty_response_agenerator() -> AsyncGenerator[str, None]: + yield "Empty Response" + + class BaseSynthesizer(ChainableMixin, PromptMixin): """Response builder class.""" @@ -164,6 +169,12 @@ class BaseSynthesizer(ChainableMixin, PromptMixin): source_nodes=source_nodes, metadata=response_metadata, ) + if isinstance(response_str, AsyncGenerator): + return AsyncStreamingResponse( + response_str, + source_nodes=source_nodes, + metadata=response_metadata, + ) if isinstance(response_str, self._output_cls): return PydanticResponse( response_str, source_nodes=source_nodes, metadata=response_metadata @@ -218,7 +229,7 @@ class BaseSynthesizer(ChainableMixin, PromptMixin): ) -> RESPONSE_TYPE: if len(nodes) == 0: if self._streaming: - return StreamingResponse(response_gen=empty_response_generator()) + return AsyncStreamingResponse(response_gen=empty_response_agenerator()) else: return Response("Empty Response") diff --git a/llama-index-core/llama_index/core/response_synthesizers/refine.py b/llama-index-core/llama_index/core/response_synthesizers/refine.py index 8dd0fcc50..059c6f4b3 100644 --- a/llama-index-core/llama_index/core/response_synthesizers/refine.py +++ b/llama-index-core/llama_index/core/response_synthesizers/refine.py @@ -1,5 +1,14 @@ import logging -from typing import Any, Callable, Generator, Optional, Sequence, Type, cast +from typing import ( + Any, + Callable, + Generator, + Optional, + Sequence, + Type, + cast, + AsyncGenerator, +) from llama_index.core.bridge.pydantic import BaseModel, Field, ValidationError from llama_index.core.callbacks.base import CallbackManager @@ -351,7 +360,7 @@ class Refine(BaseSynthesizer): else: response = response or "Empty Response" else: - response = cast(Generator, response) + response = cast(AsyncGenerator, response) return response async def _arefine_response_single( @@ -411,7 +420,24 @@ class Refine(BaseSynthesizer): f"Validation error on structured response: {e}", exc_info=True ) else: - raise ValueError("Streaming not supported for async") + if isinstance(response, Generator): + response = "".join(response) + + if isinstance(response, AsyncGenerator): + _r = "" + async for text in response: + _r += text + response = _r + + refine_template = self._refine_template.partial_format( + query_str=query_str, existing_answer=response + ) + + response = await self._llm.astream( + refine_template, + context_msg=cur_text_chunk, + **response_kwargs, + ) if query_satisfied: refine_template = self._refine_template.partial_format( @@ -451,7 +477,12 @@ class Refine(BaseSynthesizer): f"Validation error on structured response: {e}", exc_info=True ) elif response is None and self._streaming: - raise ValueError("Streaming not supported for async") + response = await self._llm.astream( + text_qa_template, + context_str=cur_text_chunk, + **response_kwargs, + ) + query_satisfied = True else: response = await self._arefine_response_single( cast(RESPONSE_TEXT_TYPE, response), @@ -464,5 +495,5 @@ class Refine(BaseSynthesizer): if isinstance(response, str): response = response or "Empty Response" else: - response = cast(Generator, response) + response = cast(AsyncGenerator, response) return response diff --git a/llama-index-core/llama_index/core/types.py b/llama-index-core/llama_index/core/types.py index cfa9a0fa5..236303324 100644 --- a/llama-index-core/llama_index/core/types.py +++ b/llama-index-core/llama_index/core/types.py @@ -21,7 +21,7 @@ Model = TypeVar("Model", bound=BaseModel) TokenGen = Generator[str, None, None] TokenAsyncGen = AsyncGenerator[str, None] -RESPONSE_TEXT_TYPE = Union[BaseModel, str, TokenGen] +RESPONSE_TEXT_TYPE = Union[BaseModel, str, TokenGen, TokenAsyncGen] # TODO: move into a `core` folder -- GitLab