From 3c30d2a92d06bc07cfbf33a5f11328a35c494cd3 Mon Sep 17 00:00:00 2001
From: Logan <logan.markewich@live.com>
Date: Wed, 13 Nov 2024 12:50:38 -0600
Subject: [PATCH] Fix async streaming with bedrock converse (#16942)

---
 .../llama_index/llms/bedrock_converse/base.py |  8 ++++----
 .../llms/bedrock_converse/utils.py            | 20 +++++++++++++++----
 .../pyproject.toml                            |  2 +-
 3 files changed, 21 insertions(+), 9 deletions(-)

diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py
index 493b3c2033..706772fec6 100644
--- a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py
@@ -454,7 +454,7 @@ class BedrockConverse(FunctionCallingLLM):
         all_kwargs = self._get_all_kwargs(**kwargs)
 
         # invoke LLM in AWS Bedrock Converse with retry
-        response = await converse_with_retry_async(
+        response_gen = await converse_with_retry_async(
             session=self._asession,
             config=self._config,
             messages=converse_messages,
@@ -467,7 +467,7 @@ class BedrockConverse(FunctionCallingLLM):
         async def gen() -> ChatResponseAsyncGen:
             content = {}
             role = MessageRole.ASSISTANT
-            async for chunk in response["stream"]:
+            async for chunk in response_gen:
                 if content_block_delta := chunk.get("contentBlockDelta"):
                     content_delta = content_block_delta["delta"]
                     content = join_two_dicts(content, content_delta)
@@ -489,7 +489,7 @@ class BedrockConverse(FunctionCallingLLM):
                             },
                         ),
                         delta=content_delta.get("text", ""),
-                        raw=response,
+                        raw=chunk,
                     )
                 elif content_block_start := chunk.get("contentBlockStart"):
                     tool_use = content_block_start["toolUse"]
@@ -511,7 +511,7 @@ class BedrockConverse(FunctionCallingLLM):
                                 "status": status,
                             },
                         ),
-                        raw=response,
+                        raw=chunk,
                     )
 
         return gen()
diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/utils.py b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/utils.py
index 433753e6f2..31be9c1d17 100644
--- a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/utils.py
+++ b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/utils.py
@@ -357,15 +357,27 @@ async def converse_with_retry_async(
         converse_kwargs, {k: v for k, v in kwargs.items() if k != "tools"}
     )
 
+    ## NOTE: Returning the generator directly from converse_stream doesn't work
+    # So, we have to use two separate functions for streaming and non-streaming
+    # This differs from the synchronous version, and is a bit of a hack
+    # Further investigation is needed
+
     @retry_decorator
     async def _conversion_with_retry(**kwargs: Any) -> Any:
-        # the async boto3 client needs to be defined inside this async with, otherwise it will raise an error
         async with session.client("bedrock-runtime", config=config) as client:
-            if stream:
-                return await client.converse_stream(**kwargs)
             return await client.converse(**kwargs)
 
-    return await _conversion_with_retry(**converse_kwargs)
+    @retry_decorator
+    async def _conversion_stream_with_retry(**kwargs: Any) -> Any:
+        async with session.client("bedrock-runtime", config=config) as client:
+            response = await client.converse_stream(**kwargs)
+            async for event in response["stream"]:
+                yield event
+
+    if stream:
+        return _conversion_stream_with_retry(**converse_kwargs)
+    else:
+        return await _conversion_with_retry(**converse_kwargs)
 
 
 def join_two_dicts(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]:
diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/pyproject.toml
index 81b63353e3..00d62b04e4 100644
--- a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/pyproject.toml
@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-llms-bedrock-converse"
 readme = "README.md"
-version = "0.3.8"
+version = "0.3.9"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
-- 
GitLab