diff --git a/poetry.lock b/poetry.lock index 53b7b5cd7df3ef496e4c97d39e3beea6cd662bc8..8b7f9ccdac9b216db117ee7f88a66f7765c2d7e2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -259,6 +259,47 @@ d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "boto3" +version = "1.34.98" +description = "The AWS SDK for Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "boto3-1.34.98-py3-none-any.whl", hash = "sha256:030e43b8efe22b4cf10b9f3ef9e30cd4cf9ef9784b26efe9a4583339f2b2bcec"}, + {file = "boto3-1.34.98.tar.gz", hash = "sha256:28c10956033fa79e64529f48c3b62db86d5e4b77024a7343764b6bde6b553543"}, +] + +[package.dependencies] +botocore = ">=1.34.98,<1.35.0" +jmespath = ">=0.7.1,<2.0.0" +s3transfer = ">=0.10.0,<0.11.0" + +[package.extras] +crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] + +[[package]] +name = "botocore" +version = "1.34.98" +description = "Low-level, data-driven core of boto 3." +optional = true +python-versions = ">=3.8" +files = [ + {file = "botocore-1.34.98-py3-none-any.whl", hash = "sha256:631c0031d8ce922b5752ab395ead896a0281b0dc74745a754d0351a27c5d83de"}, + {file = "botocore-1.34.98.tar.gz", hash = "sha256:4cee65df02f4b0be08ad1401965cc89efafebc50ef0727d2d17083c7f1ed2831"}, +] + +[package.dependencies] +jmespath = ">=0.7.1,<2.0.0" +python-dateutil = ">=2.1,<3.0.0" +urllib3 = [ + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, + {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, +] + +[package.extras] +crt = ["awscrt (==0.20.9)"] + [[package]] name = "cachetools" version = "5.3.3" @@ -1096,12 +1137,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" @@ -1857,6 +1898,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jmespath" +version = "1.0.1" +description = "JSON Matching Expressions" +optional = true +python-versions = ">=3.7" +files = [ + {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, + {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, +] + [[package]] name = "joblib" version = "1.4.0" @@ -3508,6 +3560,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3853,6 +3906,23 @@ files = [ {file = "ruff-0.1.15.tar.gz", hash = "sha256:f6dfa8c1b21c913c326919056c390966648b680966febcb796cc9d1aaab8564e"}, ] +[[package]] +name = "s3transfer" +version = "0.10.1" +description = "An Amazon S3 Transfer Manager" +optional = true +python-versions = ">= 3.8" +files = [ + {file = "s3transfer-0.10.1-py3-none-any.whl", hash = "sha256:ceb252b11bcf87080fb7850a224fb6e05c8a776bab8f2b64b7f25b969464839d"}, + {file = "s3transfer-0.10.1.tar.gz", hash = "sha256:5683916b4c724f799e600f41dd9e10a9ff19871bf87623cc8f491cb4f5fa0a19"}, +] + +[package.dependencies] +botocore = ">=1.33.2,<2.0a.0" + +[package.extras] +crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] + [[package]] name = "safetensors" version = "0.4.2" @@ -4560,6 +4630,20 @@ files = [ {file = "types_PyYAML-6.0.12.20240311-py3-none-any.whl", hash = "sha256:b845b06a1c7e54b8e5b4c683043de0d9caf205e7434b3edc678ff2411979b8f6"}, ] +[[package]] +name = "types-requests" +version = "2.31.0.6" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.7" +files = [ + {file = "types-requests-2.31.0.6.tar.gz", hash = "sha256:cd74ce3b53c461f1228a9b783929ac73a666658f223e28ed29753771477b3bd0"}, + {file = "types_requests-2.31.0.6-py3-none-any.whl", hash = "sha256:a2db9cb228a81da8348b49ad6db3f5519452dd20a9c1e1a868c83c5fe88fd1a9"}, +] + +[package.dependencies] +types-urllib3 = "*" + [[package]] name = "types-requests" version = "2.31.0.20240406" @@ -4574,6 +4658,17 @@ files = [ [package.dependencies] urllib3 = ">=2" +[[package]] +name = "types-urllib3" +version = "1.26.25.14" +description = "Typing stubs for urllib3" +optional = false +python-versions = "*" +files = [ + {file = "types-urllib3-1.26.25.14.tar.gz", hash = "sha256:229b7f577c951b8c1b92c1bc2b2fdb0b49847bd2af6d1cc2a2e3dd340f3bda8f"}, + {file = "types_urllib3-1.26.25.14-py3-none-any.whl", hash = "sha256:9683bbb7fb72e32bfe9d2be6e04875fbe1b3eeec3cbb4ea231435aa7fd6b4f0e"}, +] + [[package]] name = "typing-extensions" version = "4.11.0" @@ -4585,6 +4680,22 @@ files = [ {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"}, ] +[[package]] +name = "urllib3" +version = "1.26.18" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +files = [ + {file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"}, + {file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"}, +] + +[package.extras] +brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] +secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] + [[package]] name = "urllib3" version = "2.2.1" @@ -4756,6 +4867,7 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] +bedrock = ["boto3"] fastembed = ["fastembed"] google = ["google-cloud-aiplatform"] hybrid = ["pinecone-text"] @@ -4769,4 +4881,4 @@ vision = ["pillow", "torch", "torchvision", "transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "9f308d2dd1c067185f9d84721b25d81e7d1e72a239059863bad1f4439a7a26cc" +content-hash = "be798556d4ad5d05ba0682534dcfab1c06e3ff1c33bcf3c24d178b665c81dde8" diff --git a/pyproject.toml b/pyproject.toml index 6a2ee15d32bff2e24fe2bde1006e0d2e112131db..ed313654cdac2967857ff74f4c30b47539954041 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ matplotlib = { version = "^3.8.3", optional = true} qdrant-client = {version = "^1.8.0", optional = true} google-cloud-aiplatform = {version = "^1.45.0", optional = true} requests-mock = "^1.12.1" +boto3 = { version = "^1.34.98", optional = true } [tool.poetry.extras] hybrid = ["pinecone-text"] @@ -49,6 +50,7 @@ processing = ["matplotlib"] mistralai = ["mistralai"] qdrant = ["qdrant-client"] google = ["google-cloud-aiplatform"] +bedrock = ["boto3"] [tool.poetry.group.dev.dependencies] ipykernel = "^6.25.0" diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index 5efc730398a45fc3a9de5f234a6d43a0e37911be..8598bbc58ba586cacf4b81d1ea440d7cbb3c7b0b 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -1,6 +1,7 @@ from typing import List, Optional from semantic_router.encoders.base import BaseEncoder +from semantic_router.encoders.bedrock import BedrockEncoder from semantic_router.encoders.bm25 import BM25Encoder from semantic_router.encoders.clip import CLIPEncoder from semantic_router.encoders.cohere import CohereEncoder @@ -29,6 +30,7 @@ __all__ = [ "VitEncoder", "CLIPEncoder", "GoogleEncoder", + "BedrockEncoder", ] @@ -67,6 +69,8 @@ class AutoEncoder: self.model = CLIPEncoder(name=name) elif self.type == EncoderType.GOOGLE: self.model = GoogleEncoder(name=name) + elif self.type == EncoderType.BEDROCK: + self.model = BedrockEncoder(name=name) else: raise ValueError(f"Encoder type '{type}' not supported") diff --git a/semantic_router/encoders/bedrock.py b/semantic_router/encoders/bedrock.py new file mode 100644 index 0000000000000000000000000000000000000000..bb27572aaf1d498fb77f83135a5ac850d5bb5b10 --- /dev/null +++ b/semantic_router/encoders/bedrock.py @@ -0,0 +1,101 @@ +import json +from typing import List, Optional, Any + +import boto3 + +from semantic_router.encoders import BaseEncoder +from semantic_router.utils.defaults import EncoderDefault + + +class BedrockEncoder(BaseEncoder): + client: Any = None + type: str = "bedrock" + input_type: Optional[str] = "search_query" + session: Optional[Any] = (None,) + region: Optional[str] = None + + def __init__( + self, + name: Optional[str] = None, + session: Optional[Any] = None, + region: Optional[str] = None, + score_threshold: float = 0.3, + input_type: Optional[str] = "search_query", + ): + if name is None: + name = EncoderDefault.BEDROCK.value["embedding_model"] + super().__init__( + name=name, + score_threshold=score_threshold, + input_type=input_type, + ) + self.input_type = input_type + self.session = session or boto3.Session() + if self.session.get_credentials() is None: + raise ValueError("Could not get AWS session") + self.region = region or self.session.region_name + if self.region is None: + raise ValueError("No AWS region provided") + try: + self.client = self.session.client( + service_name="bedrock-runtime", region_name=str(self.region) + ) + except Exception as e: + raise ValueError(f"Bedrock client failed to initialise. Error: {e}") from e + + def __call__(self, docs: List[str]) -> List[List[float]]: + if self.client is None: + raise ValueError("Bedrock client is not initialised.") + try: + embeddings = [] + if "amazon" in self.name: + for doc in docs: + doc = json.dumps( + { + "inputText": doc, + } + ) + response = self.client.invoke_model( + body=doc, + modelId=self.name, + accept="*/*", + contentType="application/json", + ) + + response_body = json.loads(response.get("body").read()) + + embedding = response_body.get("embedding") + embeddings.append(embedding) + elif "cohere" in self.name: + MAX_WORDS = 400 + for doc in docs: + words = doc.split() + if len(words) > MAX_WORDS: + chunks = [ + " ".join(words[i : i + MAX_WORDS]) + for i in range(0, len(words), MAX_WORDS) + ] + else: + chunks = [doc] + + for chunk in chunks: + chunk = json.dumps( + {"texts": [chunk], "input_type": self.input_type} + ) + + response = self.client.invoke_model( + body=chunk, + modelId=self.name, + accept="*/*", + contentType="application/json", + ) + + response_body = json.loads(response.get("body").read()) + + chunk_embeddings = response_body.get("embeddings") + embeddings.extend(chunk_embeddings) + else: + raise ValueError("Unknown model name") + return embeddings + except Exception as e: + raise ValueError(f"Bedrock call failed. Error: {e}") from e diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 20b6ef825872ca329a90fc5230a49e877987f6d9..256370b07aef950ded9fdb3fca6491a3bb35799f 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -15,6 +15,7 @@ class EncoderType(Enum): VIT = "vit" CLIP = "clip" GOOGLE = "google" + BEDROCK = "bedrock" class EncoderInfo(BaseModel): diff --git a/semantic_router/utils/defaults.py b/semantic_router/utils/defaults.py index 3c9cbb2dd1010f5b861c49fcafad389c591fe9cb..75331c06581ad4692bc24f1633ba5a609ba28e47 100644 --- a/semantic_router/utils/defaults.py +++ b/semantic_router/utils/defaults.py @@ -31,3 +31,8 @@ class EncoderDefault(Enum): "GOOGLE_EMBEDDING_MODEL", "textembedding-gecko@003" ), } + BEDROCK = { + "embedding_model": os.environ.get( + "BEDROCK_EMBEDDING_MODEL", "amazon.titan-embed-image-v1" + ) + } diff --git a/tests/unit/encoders/test_bedrock.py b/tests/unit/encoders/test_bedrock.py new file mode 100644 index 0000000000000000000000000000000000000000..2076d36ea66a7fd7e2af44750ec7b2b93c3c2037 --- /dev/null +++ b/tests/unit/encoders/test_bedrock.py @@ -0,0 +1,90 @@ +import pytest +import json +from io import BytesIO +from semantic_router.encoders import BedrockEncoder + + +@pytest.fixture +def bedrock_encoder(mocker): + mocker.patch("boto3.Session") + mocker.patch("boto3.Session.client") + return BedrockEncoder() + + +class TestBedrockEncoder: + def test_initialisation_with_default_values(self, bedrock_encoder): + assert bedrock_encoder.client is not None, "Client should be initialised" + assert bedrock_encoder.type == "bedrock", "Default type not set correctly" + assert ( + bedrock_encoder.input_type == "search_query" + ), "Default input type not set correctly" + assert bedrock_encoder.session is not None, "Session should be initialised" + assert bedrock_encoder.region is not None, "Region should be initialised" + + def test_initialisation_with_custom_values(self, mocker): + mocker.patch("boto3.Session") + mocker.patch("boto3.Session.client") + name = "custom_model" + session = mocker.Mock() + region = "us-west-2" + score_threshold = 0.5 + input_type = "custom_input" + bedrock_encoder = BedrockEncoder( + name=name, + session=session, + region=region, + score_threshold=score_threshold, + input_type=input_type, + ) + assert bedrock_encoder.name == name, "Custom name not set correctly" + assert bedrock_encoder.session == session, "Custom session not set correctly" + assert bedrock_encoder.region == region, "Custom region not set correctly" + assert ( + bedrock_encoder.score_threshold == score_threshold + ), "Custom score threshold not set correctly" + assert ( + bedrock_encoder.input_type == input_type + ), "Custom input type not set correctly" + + def test_call_method(self, bedrock_encoder): + response_content = json.dumps({"embedding": [0.1, 0.2, 0.3]}) + response_body = BytesIO(response_content.encode("utf-8")) + + mock_response = {"body": response_body} + bedrock_encoder.client.invoke_model.return_value = mock_response + + result = bedrock_encoder(["test"]) + + assert isinstance(result, list), "Result should be a list" + assert all( + isinstance(item, list) for item in result + ), "Each item in result should be a list" + assert result == [[0.1, 0.2, 0.3]], "Embedding should be [0.1, 0.2, 0.3]" + + def test_returns_list_of_embeddings_for_valid_input(self, bedrock_encoder): + response_content = json.dumps({"embedding": [0.1, 0.2, 0.3]}) + response_body = BytesIO(response_content.encode("utf-8")) + + mock_response = {"body": response_body} + + bedrock_encoder.client.invoke_model.return_value = mock_response + + result = bedrock_encoder(["test"]) + + assert isinstance(result, list), "Result should be a list" + assert all( + isinstance(item, list) for item in result + ), "Each item in result should be a list" + assert result == [[0.1, 0.2, 0.3]], "Embedding should be [0.1, 0.2, 0.3]" + + def test_raises_value_error_if_client_is_not_initialised(self, mocker): + mocker.patch("boto3.Session.client", return_value=None) + with pytest.raises(ValueError): + BedrockEncoder() + + def test_raises_value_error_if_call_to_bedrock_fails(self, bedrock_encoder): + bedrock_encoder.client.invoke_model.side_effect = Exception( + "Bedrock call failed." + ) + with pytest.raises(ValueError): + bedrock_encoder(["test"])