Skip to content
Snippets Groups Projects
Unverified Commit 3c30d2a9 authored by Logan's avatar Logan Committed by GitHub
Browse files

Fix async streaming with bedrock converse (#16942)

parent 286891ad
No related branches found
No related tags found
No related merge requests found
......@@ -454,7 +454,7 @@ class BedrockConverse(FunctionCallingLLM):
all_kwargs = self._get_all_kwargs(**kwargs)
# invoke LLM in AWS Bedrock Converse with retry
response = await converse_with_retry_async(
response_gen = await converse_with_retry_async(
session=self._asession,
config=self._config,
messages=converse_messages,
......@@ -467,7 +467,7 @@ class BedrockConverse(FunctionCallingLLM):
async def gen() -> ChatResponseAsyncGen:
content = {}
role = MessageRole.ASSISTANT
async for chunk in response["stream"]:
async for chunk in response_gen:
if content_block_delta := chunk.get("contentBlockDelta"):
content_delta = content_block_delta["delta"]
content = join_two_dicts(content, content_delta)
......@@ -489,7 +489,7 @@ class BedrockConverse(FunctionCallingLLM):
},
),
delta=content_delta.get("text", ""),
raw=response,
raw=chunk,
)
elif content_block_start := chunk.get("contentBlockStart"):
tool_use = content_block_start["toolUse"]
......@@ -511,7 +511,7 @@ class BedrockConverse(FunctionCallingLLM):
"status": status,
},
),
raw=response,
raw=chunk,
)
return gen()
......
......@@ -357,15 +357,27 @@ async def converse_with_retry_async(
converse_kwargs, {k: v for k, v in kwargs.items() if k != "tools"}
)
## NOTE: Returning the generator directly from converse_stream doesn't work
# So, we have to use two separate functions for streaming and non-streaming
# This differs from the synchronous version, and is a bit of a hack
# Further investigation is needed
@retry_decorator
async def _conversion_with_retry(**kwargs: Any) -> Any:
# the async boto3 client needs to be defined inside this async with, otherwise it will raise an error
async with session.client("bedrock-runtime", config=config) as client:
if stream:
return await client.converse_stream(**kwargs)
return await client.converse(**kwargs)
return await _conversion_with_retry(**converse_kwargs)
@retry_decorator
async def _conversion_stream_with_retry(**kwargs: Any) -> Any:
async with session.client("bedrock-runtime", config=config) as client:
response = await client.converse_stream(**kwargs)
async for event in response["stream"]:
yield event
if stream:
return _conversion_stream_with_retry(**converse_kwargs)
else:
return await _conversion_with_retry(**converse_kwargs)
def join_two_dicts(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]:
......
......@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-bedrock-converse"
readme = "README.md"
version = "0.3.8"
version = "0.3.9"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment