From 8ba53668a594ee4d3e53450d764495469c24e248 Mon Sep 17 00:00:00 2001
From: Theodore Cowan <t@theodore.me>
Date: Mon, 4 Dec 2023 17:50:19 -0500
Subject: [PATCH] Bedrock embedding query for cohere was not in the correct
 format (#9265)

---
 llama_index/embeddings/bedrock.py |  37 +++++++----
 poetry.lock                       | 100 +++++++++++++++++++++++++++---
 pyproject.toml                    |   1 +
 tests/embeddings/test_bedrock.py  |  75 ++++++++++++++++++++++
 tests/llms/test_bedrock.py        |  85 ++++++++++++++++---------
 5 files changed, 252 insertions(+), 46 deletions(-)
 create mode 100644 tests/embeddings/test_bedrock.py

diff --git a/llama_index/embeddings/bedrock.py b/llama_index/embeddings/bedrock.py
index 6cc4325647..f042622ac0 100644
--- a/llama_index/embeddings/bedrock.py
+++ b/llama_index/embeddings/bedrock.py
@@ -2,7 +2,7 @@ import json
 import os
 import warnings
 from enum import Enum
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Literal, Optional
 
 from llama_index.bridge.pydantic import PrivateAttr
 from llama_index.callbacks.base import CallbackManager
@@ -194,7 +194,7 @@ class BedrockEmbedding(BaseEmbedding):
             callback_manager=callback_manager,
         )
 
-    def _get_embedding(self, payload: Any) -> Embedding:
+    def _get_embedding(self, payload: str, type: Literal["text", "query"]) -> Embedding:
         if self._client is None:
             self.set_credentials(self.model_name)
 
@@ -202,7 +202,7 @@ class BedrockEmbedding(BaseEmbedding):
             raise ValueError("Client not set")
 
         provider = self.model_name.split(".")[0]
