Skip to content
Snippets Groups Projects
Unverified Commit 19037308 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Checking cohere embeddings types for mypy.

The newer version of cohere can return incorrect types.
parent 5ffc38ac
Branches
Tags
No related merge requests found
......@@ -2,6 +2,7 @@ import os
from typing import List, Optional
import cohere
from cohere.types.embed_response import EmbedResponse_EmbeddingsFloats, EmbedResponse_EmbeddingsByType
from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault
......@@ -44,6 +45,13 @@ class CohereEncoder(BaseEncoder):
embeds = self.client.embed(
texts=docs, input_type=self.input_type, model=self.name
)
return embeds.embeddings
# Check the type of response and handle accordingly
# Only EmbedResponse_EmbeddingsFloats has embeddings of type List[List[float]]
if isinstance(embeds, EmbedResponse_EmbeddingsFloats):
return embeds.embeddings
elif isinstance(embeds, EmbedResponse_EmbeddingsByType):
raise NotImplementedError("Handling of EmbedByTypeResponseEmbeddings is not implemented.")
else:
raise ValueError("Unexpected response type from Cohere API")
except Exception as e:
raise ValueError(f"Cohere API call failed. Error: {e}") from e
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment