diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py index 6f90e61e68646c7100b8f94225c4c3d7201aff0e..59dd1c6358282e22c6c263d2615968ca4983384e 100644 --- a/semantic_router/encoders/cohere.py +++ b/semantic_router/encoders/cohere.py @@ -2,6 +2,7 @@ import os from typing import List, Optional import cohere +from cohere.types.embed_response import EmbedResponse_EmbeddingsFloats, EmbedResponse_EmbeddingsByType from semantic_router.encoders import BaseEncoder from semantic_router.utils.defaults import EncoderDefault @@ -44,6 +45,13 @@ class CohereEncoder(BaseEncoder): embeds = self.client.embed( texts=docs, input_type=self.input_type, model=self.name ) - return embeds.embeddings + # Check the type of response and handle accordingly + # Only EmbedResponse_EmbeddingsFloats has embeddings of type List[List[float]] + if isinstance(embeds, EmbedResponse_EmbeddingsFloats): + return embeds.embeddings + elif isinstance(embeds, EmbedResponse_EmbeddingsByType): + raise NotImplementedError("Handling of EmbedByTypeResponseEmbeddings is not implemented.") + else: + raise ValueError("Unexpected response type from Cohere API") except Exception as e: raise ValueError(f"Cohere API call failed. Error: {e}") from e