From 50199b19d0114e8c84e05f77219a979d1aed7739 Mon Sep 17 00:00:00 2001 From: Kurtis Massey <55586356+kurtismassey@users.noreply.github.com> Date: Tue, 21 May 2024 22:48:13 +0100 Subject: [PATCH] Retry handling, tests and lint --- poetry.lock | 12 +- pyproject.toml | 5 +- semantic_router/encoders/bedrock.py | 198 +++++++++++++--------------- tests/unit/encoders/test_bedrock.py | 133 ++++++++++++++++--- 4 files changed, 213 insertions(+), 135 deletions(-) diff --git a/poetry.lock b/poetry.lock index e5cc68f9..53e1951c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -131,13 +131,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.108" +version = "1.34.110" description = "Low-level, data-driven core of boto 3." optional = true python-versions = ">=3.8" files = [ - {file = "botocore-1.34.108-py3-none-any.whl", hash = "sha256:b1b9d00804267669c5fcc36489269f7e9c43580c30f0885fbf669cf73cec720b"}, - {file = "botocore-1.34.108.tar.gz", hash = "sha256:384c9408c447631475dc41fdc9bf2e0f30c29c420d96bfe8b468bdc2bace3e13"}, + {file = "botocore-1.34.110-py3-none-any.whl", hash = "sha256:1edf3a825ec0a5edf238b2d42ad23305de11d5a71bb27d6f9a58b7e8862df1b6"}, + {file = "botocore-1.34.110.tar.gz", hash = "sha256:b2c98c40ecf0b1facb9e61ceb7dfa28e61ae2456490554a16c8dbf99f20d6a18"}, ] [package.dependencies] @@ -4420,7 +4420,7 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -bedrock = ["boto3"] +bedrock = ["boto3", "botocore"] fastembed = ["fastembed"] google = ["google-cloud-aiplatform"] hybrid = ["pinecone-text"] @@ -4434,4 +4434,4 @@ vision = ["pillow", "torch", "torchvision", "transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "954b14578e3ee2ab4236a3533e313035039b9e6c01bca91f52c7c81ff836a089" +content-hash = "b7effbc6291a3b3ffcbb88efa3c1db8d1167c4fa959349b640be64f99e9c8618" diff --git a/pyproject.toml b/pyproject.toml index 9b4c3903..7f61dd0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ qdrant-client = {version = "^1.8.0", optional = true} google-cloud-aiplatform = {version = "^1.45.0", optional = true} requests-mock = "^1.12.1" boto3 = { version = "^1.34.98", optional = true } +botocore = {version = "^1.34.110", optional = true} [tool.poetry.extras] hybrid = ["pinecone-text"] @@ -51,7 +52,7 @@ processing = ["matplotlib"] mistralai = ["mistralai"] qdrant = ["qdrant-client"] google = ["google-cloud-aiplatform"] -bedrock = ["boto3"] +bedrock = ["boto3", "botocore"] [tool.poetry.group.dev.dependencies] ipykernel = "^6.25.0" @@ -76,4 +77,4 @@ build-backend = "poetry.core.masonry.api" line-length = 88 [tool.mypy] -ignore_missing_imports = true \ No newline at end of file +ignore_missing_imports = true diff --git a/semantic_router/encoders/bedrock.py b/semantic_router/encoders/bedrock.py index fad82978..40b43411 100644 --- a/semantic_router/encoders/bedrock.py +++ b/semantic_router/encoders/bedrock.py @@ -19,6 +19,7 @@ Classes: import json from typing import List, Optional, Any import os +from time import sleep import tiktoken from semantic_router.encoders import BaseEncoder from semantic_router.utils.defaults import EncoderDefault @@ -68,7 +69,6 @@ class BedrockEncoder(BaseEncoder): Raises: ValueError: If the Bedrock Platform client fails to initialize. """ - super().__init__(name=name, score_threshold=score_threshold) self.access_key_id = self.get_env_variable("AWS_ACCESS_KEY_ID", access_key_id) self.secret_access_key = self.get_env_variable( @@ -78,9 +78,7 @@ class BedrockEncoder(BaseEncoder): self.region = self.get_env_variable( "AWS_DEFAULT_REGION", region, default="us-west-1" ) - self.input_type = input_type - try: self.client = self._initialize_client( self.access_key_id, @@ -88,7 +86,6 @@ class BedrockEncoder(BaseEncoder): self.session_token, self.region, ) - except Exception as e: raise ValueError(f"Bedrock client failed to initialise. Error: {e}") from e @@ -118,17 +115,13 @@ class BedrockEncoder(BaseEncoder): "You can install them with: " "`pip install boto3`" ) - access_key_id = access_key_id or os.getenv("AWS_ACCESS_KEY_ID") aws_secret_key = secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY") region = region or os.getenv("AWS_DEFAULT_REGION", "us-west-2") - if access_key_id is None: raise ValueError("AWS access key ID cannot be 'None'.") - if aws_secret_key is None: raise ValueError("AWS secret access key cannot be 'None'.") - session = boto3.Session( aws_access_key_id=access_key_id, aws_secret_access_key=secret_access_key, @@ -143,7 +136,6 @@ class BedrockEncoder(BaseEncoder): raise ValueError( f"The Bedrock client failed to initialize. Error: {err}" ) from err - return bedrock_client def __call__(self, docs: List[str]) -> List[List[float]]: @@ -160,110 +152,104 @@ class BedrockEncoder(BaseEncoder): ValueError: If the Bedrock Platform client is not initialized or if the API call fails. """ - from botocore.exceptions import ClientError - + try: + from botocore.exceptions import ClientError + except ImportError: + raise ImportError( + "Please install Amazon's Botocore client library to use the BedrockEncoder. " + "You can install them with: " + "`pip install botocore`" + ) if self.client is None: raise ValueError("Bedrock client is not initialised.") - try: - embeddings = [] - - def chunk_strings(strings, MAX_WORDS=20): - """ - Breaks up a list of strings into smaller chunks. - - Args: - strings (list): A list of strings to be chunked. - max_chunk_size (int): The maximum size of each chunk. Default is 75. - - Returns: - list: A list of lists, where each inner list contains a chunk of strings. - """ - encoding = tiktoken.get_encoding("cl100k_base") - chunked_strings = [] - current_chunk = [] - - for text in strings: - encoded_text = encoding.encode(text) - - if len(encoded_text) > MAX_WORDS: - current_chunk = [ - encoding.decode(encoded_text[i : i + MAX_WORDS]) - for i in range(0, len(encoded_text), MAX_WORDS) - ] - else: - current_chunk = [encoding.decode(encoded_text)] - - chunked_strings.append(current_chunk) - return chunked_strings - - if self.name and "amazon" in self.name: - for doc in docs: - embedding_body = json.dumps( - { - "inputText": doc, - } - ) - response = self.client.invoke_model( - body=embedding_body, - modelId=self.name, - accept="application/json", - contentType="application/json", - ) - - response_body = json.loads(response.get("body").read()) - embeddings.append(response_body.get("embedding")) - elif self.name and "cohere" in self.name: - chunked_docs = chunk_strings(docs) - for chunk in chunked_docs: - chunk = json.dumps({"texts": chunk, "input_type": self.input_type}) - - response = self.client.invoke_model( - body=chunk, - modelId=self.name, - accept="*/*", - contentType="application/json", - ) - - response_body = json.loads(response.get("body").read()) - - chunk_embeddings = response_body.get("embeddings") - embeddings.extend(chunk_embeddings) - else: - raise ValueError("Unknown model name") - return embeddings - except ClientError as error: - if error.response["Error"]["Code"] == "ExpiredTokenException": - logger.warning("Session token has expired. Retrying initialisation.") - try: - self.session_token = os.getenv("AWS_SESSION_TOKEN") - self.client = self._initialize_client( - self.access_key_id, - self.secret_access_key, - self.session_token, - self.region, - ) - except Exception as e: - raise ValueError( - f"Bedrock client failed to reinitialise. Error: {e}" - ) from e - except Exception as e: - raise ValueError(f"Bedrock call failed. Error: {e}") from e - - @staticmethod - def get_env_variable(var_name, provided_value, default=None): - """Retrieves environment variable or uses a provided value. + max_attempts = 3 + for attempt in range(max_attempts): + try: + embeddings = [] + if self.name and "amazon" in self.name: + for doc in docs: + embedding_body = json.dumps( + { + "inputText": doc, + } + ) + response = self.client.invoke_model( + body=embedding_body, + modelId=self.name, + accept="application/json", + contentType="application/json", + ) + response_body = json.loads(response.get("body").read()) + embeddings.append(response_body.get("embedding")) + elif self.name and "cohere" in self.name: + chunked_docs = self.chunk_strings(docs) + for chunk in chunked_docs: + chunk = json.dumps( + {"texts": chunk, "input_type": self.input_type} + ) + response = self.client.invoke_model( + body=chunk, + modelId=self.name, + accept="*/*", + contentType="application/json", + ) + response_body = json.loads(response.get("body").read()) + chunk_embeddings = response_body.get("embeddings") + embeddings.extend(chunk_embeddings) + else: + raise ValueError("Unknown model name") + return embeddings + except ClientError as error: + if attempt < max_attempts - 1: + if error.response["Error"]["Code"] == "ExpiredTokenException": + logger.warning( + "Session token has expired. Retrying initialisation." + ) + try: + self.session_token = os.getenv("AWS_SESSION_TOKEN") + self.client = self._initialize_client( + self.access_key_id, + self.secret_access_key, + self.session_token, + self.region, + ) + except Exception as e: + raise ValueError( + f"Bedrock client failed to reinitialise. Error: {e}" + ) from e + sleep(2**attempt) + logger.warning(f"Retrying in {2**attempt} seconds...") + raise ValueError( + f"Retries exhausted, Bedrock call failed. Error: {error}" + ) from error + except Exception as e: + raise ValueError(f"Bedrock call failed. Error: {e}") from e + raise ValueError("Bedrock call to return embeddings.") + + def chunk_strings(self, strings, MAX_WORDS=20): + """ + Breaks up a list of strings into smaller chunks. Args: - var_name (str): The name of the environment variable. - provided_value (Optional[str]): The provided value to use if not None. - default (Optional[str]): The default value if the environment variable is not set. + strings (list): A list of strings to be chunked. + max_chunk_size (int): The maximum size of each chunk. Default is 20. Returns: - str: The value of the environment variable or the provided/default value. - - Raises: - ValueError: If no value is provided and the environment variable is not set. + list: A list of lists, where each inner list contains a chunk of strings. """ + encoding = tiktoken.get_encoding("cl100k_base") + chunked_strings = [] + for text in strings: + encoded_text = encoding.encode(text) + chunks = [ + encoding.decode(encoded_text[i : i + MAX_WORDS]) + for i in range(0, len(encoded_text), MAX_WORDS) + ] + chunked_strings.append(chunks) + return chunked_strings + + @staticmethod + def get_env_variable(var_name, provided_value, default=None): if provided_value is not None: return provided_value value = os.getenv(var_name, default) diff --git a/tests/unit/encoders/test_bedrock.py b/tests/unit/encoders/test_bedrock.py index 43955d45..1a9cf835 100644 --- a/tests/unit/encoders/test_bedrock.py +++ b/tests/unit/encoders/test_bedrock.py @@ -1,3 +1,4 @@ +import os import pytest import json from io import BytesIO @@ -15,6 +16,18 @@ def bedrock_encoder(mocker): ) +@pytest.fixture +def bedrock_encoder_with_cohere(mocker): + mocker.patch("semantic_router.encoders.bedrock.BedrockEncoder._initialize_client") + return BedrockEncoder( + name="cohere_model", + access_key_id="fake_id", + secret_access_key="fake_secret", + session_token="fake_token", + region="us-west-2", + ) + + class TestBedrockEncoder: def test_initialisation_with_default_values(self, bedrock_encoder): assert ( @@ -23,9 +36,6 @@ class TestBedrockEncoder: assert bedrock_encoder.region == "us-west-2", "Region should be initialised" def test_initialisation_with_custom_values(self, mocker): - # mocker.patch( - # "semantic_router.encoders.bedrock.BedrockEncoder._initialize_client" - # ) name = "custom_model" score_threshold = 0.5 input_type = "custom_input" @@ -47,6 +57,45 @@ class TestBedrockEncoder: bedrock_encoder.input_type == input_type ), "Custom input type not set correctly" + def test_initialisation_with_session_token(self, mocker): + mocker.patch( + "semantic_router.encoders.bedrock.BedrockEncoder._initialize_client" + ) + bedrock_encoder = BedrockEncoder( + access_key_id="fake_id", + secret_access_key="fake_secret", + session_token="fake_token", + region="us-west-2", + ) + assert ( + bedrock_encoder.session_token == "fake_token" + ), "Session token not set correctly" + + def test_initialisation_with_missing_access_key(self, mocker): + mocker.patch.dict(os.environ, {"AWS_ACCESS_KEY_ID": "env_id"}) + mocker.patch( + "semantic_router.encoders.bedrock.BedrockEncoder._initialize_client" + ) + bedrock_encoder = BedrockEncoder( + access_key_id=None, + secret_access_key="fake_secret", + session_token="fake_token", + region="us-west-2", + ) + assert ( + bedrock_encoder.access_key_id == "env_id" + ), "Access key ID not set correctly from environment variable" + + def test_initialisation_missing_env_variables(self, mocker): + mocker.patch.dict(os.environ, {}, clear=True) + with pytest.raises(ValueError): + BedrockEncoder( + access_key_id=None, + secret_access_key=None, + session_token=None, + region=None, + ) + def test_call_method(self, bedrock_encoder): response_content = json.dumps({"embedding": [0.1, 0.2, 0.3]}) response_body = BytesIO(response_content.encode("utf-8")) @@ -59,18 +108,38 @@ class TestBedrockEncoder: ), "Each item in result should be a list" assert result == [[0.1, 0.2, 0.3]], "Embedding should be [0.1, 0.2, 0.3]" - def test_raises_value_error_if_client_is_not_initialised(self, mocker): + def test_call_with_expired_token(self, mocker, bedrock_encoder): + from botocore.exceptions import ClientError + + error_response = {"Error": {"Code": "ExpiredTokenException"}} mocker.patch( "semantic_router.encoders.bedrock.BedrockEncoder._initialize_client", - side_effect=Exception("Client initialisation failed"), + return_value=None, ) + + def invoke_model_side_effect(*args, **kwargs): + if not invoke_model_side_effect.expired_token_raised: + invoke_model_side_effect.expired_token_raised = True + raise ClientError(error_response, "invoke_model") + else: + return { + "body": BytesIO( + json.dumps({"embedding": [0.1, 0.2, 0.3]}).encode("utf-8") + ) + } + + invoke_model_side_effect.expired_token_raised = False + bedrock_encoder.client.invoke_model.side_effect = invoke_model_side_effect + with pytest.raises(ValueError): - BedrockEncoder( - access_key_id="fake_id", - secret_access_key="fake_secret", - session_token="fake_token", - region="us-west-2", - ) + bedrock_encoder(["test"]) + + bedrock_encoder._initialize_client.assert_called_once_with( + bedrock_encoder.access_key_id, + bedrock_encoder.secret_access_key, + None, + bedrock_encoder.region, + ) def test_raises_value_error_if_call_to_bedrock_fails(self, bedrock_encoder): bedrock_encoder.client.invoke_model.side_effect = Exception( @@ -79,17 +148,39 @@ class TestBedrockEncoder: with pytest.raises(ValueError): bedrock_encoder(["test"]) + def test_call_with_unknown_model_name(self, bedrock_encoder): + bedrock_encoder.name = "unknown_model" + with pytest.raises(ValueError): + bedrock_encoder(["test"]) -@pytest.fixture -def bedrock_encoder_with_cohere(mocker): - mocker.patch("semantic_router.encoders.bedrock.BedrockEncoder._initialize_client") - return BedrockEncoder( - name="cohere_model", - access_key_id="fake_id", - secret_access_key="fake_secret", - session_token="fake_token", - region="us-west-2", - ) + def test_chunking_functionality(self, bedrock_encoder): + docs = ["This is a long text that needs to be chunked properly."] + chunked_docs = bedrock_encoder.chunk_strings(docs, MAX_WORDS=5) + assert isinstance(chunked_docs, list), "Chunked result should be a list" + assert ( + len(chunked_docs[0]) > 1 + ), "Document should be chunked into multiple parts" + assert all( + isinstance(chunk, str) for chunk in chunked_docs[0] + ), "Chunks should be strings" + + def test_get_env_variable(self): + var_name = "TEST_ENV_VAR" + default_value = "default" + os.environ[var_name] = "env_value" + assert BedrockEncoder.get_env_variable(var_name, None) == "env_value" + assert ( + BedrockEncoder.get_env_variable(var_name, None, default_value) + == "env_value" + ) + assert ( + BedrockEncoder.get_env_variable("NON_EXISTENT_VAR", None, default_value) + == default_value + ) + + def test_get_env_variable_missing(self): + with pytest.raises(ValueError): + BedrockEncoder.get_env_variable("MISSING_VAR", None) class TestBedrockEncoderWithCohere: -- GitLab