From 69c2e9b73c69469e75d1e0c7d9fb305d8acf43e0 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Wed, 1 Jan 2025 14:11:54 +0400
Subject: [PATCH] fix: allow types to work between pinecone and hybrid

---
 semantic_router/index/pinecone.py | 119 +++++++++++++++++-------------
 semantic_router/routers/hybrid.py |   4 +-
 2 files changed, 70 insertions(+), 53 deletions(-)

diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 7e7def05..61247323 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -18,11 +18,59 @@ from semantic_router.utils.logger import logger
 def clean_route_name(route_name: str) -> str:
     return route_name.strip().replace(" ", "-")
 
+def build_records(
+    embeddings: List[List[float]],
+    routes: List[str],
+    utterances: List[str],
+    function_schemas: Optional[Optional[List[Dict[str, Any]]]] = None,
+    metadata_list: List[Dict[str, Any]] = [],
+    sparse_embeddings: Optional[Optional[List[SparseEmbedding]]] = None,
+) -> List[Dict]:
+    if function_schemas is None:
+        function_schemas = [{}] * len(embeddings)
+    if sparse_embeddings is None:
+        vectors_to_upsert = [
+            PineconeRecord(
+                values=vector,
+                route=route,
+                utterance=utterance,
+                function_schema=json.dumps(function_schema),
+                metadata=metadata,
+            ).to_dict()
+            for vector, route, utterance, function_schema, metadata in zip(
+                embeddings,
+                routes,
+                utterances,
+                function_schemas,
+                metadata_list,
+            )
+        ]
+    else:
+        vectors_to_upsert = [
+            PineconeRecord(
+                values=vector,
+                sparse_values=sparse_emb.to_pinecone(),
+                route=route,
+                utterance=utterance,
+                function_schema=json.dumps(function_schema),
+                metadata=metadata,
+            ).to_dict()
+            for vector, route, utterance, function_schema, metadata, sparse_emb in zip(
+                embeddings,
+                routes,
+                utterances,
+                function_schemas,
+                metadata_list,
+                sparse_embeddings,
+            )
+        ]
+    return vectors_to_upsert
+            
 
 class PineconeRecord(BaseModel):
     id: str = ""
     values: List[float]
-    sparse_values: Optional[dict[int, float]] = None
+    sparse_values: Optional[dict[str, list]] = None
     route: str
     utterance: str
     function_schema: str = "{}"
@@ -49,10 +97,7 @@ class PineconeRecord(BaseModel):
             "metadata": self.metadata,
         }
         if self.sparse_values:
-            d["sparse_values"] = {
-                "indices": list(self.sparse_values.keys()),
-                "values": list(self.sparse_values.values()),
-            }
+            d["sparse_values"] = self.sparse_values
         return d
 
 
@@ -255,34 +300,20 @@ class PineconeIndex(BaseIndex):
         function_schemas: Optional[Optional[List[Dict[str, Any]]]] = None,
         metadata_list: List[Dict[str, Any]] = [],
         batch_size: int = 100,
-        sparse_embeddings: Optional[Optional[List[dict[int, float]]]] = None,
+        sparse_embeddings: Optional[Optional[List[SparseEmbedding]]] = None,
     ):
         """Add vectors to Pinecone in batches."""
         if self.index is None:
             self.dimensions = self.dimensions or len(embeddings[0])
             self.index = self._init_index(force_create=True)
-        if function_schemas is None:
-            function_schemas = [{}] * len(embeddings)
-        if sparse_embeddings is None:
-            sparse_embeddings = [{}] * len(embeddings)
-        vectors_to_upsert = [
-            PineconeRecord(
-                values=vector,
-                sparse_values=sparse_dict,
-                route=route,
-                utterance=utterance,
-                function_schema=json.dumps(function_schema),
-                metadata=metadata,
-            ).to_dict()
-            for vector, route, utterance, function_schema, metadata, sparse_dict in zip(
-                embeddings,
-                routes,
-                utterances,
-                function_schemas,
-                metadata_list,
-                sparse_embeddings,
-            )
-        ]
+        vectors_to_upsert = build_records(
+            embeddings=embeddings,
+            routes=routes,
+            utterances=utterances,
+            function_schemas=function_schemas,
+            metadata_list=metadata_list,
+            sparse_embeddings=sparse_embeddings,
+        )
 
         for i in range(0, len(vectors_to_upsert), batch_size):
             batch = vectors_to_upsert[i : i + batch_size]
@@ -296,34 +327,20 @@ class PineconeIndex(BaseIndex):
         function_schemas: Optional[Optional[List[Dict[str, Any]]]] = None,
         metadata_list: List[Dict[str, Any]] = [],
         batch_size: int = 100,
-        sparse_embeddings: Optional[Optional[List[dict[int, float]]]] = None,
+        sparse_embeddings: Optional[Optional[List[SparseEmbedding]]] = None,
     ):
         """Add vectors to Pinecone in batches."""
         if self.index is None:
             self.dimensions = self.dimensions or len(embeddings[0])
             self.index = await self._init_async_index(force_create=True)
-        if function_schemas is None:
-            function_schemas = [{}] * len(embeddings)
-        if sparse_embeddings is None:
-            sparse_embeddings = [{}] * len(embeddings)
-        vectors_to_upsert = [
-            PineconeRecord(
-                values=vector,
-                sparse_values=sparse_dict,
-                route=route,
-                utterance=utterance,
-                function_schema=json.dumps(function_schema),
-                metadata=metadata,
-            ).to_dict()
-            for vector, route, utterance, function_schema, metadata, sparse_dict in zip(
-                embeddings,
-                routes,
-                utterances,
-                function_schemas,
-                metadata_list,
-                sparse_embeddings,
-            )
-        ]
+        vectors_to_upsert = build_records(
+            embeddings=embeddings,
+            routes=routes,
+            utterances=utterances,
+            function_schemas=function_schemas,
+            metadata_list=metadata_list,
+            sparse_embeddings=sparse_embeddings,
+        )
 
         for i in range(0, len(vectors_to_upsert), batch_size):
             batch = vectors_to_upsert[i : i + batch_size]
diff --git a/semantic_router/routers/hybrid.py b/semantic_router/routers/hybrid.py
index 54901d5e..0bb0574b 100644
--- a/semantic_router/routers/hybrid.py
+++ b/semantic_router/routers/hybrid.py
@@ -92,7 +92,7 @@ class HybridRouter(BaseRouter):
             utterances=all_utterances,
             function_schemas=all_function_schemas,
             metadata_list=all_metadata,
-            sparse_embeddings=sparse_emb,  # type: ignore
+            sparse_embeddings=sparse_emb,
         )
 
         self.routes.extend(routes)
@@ -129,7 +129,7 @@ class HybridRouter(BaseRouter):
                     utt.function_schemas for utt in strategy["remote"]["upsert"]  # type: ignore
                 ],
                 metadata_list=[utt.metadata for utt in strategy["remote"]["upsert"]],
-                sparse_embeddings=sparse_emb,  # type: ignore
+                sparse_embeddings=sparse_emb,
             )
         if strategy["local"]["delete"]:
             self._local_delete(utterances=strategy["local"]["delete"])
-- 
GitLab