From 3c30d2a92d06bc07cfbf33a5f11328a35c494cd3 Mon Sep 17 00:00:00 2001 From: Logan <logan.markewich@live.com> Date: Wed, 13 Nov 2024 12:50:38 -0600 Subject: [PATCH] Fix async streaming with bedrock converse (#16942) --- .../llama_index/llms/bedrock_converse/base.py | 8 ++++---- .../llms/bedrock_converse/utils.py | 20 +++++++++++++++---- .../pyproject.toml | 2 +- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py index 493b3c2033..706772fec6 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py @@ -454,7 +454,7 @@ class BedrockConverse(FunctionCallingLLM): all_kwargs = self._get_all_kwargs(**kwargs) # invoke LLM in AWS Bedrock Converse with retry - response = await converse_with_retry_async( + response_gen = await converse_with_retry_async( session=self._asession, config=self._config, messages=converse_messages, @@ -467,7 +467,7 @@ class BedrockConverse(FunctionCallingLLM): async def gen() -> ChatResponseAsyncGen: content = {} role = MessageRole.ASSISTANT - async for chunk in response["stream"]: + async for chunk in response_gen: if content_block_delta := chunk.get("contentBlockDelta"): content_delta = content_block_delta["delta"] content = join_two_dicts(content, content_delta) @@ -489,7 +489,7 @@ class BedrockConverse(FunctionCallingLLM): }, ), delta=content_delta.get("text", ""), - raw=response, + raw=chunk, ) elif content_block_start := chunk.get("contentBlockStart"): tool_use = content_block_start["toolUse"] @@ -511,7 +511,7 @@ class BedrockConverse(FunctionCallingLLM): "status": status, }, ), - raw=response, + raw=chunk, ) return gen() diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/utils.py b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/utils.py index 433753e6f2..31be9c1d17 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/utils.py @@ -357,15 +357,27 @@ async def converse_with_retry_async( converse_kwargs, {k: v for k, v in kwargs.items() if k != "tools"} ) + ## NOTE: Returning the generator directly from converse_stream doesn't work + # So, we have to use two separate functions for streaming and non-streaming + # This differs from the synchronous version, and is a bit of a hack + # Further investigation is needed + @retry_decorator async def _conversion_with_retry(**kwargs: Any) -> Any: - # the async boto3 client needs to be defined inside this async with, otherwise it will raise an error async with session.client("bedrock-runtime", config=config) as client: - if stream: - return await client.converse_stream(**kwargs) return await client.converse(**kwargs) - return await _conversion_with_retry(**converse_kwargs) + @retry_decorator + async def _conversion_stream_with_retry(**kwargs: Any) -> Any: + async with session.client("bedrock-runtime", config=config) as client: + response = await client.converse_stream(**kwargs) + async for event in response["stream"]: + yield event + + if stream: + return _conversion_stream_with_retry(**converse_kwargs) + else: + return await _conversion_with_retry(**converse_kwargs) def join_two_dicts(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]: diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/pyproject.toml index 81b63353e3..00d62b04e4 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-bedrock-converse" readme = "README.md" -version = "0.3.8" +version = "0.3.9" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -- GitLab