From 190373087d7ae09fc40dbd77c4b113760d407602 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Wed, 15 May 2024 00:59:05 +0400 Subject: [PATCH] Checking cohere embeddings types for mypy. The newer version of cohere can return incorrect types. --- semantic_router/encoders/cohere.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py index 6f90e61e..59dd1c63 100644 --- a/semantic_router/encoders/cohere.py +++ b/semantic_router/encoders/cohere.py @@ -2,6 +2,7 @@ import os from typing import List, Optional import cohere +from cohere.types.embed_response import EmbedResponse_EmbeddingsFloats, EmbedResponse_EmbeddingsByType from semantic_router.encoders import BaseEncoder from semantic_router.utils.defaults import EncoderDefault @@ -44,6 +45,13 @@ class CohereEncoder(BaseEncoder): embeds = self.client.embed( texts=docs, input_type=self.input_type, model=self.name ) - return embeds.embeddings + # Check the type of response and handle accordingly + # Only EmbedResponse_EmbeddingsFloats has embeddings of type List[List[float]] + if isinstance(embeds, EmbedResponse_EmbeddingsFloats): + return embeds.embeddings + elif isinstance(embeds, EmbedResponse_EmbeddingsByType): + raise NotImplementedError("Handling of EmbedByTypeResponseEmbeddings is not implemented.") + else: + raise ValueError("Unexpected response type from Cohere API") except Exception as e: raise ValueError(f"Cohere API call failed. Error: {e}") from e -- GitLab