Skip to content
Snippets Groups Projects
Unverified Commit 19037308 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Checking cohere embeddings types for mypy.

The newer version of cohere can return incorrect types.
parent 5ffc38ac
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
from typing import List, Optional from typing import List, Optional
import cohere import cohere
from cohere.types.embed_response import EmbedResponse_EmbeddingsFloats, EmbedResponse_EmbeddingsByType
from semantic_router.encoders import BaseEncoder from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.defaults import EncoderDefault
...@@ -44,6 +45,13 @@ class CohereEncoder(BaseEncoder): ...@@ -44,6 +45,13 @@ class CohereEncoder(BaseEncoder):
embeds = self.client.embed( embeds = self.client.embed(
texts=docs, input_type=self.input_type, model=self.name texts=docs, input_type=self.input_type, model=self.name
) )
return embeds.embeddings # Check the type of response and handle accordingly
# Only EmbedResponse_EmbeddingsFloats has embeddings of type List[List[float]]
if isinstance(embeds, EmbedResponse_EmbeddingsFloats):
return embeds.embeddings
elif isinstance(embeds, EmbedResponse_EmbeddingsByType):
raise NotImplementedError("Handling of EmbedByTypeResponseEmbeddings is not implemented.")
else:
raise ValueError("Unexpected response type from Cohere API")
except Exception as e: except Exception as e:
raise ValueError(f"Cohere API call failed. Error: {e}") from e raise ValueError(f"Cohere API call failed. Error: {e}") from e
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment