Skip to content
Snippets Groups Projects
Unverified Commit 02b12c34 authored by Daniel J's avatar Daniel J Committed by GitHub
Browse files

Feat: Add async streaming support to `query_engine` (#11949)

parent 74502938
Branches
Tags
No related merge requests found
"""Response schema."""
import asyncio
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from llama_index.core.bridge.pydantic import BaseModel
from llama_index.core.schema import NodeWithScore
from llama_index.core.types import TokenGen
from llama_index.core.types import TokenGen, TokenAsyncGen
from llama_index.core.utils import truncate_text
......@@ -139,4 +139,66 @@ class StreamingResponse:
return "\n\n".join(texts)
RESPONSE_TYPE = Union[Response, StreamingResponse, PydanticResponse]
@dataclass
class AsyncStreamingResponse:
"""AsyncStreamingResponse object.
Returned if streaming=True while using async.
Attributes:
async_response_gen: The response async generator.
"""
async_response_gen: TokenAsyncGen
source_nodes: List[NodeWithScore] = field(default_factory=list)
metadata: Optional[Dict[str, Any]] = None
response_txt: Optional[str] = None
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
async def _yield_response(self) -> TokenAsyncGen:
"""Yield the string response."""
async with self._lock:
if self.response_txt is None and self.async_response_gen is not None:
self.response_txt = ""
async for text in self.async_response_gen:
self.response_txt += text
yield text
else:
yield self.response_txt
async def async_response_gen(self) -> TokenAsyncGen:
"""Yield the string response."""
async for text in self._yield_response():
yield text
async def get_response(self) -> Response:
"""Get a standard response object."""
async for _ in self._yield_response():
...
return Response(self.response_txt, self.source_nodes, self.metadata)
async def print_response_stream(self) -> None:
"""Print the response stream."""
streaming = True
async for text in self._yield_response():
print(text, end="", flush=True)
# do an empty print to print on the next line again next time
print()
def get_formatted_sources(self, length: int = 100, trim_text: int = True) -> str:
"""Get formatted sources text."""
texts = []
for source_node in self.source_nodes:
fmt_text_chunk = source_node.node.get_content()
if trim_text:
fmt_text_chunk = truncate_text(fmt_text_chunk, length)
node_id = source_node.node.node_id or "None"
source_text = f"> Source (Node id: {node_id}): {fmt_text_chunk}"
texts.append(source_text)
return "\n\n".join(texts)
RESPONSE_TYPE = Union[
Response, StreamingResponse, AsyncStreamingResponse, PydanticResponse
]
......@@ -10,7 +10,7 @@ Will support different modes, from 1) stuffing chunks into prompt,
import logging
from abc import abstractmethod
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
from typing import Any, Dict, Generator, List, Optional, Sequence, Union, AsyncGenerator
from llama_index.core.base.query_pipeline.query import (
ChainableMixin,
......@@ -24,6 +24,7 @@ from llama_index.core.base.response.schema import (
PydanticResponse,
Response,
StreamingResponse,
AsyncStreamingResponse,
)
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.core.callbacks.base import CallbackManager
......@@ -54,6 +55,10 @@ def empty_response_generator() -> Generator[str, None, None]:
yield "Empty Response"
async def empty_response_agenerator() -> AsyncGenerator[str, None]:
yield "Empty Response"
class BaseSynthesizer(ChainableMixin, PromptMixin):
"""Response builder class."""
......@@ -164,6 +169,12 @@ class BaseSynthesizer(ChainableMixin, PromptMixin):
source_nodes=source_nodes,
metadata=response_metadata,
)
if isinstance(response_str, AsyncGenerator):
return AsyncStreamingResponse(
response_str,
source_nodes=source_nodes,
metadata=response_metadata,
)
if isinstance(response_str, self._output_cls):
return PydanticResponse(
response_str, source_nodes=source_nodes, metadata=response_metadata
......@@ -218,7 +229,7 @@ class BaseSynthesizer(ChainableMixin, PromptMixin):
) -> RESPONSE_TYPE:
if len(nodes) == 0:
if self._streaming:
return StreamingResponse(response_gen=empty_response_generator())
return AsyncStreamingResponse(response_gen=empty_response_agenerator())
else:
return Response("Empty Response")
......
import logging
from typing import Any, Callable, Generator, Optional, Sequence, Type, cast
from typing import (
Any,
Callable,
Generator,
Optional,
Sequence,
Type,
cast,
AsyncGenerator,
)
from llama_index.core.bridge.pydantic import BaseModel, Field, ValidationError
from llama_index.core.callbacks.base import CallbackManager
......@@ -351,7 +360,7 @@ class Refine(BaseSynthesizer):
else:
response = response or "Empty Response"
else:
response = cast(Generator, response)
response = cast(AsyncGenerator, response)
return response
async def _arefine_response_single(
......@@ -411,7 +420,24 @@ class Refine(BaseSynthesizer):
f"Validation error on structured response: {e}", exc_info=True
)
else:
raise ValueError("Streaming not supported for async")
if isinstance(response, Generator):
response = "".join(response)
if isinstance(response, AsyncGenerator):
_r = ""
async for text in response:
_r += text
response = _r
refine_template = self._refine_template.partial_format(
query_str=query_str, existing_answer=response
)
response = await self._llm.astream(
refine_template,
context_msg=cur_text_chunk,
**response_kwargs,
)
if query_satisfied:
refine_template = self._refine_template.partial_format(
......@@ -451,7 +477,12 @@ class Refine(BaseSynthesizer):
f"Validation error on structured response: {e}", exc_info=True
)
elif response is None and self._streaming:
raise ValueError("Streaming not supported for async")
response = await self._llm.astream(
text_qa_template,
context_str=cur_text_chunk,
**response_kwargs,
)
query_satisfied = True
else:
response = await self._arefine_response_single(
cast(RESPONSE_TEXT_TYPE, response),
......@@ -464,5 +495,5 @@ class Refine(BaseSynthesizer):
if isinstance(response, str):
response = response or "Empty Response"
else:
response = cast(Generator, response)
response = cast(AsyncGenerator, response)
return response
......@@ -21,7 +21,7 @@ Model = TypeVar("Model", bound=BaseModel)
TokenGen = Generator[str, None, None]
TokenAsyncGen = AsyncGenerator[str, None]
RESPONSE_TEXT_TYPE = Union[BaseModel, str, TokenGen]
RESPONSE_TEXT_TYPE = Union[BaseModel, str, TokenGen, TokenAsyncGen]
# TODO: move into a `core` folder
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment