From 190373087d7ae09fc40dbd77c4b113760d407602 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Wed, 15 May 2024 00:59:05 +0400
Subject: [PATCH] Checking cohere embeddings types for mypy.

The newer version of cohere can return incorrect types.
---
 semantic_router/encoders/cohere.py | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py
index 6f90e61e..59dd1c63 100644
--- a/semantic_router/encoders/cohere.py
+++ b/semantic_router/encoders/cohere.py
@@ -2,6 +2,7 @@ import os
 from typing import List, Optional
 
 import cohere
+from cohere.types.embed_response import EmbedResponse_EmbeddingsFloats, EmbedResponse_EmbeddingsByType 
 
 from semantic_router.encoders import BaseEncoder
 from semantic_router.utils.defaults import EncoderDefault
@@ -44,6 +45,13 @@ class CohereEncoder(BaseEncoder):
             embeds = self.client.embed(
                 texts=docs, input_type=self.input_type, model=self.name
             )
-            return embeds.embeddings
+            # Check the type of response and handle accordingly
+            # Only EmbedResponse_EmbeddingsFloats has embeddings of type List[List[float]]
+            if isinstance(embeds, EmbedResponse_EmbeddingsFloats):
+                return embeds.embeddings
+            elif isinstance(embeds, EmbedResponse_EmbeddingsByType):
+                raise NotImplementedError("Handling of EmbedByTypeResponseEmbeddings is not implemented.")
+            else:
+                raise ValueError("Unexpected response type from Cohere API")
         except Exception as e:
             raise ValueError(f"Cohere API call failed. Error: {e}") from e
-- 
GitLab