-        request_body = self._get_request_body(provider, payload)
+        request_body = self._get_request_body(provider, payload, type)
 
         response = self._client.invoke_model(
             body=request_body,
@@ -218,12 +218,14 @@ class BedrockEmbedding(BaseEmbedding):
         return resp.get(identifiers.get("embeddings"))
 
     def _get_query_embedding(self, query: str) -> Embedding:
-        return self._get_embedding(query)
+        return self._get_embedding(query, "query")
 
     def _get_text_embedding(self, text: str) -> Embedding:
-        return self._get_embedding(text)
+        return self._get_embedding(text, "text")
 
-    def _get_request_body(self, provider: str, payload: Any) -> Any:
+    def _get_request_body(
+        self, provider: str, payload: str, type: Literal["text", "query"]
+    ) -> Any:
         """Build the request body as per the provider.
         Currently supported providers are amazon, cohere.
 
@@ -240,14 +242,27 @@ class BedrockEmbedding(BaseEmbedding):
             }
 
         """
+        print("provider: ", provider, PROVIDERS.AMAZON)
         if provider == PROVIDERS.AMAZON:
-            request_body = json.dumps({"inputText": str(payload)})
-        if provider == PROVIDERS.COHERE:
-            request_body = json.dumps(payload)
+            request_body = json.dumps({"inputText": payload})
+        elif provider == PROVIDERS.COHERE:
+            input_types = {
+                "text": "search_document",
+                "query": "search_query",
+            }
+            request_body = json.dumps(
+                {
+                    "texts": [payload],
+                    "input_type": input_types[type],
+                    "truncate": "NONE",
+                }
+            )
+        else:
+            raise ValueError("Provider not supported")
         return request_body
 
     async def _aget_query_embedding(self, query: str) -> Embedding:
-        return self._get_embedding(query)
+        return self._get_embedding(query, "query")
 
     async def _aget_text_embedding(self, text: str) -> Embedding:
-        return self._get_embedding(text)
+        return self._get_embedding(text, "text")
diff --git a/poetry.lock b/poetry.lock
index b31868f201..482329da99 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
 
 [[package]]
 name = "accelerate"
@@ -587,6 +587,47 @@ numpy = [
     {version = ">=1.19.0", markers = "python_version >= \"3.9\""},
 ]
 
+[[package]]
+name = "boto3"
+version = "1.33.6"
+description = "The AWS SDK for Python"
+optional = false
+python-versions = ">= 3.7"
+files = [
+    {file = "boto3-1.33.6-py3-none-any.whl", hash = "sha256:b88f0f305186c5fd41f168e006baa45b7002a33029aec8e5bef373237a172fca"},
+    {file = "boto3-1.33.6.tar.gz", hash = "sha256:4f62fc1c7f3ea2d22917aa0aa07b86f119abd90bed3d815e4b52fb3d84773e15"},
+]
+
+[package.dependencies]
+botocore = ">=1.33.6,<1.34.0"
+jmespath = ">=0.7.1,<2.0.0"
+s3transfer = ">=0.8.2,<0.9.0"
+
+[package.extras]
+crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
+
+[[package]]
+name = "botocore"
+version = "1.33.6"
+description = "Low-level, data-driven core of boto 3."
+optional = false
+python-versions = ">= 3.7"
+files = [
+    {file = "botocore-1.33.6-py3-none-any.whl", hash = "sha256:14282cd432c0683770eee932c43c12bb9ad5730e23755204ad102897c996693a"},
+    {file = "botocore-1.33.6.tar.gz", hash = "sha256:938056bab831829f90e09ecd70dd6b295afd52b1482f5582ee7a11d8243d9661"},
+]
+
+[package.dependencies]
+jmespath = ">=0.7.1,<2.0.0"
+python-dateutil = ">=2.1,<3.0.0"
+urllib3 = [
+    {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""},
+    {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""},
+]
+
+[package.extras]
+crt = ["awscrt (==0.19.17)"]
+
 [[package]]
 name = "cachetools"
 version = "5.3.2"
@@ -2112,6 +2153,17 @@ MarkupSafe = ">=2.0"
 [package.extras]
 i18n = ["Babel (>=2.7)"]
 
+[[package]]
+name = "jmespath"
+version = "1.0.1"
+description = "JSON Matching Expressions"
+optional = false
+python-versions = ">=3.7"
+files = [
+    {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"},
+    {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"},
+]
+
 [[package]]
 name = "joblib"
 version = "1.3.2"
@@ -3764,7 +3816,7 @@ files = [
 [package.dependencies]
 coloredlogs = "*"
 datasets = [
-    {version = "*", optional = true, markers = "extra != \"onnxruntime\""},
+    {version = "*"},
     {version = ">=1.2.1", optional = true, markers = "extra == \"onnxruntime\""},
 ]
 evaluate = {version = "*", optional = true, markers = "extra == \"onnxruntime\""}
@@ -5343,6 +5395,23 @@ files = [
     {file = "ruff-0.0.292.tar.gz", hash = "sha256:1093449e37dd1e9b813798f6ad70932b57cf614e5c2b5c51005bf67d55db33ac"},
 ]
 
+[[package]]
+name = "s3transfer"
+version = "0.8.2"
+description = "An Amazon S3 Transfer Manager"
+optional = false
+python-versions = ">= 3.7"
+files = [
+    {file = "s3transfer-0.8.2-py3-none-any.whl", hash = "sha256:c9e56cbe88b28d8e197cf841f1f0c130f246595e77ae5b5a05b69fe7cb83de76"},
+    {file = "s3transfer-0.8.2.tar.gz", hash = "sha256:368ac6876a9e9ed91f6bc86581e319be08188dc60d50e0d56308ed5765446283"},
+]
+
+[package.dependencies]
+botocore = ">=1.33.2,<2.0a.0"
+
+[package.extras]
+crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"]
+
 [[package]]
 name = "safetensors"
 version = "0.4.1"
@@ -7102,17 +7171,34 @@ dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake
 
 [[package]]
 name = "urllib3"
-version = "2.1.0"
+version = "1.26.18"
 description = "HTTP library with thread-safe connection pooling, file post, and more."
 optional = false
-python-versions = ">=3.8"
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
+files = [
+    {file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"},
+    {file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"},
+]
+
+[package.extras]
+brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
+secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"]
+socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
+
+[[package]]
+name = "urllib3"
+version = "2.0.7"
+description = "HTTP library with thread-safe connection pooling, file post, and more."
+optional = false
+python-versions = ">=3.7"
 files = [
-    {file = "urllib3-2.1.0-py3-none-any.whl", hash = "sha256:55901e917a5896a349ff771be919f8bd99aff50b79fe58fec595eb37bbc56bb3"},
-    {file = "urllib3-2.1.0.tar.gz", hash = "sha256:df7aa8afb0148fa78488e7899b2c59b5f4ffcfa82e6c54ccb9dd37c1d7b52d54"},
+    {file = "urllib3-2.0.7-py3-none-any.whl", hash = "sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e"},
+    {file = "urllib3-2.0.7.tar.gz", hash = "sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84"},
 ]
 
 [package.extras]
 brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
+secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"]
 socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
 zstd = ["zstandard (>=0.18.0)"]
 
@@ -7574,4 +7660,4 @@ query-tools = ["guidance", "jsonpath-ng", "lm-format-enforcer", "rank-bm25", "sc
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.8.1,<3.12"
-content-hash = "003f5fcc15b7be0e4c2e691b503574ece7e7b11d21896982fe9a1a96e683a39d"
+content-hash = "6600bad921e4e108a7518ef1eb134e58858cf447c374f5e89af1ca47eca00eb3"
diff --git a/pyproject.toml b/pyproject.toml
index e3099b6d17..c158854e3f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -100,6 +100,7 @@ query_tools = [
 
 [tool.poetry.group.dev.dependencies]
 black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"}
+boto3 = "1.33.6"  # needed for tests
 codespell = {extras = ["toml"], version = ">=v2.2.6"}
 google-generativeai = {python = ">=3.9,<3.12", version = "^0.2.1"}
 ipython = "8.10.0"
diff --git a/tests/embeddings/test_bedrock.py b/tests/embeddings/test_bedrock.py
new file mode 100644
index 0000000000..c69aa1735f
--- /dev/null
+++ b/tests/embeddings/test_bedrock.py
@@ -0,0 +1,75 @@
+import json
+from io import BytesIO
+from unittest import TestCase
+
+import boto3
+from botocore.response import StreamingBody
+from botocore.stub import Stubber
+from llama_index.embeddings.bedrock import BedrockEmbedding, Models
+
+
+class TestBedrockEmbedding(TestCase):
+    bedrock_client = boto3.client("bedrock-runtime", region_name="us-east-1")
+    bedrock_stubber = Stubber(bedrock_client)
+
+    def test_get_text_embedding_titan(self) -> None:
+        mock_response = {
+            "embedding": [
+                0.017410278,
+                0.040924072,
+                -0.007507324,
+                0.09429932,
+                0.015304565,
+            ]
+        }
+
+        mock_stream = BytesIO(json.dumps(mock_response).encode())
+
+        self.bedrock_stubber.add_response(
+            "invoke_model",
+            {
+                "contentType": "application/json",
+                "body": StreamingBody(mock_stream, len(json.dumps(mock_response))),
+            },
+        )
+
+        bedrock_embedding = BedrockEmbedding(
+            model_name=Models.TITAN_EMBEDDING,
+            client=self.bedrock_client,
+        )
+
+        self.bedrock_stubber.activate()
+        embedding = bedrock_embedding.get_text_embedding(text="foo bar baz")
+        self.bedrock_stubber.deactivate()
+
+        self.bedrock_stubber.assert_no_pending_responses()
+        self.assertEqual(embedding, mock_response["embedding"])
+
+    def test_get_text_embedding_cohere(self) -> None:
+        mock_response = {
+            "embeddings": [
+                [0.017410278, 0.040924072, -0.007507324, 0.09429932, 0.015304565]
+            ]
+        }
+
+        mock_stream = BytesIO(json.dumps(mock_response).encode())
+
+        self.bedrock_stubber.add_response(
+            "invoke_model",
+            {
+                "contentType": "application/json",
+                "body": StreamingBody(mock_stream, len(json.dumps(mock_response))),
+            },
+        )
+
+        bedrock_embedding = BedrockEmbedding(
+            model_name=Models.COHERE_EMBED_ENGLISH_V3,
+            client=self.bedrock_client,
+        )
+
+        self.bedrock_stubber.activate()
+        embedding = bedrock_embedding.get_text_embedding(text="foo bar baz")
+        self.bedrock_stubber.deactivate()
+
+        self.bedrock_stubber.assert_no_pending_responses()
+        self.assertEqual(embedding, mock_response["embeddings"])
diff --git a/tests/llms/test_bedrock.py b/tests/llms/test_bedrock.py
index 6c2d8d36ab..f462ce5b9a 100644
--- a/tests/llms/test_bedrock.py
+++ b/tests/llms/test_bedrock.py
@@ -1,27 +1,13 @@
+import json
+from io import BytesIO
 from typing import Any, Generator
 
-import pytest
+from botocore.response import StreamingBody
+from botocore.stub import Stubber
+from llama_index.llms import Bedrock
 from llama_index.llms.base import ChatMessage
 from pytest import MonkeyPatch
 
-try:
-    import boto3
-except ImportError:
-    boto3 = None
-from llama_index.llms import Bedrock
-
-
-class MockStreamingBody:
-    def read(self) -> str:
-        return """{
-        "inputTextTokenCount": 3,
-        "results": [
-        {"tokenCount": 14,
-        "outputText": "\\n\\nThis is indeed a test",
-        "completionReason": "FINISH"
-        }]}
-        """
-
 
 class MockEventStream:
     def __iter__(self) -> Generator[dict, None, None]:
@@ -36,7 +22,23 @@ class MockEventStream:
             }
 
 
-def mock_completion_with_retry(*args: Any, **kwargs: Any) -> dict:
+def get_invoke_model_response() -> dict:
+    # response for titan model
+    raw_stream_bytes = json.dumps(
+        {
+            "inputTextTokenCount": 3,
+            "results": [
+                {
+                    "tokenCount": 14,
+                    "outputText": "\n\nThis is indeed a test",
+                    "completionReason": "FINISH",
+                }
+            ],
+        }
+    ).encode()
+    raw_stream = BytesIO(raw_stream_bytes)
+    content_length = len(raw_stream_bytes)
+
     return {
         "ResponseMetadata": {
             "HTTPHeaders": {
@@ -50,7 +52,10 @@ def mock_completion_with_retry(*args: Any, **kwargs: Any) -> dict:
             "RequestId": "667dq648-fbc3-4a7b-8f0e-4575f1f1f11d",
             "RetryAttempts": 0,
         },
-        "body": MockStreamingBody(),
+        "body": StreamingBody(
+            raw_stream=raw_stream,
+            content_length=content_length,
+        ),
         "contentType": "application/json",
     }
 
@@ -75,12 +80,29 @@ def mock_stream_completion_with_retry(*args: Any, **kwargs: Any) -> dict:
     }
 
 
-@pytest.mark.skipif(boto3 is None, reason="bedrock not installed")
-def test_model_basic(monkeypatch: MonkeyPatch) -> None:
-    monkeypatch.setattr(
-        "llama_index.llms.bedrock.completion_with_retry", mock_completion_with_retry
+def test_model_basic() -> None:
+    llm = Bedrock(
+        model="amazon.titan-text-express-v1",
+        profile_name=None,
+        aws_region_name="us-east-1",
+        aws_access_key_id="test",
     )
-    llm = Bedrock(model="amazon.titan-text-express-v1", profile_name=None)
+
+    bedrock_stubber = Stubber(llm._client)
+
+    # response for llm.complete()
+    bedrock_stubber.add_response(
+        "invoke_model",
+        get_invoke_model_response(),
+    )
+    # response for llm.chat()
+    bedrock_stubber.add_response(
+        "invoke_model",
+        get_invoke_model_response(),
+    )
+
+    bedrock_stubber.activate()
+
     test_prompt = "test prompt"
     response = llm.complete(test_prompt)
     assert response.text == "\n\nThis is indeed a test"
@@ -89,14 +111,21 @@ def test_model_basic(monkeypatch: MonkeyPatch) -> None:
     chat_response = llm.chat([message])
     assert chat_response.message.content == "\n\nThis is indeed a test"
 
+    bedrock_stubber.deactivate()
+
 
-@pytest.mark.skipif(boto3 is None, reason="bedrock not installed")
 def test_model_streaming(monkeypatch: MonkeyPatch) -> None:
+    # Cannot use Stubber to mock EventStream. See https://github.com/boto/botocore/issues/1621
     monkeypatch.setattr(
         "llama_index.llms.bedrock.completion_with_retry",
         mock_stream_completion_with_retry,
     )
-    llm = Bedrock(model="amazon.titan-text-express-v1", profile_name=None)
+    llm = Bedrock(
+        model="amazon.titan-text-express-v1",
+        profile_name=None,
+        aws_region_name="us-east-1",
+        aws_access_key_id="test",
+    )
     test_prompt = "test prompt"
     response_gen = llm.stream_complete(test_prompt)
     response = list(response_gen)
-- 
GitLab