From 8ba53668a594ee4d3e53450d764495469c24e248 Mon Sep 17 00:00:00 2001 From: Theodore Cowan <t@theodore.me> Date: Mon, 4 Dec 2023 17:50:19 -0500 Subject: [PATCH] Bedrock embedding query for cohere was not in the correct format (#9265) --- llama_index/embeddings/bedrock.py | 37 +++++++---- poetry.lock | 100 +++++++++++++++++++++++++++--- pyproject.toml | 1 + tests/embeddings/test_bedrock.py | 75 ++++++++++++++++++++++ tests/llms/test_bedrock.py | 85 ++++++++++++++++--------- 5 files changed, 252 insertions(+), 46 deletions(-) create mode 100644 tests/embeddings/test_bedrock.py diff --git a/llama_index/embeddings/bedrock.py b/llama_index/embeddings/bedrock.py index 6cc4325647..f042622ac0 100644 --- a/llama_index/embeddings/bedrock.py +++ b/llama_index/embeddings/bedrock.py @@ -2,7 +2,7 @@ import json import os import warnings from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional from llama_index.bridge.pydantic import PrivateAttr from llama_index.callbacks.base import CallbackManager @@ -194,7 +194,7 @@ class BedrockEmbedding(BaseEmbedding): callback_manager=callback_manager, ) - def _get_embedding(self, payload: Any) -> Embedding: + def _get_embedding(self, payload: str, type: Literal["text", "query"]) -> Embedding: if self._client is None: self.set_credentials(self.model_name) @@ -202,7 +202,7 @@ class BedrockEmbedding(BaseEmbedding): raise ValueError("Client not set") provider = self.model_name.split(".")[0] - request_body = self._get_request_body(provider, payload) + request_body = self._get_request_body(provider, payload, type) response = self._client.invoke_model( body=request_body, @@ -218,12 +218,14 @@ class BedrockEmbedding(BaseEmbedding): return resp.get(identifiers.get("embeddings")) def _get_query_embedding(self, query: str) -> Embedding: - return self._get_embedding(query) + return self._get_embedding(query, "query") def _get_text_embedding(self, text: str) -> Embedding: - return self._get_embedding(text) + return self._get_embedding(text, "text") - def _get_request_body(self, provider: str, payload: Any) -> Any: + def _get_request_body( + self, provider: str, payload: str, type: Literal["text", "query"] + ) -> Any: """Build the request body as per the provider. Currently supported providers are amazon, cohere. @@ -240,14 +242,27 @@ class BedrockEmbedding(BaseEmbedding): } """ + print("provider: ", provider, PROVIDERS.AMAZON) if provider == PROVIDERS.AMAZON: - request_body = json.dumps({"inputText": str(payload)}) - if provider == PROVIDERS.COHERE: - request_body = json.dumps(payload) + request_body = json.dumps({"inputText": payload}) + elif provider == PROVIDERS.COHERE: + input_types = { + "text": "search_document", + "query": "search_query", + } + request_body = json.dumps( + { + "texts": [payload], + "input_type": input_types[type], + "truncate": "NONE", + } + ) + else: + raise ValueError("Provider not supported") return request_body async def _aget_query_embedding(self, query: str) -> Embedding: - return self._get_embedding(query) + return self._get_embedding(query, "query") async def _aget_text_embedding(self, text: str) -> Embedding: - return self._get_embedding(text) + return self._get_embedding(text, "text") diff --git a/poetry.lock b/poetry.lock index b31868f201..482329da99 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "accelerate" @@ -587,6 +587,47 @@ numpy = [ {version = ">=1.19.0", markers = "python_version >= \"3.9\""}, ] +[[package]] +name = "boto3" +version = "1.33.6" +description = "The AWS SDK for Python" +optional = false +python-versions = ">= 3.7" +files = [ + {file = "boto3-1.33.6-py3-none-any.whl", hash = "sha256:b88f0f305186c5fd41f168e006baa45b7002a33029aec8e5bef373237a172fca"}, + {file = "boto3-1.33.6.tar.gz", hash = "sha256:4f62fc1c7f3ea2d22917aa0aa07b86f119abd90bed3d815e4b52fb3d84773e15"}, +] + +[package.dependencies] +botocore = ">=1.33.6,<1.34.0" +jmespath = ">=0.7.1,<2.0.0" +s3transfer = ">=0.8.2,<0.9.0" + +[package.extras] +crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] + +[[package]] +name = "botocore" +version = "1.33.6" +description = "Low-level, data-driven core of boto 3." +optional = false +python-versions = ">= 3.7" +files = [ + {file = "botocore-1.33.6-py3-none-any.whl", hash = "sha256:14282cd432c0683770eee932c43c12bb9ad5730e23755204ad102897c996693a"}, + {file = "botocore-1.33.6.tar.gz", hash = "sha256:938056bab831829f90e09ecd70dd6b295afd52b1482f5582ee7a11d8243d9661"}, +] + +[package.dependencies] +jmespath = ">=0.7.1,<2.0.0" +python-dateutil = ">=2.1,<3.0.0" +urllib3 = [ + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, + {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""}, +] + +[package.extras] +crt = ["awscrt (==0.19.17)"] + [[package]] name = "cachetools" version = "5.3.2" @@ -2112,6 +2153,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jmespath" +version = "1.0.1" +description = "JSON Matching Expressions" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, + {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, +] + [[package]] name = "joblib" version = "1.3.2" @@ -3764,7 +3816,7 @@ files = [ [package.dependencies] coloredlogs = "*" datasets = [ - {version = "*", optional = true, markers = "extra != \"onnxruntime\""}, + {version = "*"}, {version = ">=1.2.1", optional = true, markers = "extra == \"onnxruntime\""}, ] evaluate = {version = "*", optional = true, markers = "extra == \"onnxruntime\""} @@ -5343,6 +5395,23 @@ files = [ {file = "ruff-0.0.292.tar.gz", hash = "sha256:1093449e37dd1e9b813798f6ad70932b57cf614e5c2b5c51005bf67d55db33ac"}, ] +[[package]] +name = "s3transfer" +version = "0.8.2" +description = "An Amazon S3 Transfer Manager" +optional = false +python-versions = ">= 3.7" +files = [ + {file = "s3transfer-0.8.2-py3-none-any.whl", hash = "sha256:c9e56cbe88b28d8e197cf841f1f0c130f246595e77ae5b5a05b69fe7cb83de76"}, + {file = "s3transfer-0.8.2.tar.gz", hash = "sha256:368ac6876a9e9ed91f6bc86581e319be08188dc60d50e0d56308ed5765446283"}, +] + +[package.dependencies] +botocore = ">=1.33.2,<2.0a.0" + +[package.extras] +crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] + [[package]] name = "safetensors" version = "0.4.1" @@ -7102,17 +7171,34 @@ dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake [[package]] name = "urllib3" -version = "2.1.0" +version = "1.26.18" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false -python-versions = ">=3.8" +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +files = [ + {file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"}, + {file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"}, +] + +[package.extras] +brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] +secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] + +[[package]] +name = "urllib3" +version = "2.0.7" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=3.7" files = [ - {file = "urllib3-2.1.0-py3-none-any.whl", hash = "sha256:55901e917a5896a349ff771be919f8bd99aff50b79fe58fec595eb37bbc56bb3"}, - {file = "urllib3-2.1.0.tar.gz", hash = "sha256:df7aa8afb0148fa78488e7899b2c59b5f4ffcfa82e6c54ccb9dd37c1d7b52d54"}, + {file = "urllib3-2.0.7-py3-none-any.whl", hash = "sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e"}, + {file = "urllib3-2.0.7.tar.gz", hash = "sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84"}, ] [package.extras] brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -7574,4 +7660,4 @@ query-tools = ["guidance", "jsonpath-ng", "lm-format-enforcer", "rank-bm25", "sc [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.12" -content-hash = "003f5fcc15b7be0e4c2e691b503574ece7e7b11d21896982fe9a1a96e683a39d" +content-hash = "6600bad921e4e108a7518ef1eb134e58858cf447c374f5e89af1ca47eca00eb3" diff --git a/pyproject.toml b/pyproject.toml index e3099b6d17..c158854e3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ query_tools = [ [tool.poetry.group.dev.dependencies] black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} +boto3 = "1.33.6" # needed for tests codespell = {extras = ["toml"], version = ">=v2.2.6"} google-generativeai = {python = ">=3.9,<3.12", version = "^0.2.1"} ipython = "8.10.0" diff --git a/tests/embeddings/test_bedrock.py b/tests/embeddings/test_bedrock.py new file mode 100644 index 0000000000..c69aa1735f --- /dev/null +++ b/tests/embeddings/test_bedrock.py @@ -0,0 +1,75 @@ +import json +from io import BytesIO +from unittest import TestCase + +import boto3 +from botocore.response import StreamingBody +from botocore.stub import Stubber +from llama_index.embeddings.bedrock import BedrockEmbedding, Models + + +class TestBedrockEmbedding(TestCase): + bedrock_client = boto3.client("bedrock-runtime", region_name="us-east-1") + bedrock_stubber = Stubber(bedrock_client) + + def test_get_text_embedding_titan(self) -> None: + mock_response = { + "embedding": [ + 0.017410278, + 0.040924072, + -0.007507324, + 0.09429932, + 0.015304565, + ] + } + + mock_stream = BytesIO(json.dumps(mock_response).encode()) + + self.bedrock_stubber.add_response( + "invoke_model", + { + "contentType": "application/json", + "body": StreamingBody(mock_stream, len(json.dumps(mock_response))), + }, + ) + + bedrock_embedding = BedrockEmbedding( + model_name=Models.TITAN_EMBEDDING, + client=self.bedrock_client, + ) + + self.bedrock_stubber.activate() + embedding = bedrock_embedding.get_text_embedding(text="foo bar baz") + self.bedrock_stubber.deactivate() + + self.bedrock_stubber.assert_no_pending_responses() + self.assertEqual(embedding, mock_response["embedding"]) + + def test_get_text_embedding_cohere(self) -> None: + mock_response = { + "embeddings": [ + [0.017410278, 0.040924072, -0.007507324, 0.09429932, 0.015304565] + ] + } + + mock_stream = BytesIO(json.dumps(mock_response).encode()) + + self.bedrock_stubber.add_response( + "invoke_model", + { + "contentType": "application/json", + "body": StreamingBody(mock_stream, len(json.dumps(mock_response))), + }, + ) + + bedrock_embedding = BedrockEmbedding( + model_name=Models.COHERE_EMBED_ENGLISH_V3, + client=self.bedrock_client, + ) + + self.bedrock_stubber.activate() + embedding = bedrock_embedding.get_text_embedding(text="foo bar baz") + self.bedrock_stubber.deactivate() + + self.bedrock_stubber.assert_no_pending_responses() + self.assertEqual(embedding, mock_response["embeddings"]) diff --git a/tests/llms/test_bedrock.py b/tests/llms/test_bedrock.py index 6c2d8d36ab..f462ce5b9a 100644 --- a/tests/llms/test_bedrock.py +++ b/tests/llms/test_bedrock.py @@ -1,27 +1,13 @@ +import json +from io import BytesIO from typing import Any, Generator -import pytest +from botocore.response import StreamingBody +from botocore.stub import Stubber +from llama_index.llms import Bedrock from llama_index.llms.base import ChatMessage from pytest import MonkeyPatch -try: - import boto3 -except ImportError: - boto3 = None -from llama_index.llms import Bedrock - - -class MockStreamingBody: - def read(self) -> str: - return """{ - "inputTextTokenCount": 3, - "results": [ - {"tokenCount": 14, - "outputText": "\\n\\nThis is indeed a test", - "completionReason": "FINISH" - }]} - """ - class MockEventStream: def __iter__(self) -> Generator[dict, None, None]: @@ -36,7 +22,23 @@ class MockEventStream: } -def mock_completion_with_retry(*args: Any, **kwargs: Any) -> dict: +def get_invoke_model_response() -> dict: + # response for titan model + raw_stream_bytes = json.dumps( + { + "inputTextTokenCount": 3, + "results": [ + { + "tokenCount": 14, + "outputText": "\n\nThis is indeed a test", + "completionReason": "FINISH", + } + ], + } + ).encode() + raw_stream = BytesIO(raw_stream_bytes) + content_length = len(raw_stream_bytes) + return { "ResponseMetadata": { "HTTPHeaders": { @@ -50,7 +52,10 @@ def mock_completion_with_retry(*args: Any, **kwargs: Any) -> dict: "RequestId": "667dq648-fbc3-4a7b-8f0e-4575f1f1f11d", "RetryAttempts": 0, }, - "body": MockStreamingBody(), + "body": StreamingBody( + raw_stream=raw_stream, + content_length=content_length, + ), "contentType": "application/json", } @@ -75,12 +80,29 @@ def mock_stream_completion_with_retry(*args: Any, **kwargs: Any) -> dict: } -@pytest.mark.skipif(boto3 is None, reason="bedrock not installed") -def test_model_basic(monkeypatch: MonkeyPatch) -> None: - monkeypatch.setattr( - "llama_index.llms.bedrock.completion_with_retry", mock_completion_with_retry +def test_model_basic() -> None: + llm = Bedrock( + model="amazon.titan-text-express-v1", + profile_name=None, + aws_region_name="us-east-1", + aws_access_key_id="test", ) - llm = Bedrock(model="amazon.titan-text-express-v1", profile_name=None) + + bedrock_stubber = Stubber(llm._client) + + # response for llm.complete() + bedrock_stubber.add_response( + "invoke_model", + get_invoke_model_response(), + ) + # response for llm.chat() + bedrock_stubber.add_response( + "invoke_model", + get_invoke_model_response(), + ) + + bedrock_stubber.activate() + test_prompt = "test prompt" response = llm.complete(test_prompt) assert response.text == "\n\nThis is indeed a test" @@ -89,14 +111,21 @@ def test_model_basic(monkeypatch: MonkeyPatch) -> None: chat_response = llm.chat([message]) assert chat_response.message.content == "\n\nThis is indeed a test" + bedrock_stubber.deactivate() + -@pytest.mark.skipif(boto3 is None, reason="bedrock not installed") def test_model_streaming(monkeypatch: MonkeyPatch) -> None: + # Cannot use Stubber to mock EventStream. See https://github.com/boto/botocore/issues/1621 monkeypatch.setattr( "llama_index.llms.bedrock.completion_with_retry", mock_stream_completion_with_retry, ) - llm = Bedrock(model="amazon.titan-text-express-v1", profile_name=None) + llm = Bedrock( + model="amazon.titan-text-express-v1", + profile_name=None, + aws_region_name="us-east-1", + aws_access_key_id="test", + ) test_prompt = "test prompt" response_gen = llm.stream_complete(test_prompt) response = list(response_gen) -- GitLab