diff --git a/semantic_router/encoders/bedrock.py b/semantic_router/encoders/bedrock.py index ce04719be7e5fa2117a706938c811803e746b1bc..fad82978bede1579c8eb2dd6d5168e801b58ac96 100644 --- a/semantic_router/encoders/bedrock.py +++ b/semantic_router/encoders/bedrock.py @@ -22,6 +22,7 @@ import os import tiktoken from semantic_router.encoders import BaseEncoder from semantic_router.utils.defaults import EncoderDefault +from semantic_router.utils.logger import logger class BedrockEncoder(BaseEncoder): @@ -69,12 +70,14 @@ class BedrockEncoder(BaseEncoder): """ super().__init__(name=name, score_threshold=score_threshold) - self.access_key_id = self.get_env_variable("access_key_id", access_key_id) + self.access_key_id = self.get_env_variable("AWS_ACCESS_KEY_ID", access_key_id) self.secret_access_key = self.get_env_variable( - "secret_access_key", secret_access_key + "AWS_SECRET_ACCESS_KEY", secret_access_key ) self.session_token = self.get_env_variable("AWS_SESSION_TOKEN", session_token) - self.region = self.get_env_variable("AWS_REGION", region, default="us-west-1") + self.region = self.get_env_variable( + "AWS_DEFAULT_REGION", region, default="us-west-1" + ) self.input_type = input_type @@ -116,9 +119,9 @@ class BedrockEncoder(BaseEncoder): "`pip install boto3`" ) - access_key_id = access_key_id or os.getenv("access_key_id") - aws_secret_key = secret_access_key or os.getenv("secret_access_key") - region = region or os.getenv("AWS_REGION", "us-west-2") + access_key_id = access_key_id or os.getenv("AWS_ACCESS_KEY_ID") + aws_secret_key = secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY") + region = region or os.getenv("AWS_DEFAULT_REGION", "us-west-2") if access_key_id is None: raise ValueError("AWS access key ID cannot be 'None'.") @@ -126,12 +129,14 @@ class BedrockEncoder(BaseEncoder): if aws_secret_key is None: raise ValueError("AWS secret access key cannot be 'None'.") + session = boto3.Session( + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + aws_session_token=session_token, + ) try: - bedrock_client = boto3.client( + bedrock_client = session.client( "bedrock-runtime", - aws_access_key_id=access_key_id, - aws_secret_access_key=secret_access_key, - aws_session_token=session_token, region_name=region, ) except Exception as err: @@ -155,6 +160,8 @@ class BedrockEncoder(BaseEncoder): ValueError: If the Bedrock Platform client is not initialized or if the API call fails. """ + from botocore.exceptions import ClientError + if self.client is None: raise ValueError("Bedrock client is not initialised.") try: @@ -224,6 +231,21 @@ class BedrockEncoder(BaseEncoder): else: raise ValueError("Unknown model name") return embeddings + except ClientError as error: + if error.response["Error"]["Code"] == "ExpiredTokenException": + logger.warning("Session token has expired. Retrying initialisation.") + try: + self.session_token = os.getenv("AWS_SESSION_TOKEN") + self.client = self._initialize_client( + self.access_key_id, + self.secret_access_key, + self.session_token, + self.region, + ) + except Exception as e: + raise ValueError( + f"Bedrock client failed to reinitialise. Error: {e}" + ) from e except Exception as e: raise ValueError(f"Bedrock call failed. Error: {e}") from e @@ -246,5 +268,7 @@ class BedrockEncoder(BaseEncoder): return provided_value value = os.getenv(var_name, default) if value is None: + if var_name == "AWS_SESSION_TOKEN": + return None raise ValueError(f"No {var_name} provided") return value