diff --git a/semantic_router/encoders/bedrock.py b/semantic_router/encoders/bedrock.py index ff4c0d0c2832c83d592faa4a3ff9a5f228bde290..7adf60c717252b71da0d9f604aec43d2dc5226b9 100644 --- a/semantic_router/encoders/bedrock.py +++ b/semantic_router/encoders/bedrock.py @@ -138,7 +138,9 @@ class BedrockEncoder(DenseEncoder): ) from err return bedrock_client - def __call__(self, docs: List[Union[str, Dict]], model_kwargs: Optional[Dict] = None) -> List[List[float]]: + def __call__( + self, docs: List[Union[str, Dict]], model_kwargs: Optional[Dict] = None + ) -> List[List[float]]: """Generates embeddings for the given documents. Args: @@ -169,27 +171,29 @@ class BedrockEncoder(DenseEncoder): embeddings = [] if self.name and "amazon" in self.name: for doc in docs: - + embedding_body = {} if isinstance(doc, dict): - embedding_body['inputText'] = doc.get('text') - embedding_body['inputImage'] = doc.get('image') # expects a base64-encoded image + embedding_body["inputText"] = doc.get("text") + embedding_body["inputImage"] = doc.get( + "image" + ) # expects a base64-encoded image else: - embedding_body['inputText'] = doc + embedding_body["inputText"] = doc # Add model-specific inference parameters if model_kwargs: embedding_body = embedding_body | model_kwargs - + # Clean up null values embedding_body = {k: v for k, v in embedding_body.items() if v} - + # Format payload - embedding_body = json.dumps(embedding_body) + embedding_body_payload: str = json.dumps(embedding_body) response = self.client.invoke_model( - body=embedding_body, + body=embedding_body_payload, modelId=self.name, accept="application/json", contentType="application/json", @@ -199,10 +203,7 @@ class BedrockEncoder(DenseEncoder): elif self.name and "cohere" in self.name: chunked_docs = self.chunk_strings(docs) for chunk in chunked_docs: - chunk = { - 'texts': chunk, - 'input_type': self.input_type - } + chunk = {"texts": chunk, "input_type": self.input_type} # Add model-specific inference parameters # Note: if specified, input_type will be overwritten by model_kwargs