From 73af1ed5f7ec5551ba1b55f3845f43e81a392382 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Thu, 2 Jan 2025 12:54:23 +0400
Subject: [PATCH] fix: sparse emb index type support

---
 semantic_router/index/base.py         |  3 ++
 semantic_router/index/hybrid_local.py |  1 +
 semantic_router/index/local.py        |  1 +
 semantic_router/index/pinecone.py     |  5 +-
 semantic_router/index/postgres.py     |  1 +
 semantic_router/index/qdrant.py       |  1 +
 tests/unit/test_router.py             | 14 ++++--
 tests/unit/test_sync.py               | 70 ++++++++++++++++-----------
 8 files changed, 63 insertions(+), 33 deletions(-)

diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py
index d98ae19e..50702bdf 100644
--- a/semantic_router/index/base.py
+++ b/semantic_router/index/base.py
@@ -38,6 +38,7 @@ class BaseIndex(BaseModel):
         utterances: List[Any],
         function_schemas: Optional[List[Dict[str, Any]]] = None,
         metadata_list: List[Dict[str, Any]] = [],
+        **kwargs,
     ):
         """Add embeddings to the index.
         This method should be implemented by subclasses.
@@ -51,6 +52,7 @@ class BaseIndex(BaseModel):
         utterances: List[str],
         function_schemas: Optional[Optional[List[Dict[str, Any]]]] = None,
         metadata_list: List[Dict[str, Any]] = [],
+        **kwargs,
     ):
         """Add vectors to the index asynchronously.
         This method should be implemented by subclasses.
@@ -62,6 +64,7 @@ class BaseIndex(BaseModel):
             utterances=utterances,
             function_schemas=function_schemas,
             metadata_list=metadata_list,
+            **kwargs,
         )
 
     def get_utterances(self) -> List[Utterance]:
diff --git a/semantic_router/index/hybrid_local.py b/semantic_router/index/hybrid_local.py
index d4096edb..4175eac9 100644
--- a/semantic_router/index/hybrid_local.py
+++ b/semantic_router/index/hybrid_local.py
@@ -25,6 +25,7 @@ class HybridLocalIndex(LocalIndex):
         function_schemas: Optional[List[Dict[str, Any]]] = None,
         metadata_list: List[Dict[str, Any]] = [],
         sparse_embeddings: Optional[List[SparseEmbedding]] = None,
+        **kwargs,
     ):
         if sparse_embeddings is None:
             raise ValueError("Sparse embeddings are required for HybridLocalIndex.")
diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py
index 76d44d82..c4f14fc4 100644
--- a/semantic_router/index/local.py
+++ b/semantic_router/index/local.py
@@ -26,6 +26,7 @@ class LocalIndex(BaseIndex):
         utterances: List[str],
         function_schemas: Optional[List[Dict[str, Any]]] = None,
         metadata_list: List[Dict[str, Any]] = [],
+        **kwargs,
     ):
         embeds = np.array(embeddings)  # type: ignore
         routes_arr = np.array(routes)
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 61247323..da24e226 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -18,6 +18,7 @@ from semantic_router.utils.logger import logger
 def clean_route_name(route_name: str) -> str:
     return route_name.strip().replace(" ", "-")
 
+
 def build_records(
     embeddings: List[List[float]],
     routes: List[str],
@@ -65,7 +66,7 @@ def build_records(
             )
         ]
     return vectors_to_upsert
-            
+
 
 class PineconeRecord(BaseModel):
     id: str = ""
@@ -301,6 +302,7 @@ class PineconeIndex(BaseIndex):
         metadata_list: List[Dict[str, Any]] = [],
         batch_size: int = 100,
         sparse_embeddings: Optional[Optional[List[SparseEmbedding]]] = None,
+        **kwargs,
     ):
         """Add vectors to Pinecone in batches."""
         if self.index is None:
@@ -328,6 +330,7 @@ class PineconeIndex(BaseIndex):
         metadata_list: List[Dict[str, Any]] = [],
         batch_size: int = 100,
         sparse_embeddings: Optional[Optional[List[SparseEmbedding]]] = None,
+        **kwargs,
     ):
         """Add vectors to Pinecone in batches."""
         if self.index is None:
diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py
index 71ea32e8..76d60d2b 100644
--- a/semantic_router/index/postgres.py
+++ b/semantic_router/index/postgres.py
@@ -273,6 +273,7 @@ class PostgresIndex(BaseIndex):
         utterances: List[str],
         function_schemas: Optional[List[Dict[str, Any]]] = None,
         metadata_list: List[Dict[str, Any]] = [],
+        **kwargs,
     ) -> None:
         """
         Adds vectors to the index.
diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py
index 1b23753b..51846629 100644
--- a/semantic_router/index/qdrant.py
+++ b/semantic_router/index/qdrant.py
@@ -170,6 +170,7 @@ class QdrantIndex(BaseIndex):
         function_schemas: Optional[List[Dict[str, Any]]] = None,
         metadata_list: List[Dict[str, Any]] = [],
         batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE,
+        **kwargs,
     ):
         self.dimensions = self.dimensions or len(embeddings[0])
         self._init_collection()
diff --git a/tests/unit/test_router.py b/tests/unit/test_router.py
index 1f743f1c..4698e476 100644
--- a/tests/unit/test_router.py
+++ b/tests/unit/test_router.py
@@ -10,7 +10,7 @@ from semantic_router.encoders import DenseEncoder, CohereEncoder, OpenAIEncoder
 from semantic_router.index.local import LocalIndex
 from semantic_router.index.pinecone import PineconeIndex
 from semantic_router.index.qdrant import QdrantIndex
-from semantic_router.routers import RouterConfig, SemanticRouter
+from semantic_router.routers import RouterConfig, SemanticRouter, HybridRouter
 from semantic_router.llms.base import BaseLLM
 from semantic_router.route import Route
 from platform import python_version
@@ -201,12 +201,20 @@ def get_test_encoders():
     return encoders
 
 
+def get_test_routers():
+    routers = [SemanticRouter]
+    if importlib.util.find_spec("pinecone_text") is not None:
+        routers.append(HybridRouter)
+    return routers
+
+
 @pytest.mark.parametrize(
-    "index_cls,encoder_cls",
+    "index_cls,encoder_cls,router_cls",
     [
-        (index, encoder)
+        (index, encoder, router)
         for index in get_test_indexes()
         for encoder in get_test_encoders()
+        for router in get_test_routers()
     ],
 )
 class TestIndexEncoders:
diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py
index ed23a3d6..baa9d486 100644
--- a/tests/unit/test_sync.py
+++ b/tests/unit/test_sync.py
@@ -237,11 +237,7 @@ def get_test_routers():
 
 @pytest.mark.parametrize(
     "index_cls,router_cls",
-    [
-        (index, router)
-        for index in get_test_indexes()
-        for router in get_test_routers()
-    ],
+    [(index, router) for index in get_test_indexes() for router in get_test_routers()],
 )
 class TestSemanticRouter:
     @pytest.mark.skipif(
@@ -260,7 +256,9 @@ class TestSemanticRouter:
     @pytest.mark.skipif(
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
     )
-    def test_second_initialization_sync(self, openai_encoder, routes, index_cls, router_cls):
+    def test_second_initialization_sync(
+        self, openai_encoder, routes, index_cls, router_cls
+    ):
         index = init_index(index_cls)
         route_layer = router_cls(
             encoder=openai_encoder, routes=routes, index=index, auto_sync="local"
@@ -279,9 +277,7 @@ class TestSemanticRouter:
         _ = router_cls(
             encoder=openai_encoder, routes=routes, index=index, auto_sync="local"
         )
-        route_layer = router_cls(
-            encoder=openai_encoder, routes=routes_2, index=index
-        )
+        route_layer = router_cls(encoder=openai_encoder, routes=routes_2, index=index)
         if index_cls is PineconeIndex:
             time.sleep(PINECONE_SLEEP)  # allow for index to be populated
         assert route_layer.is_synced() is False
@@ -289,14 +285,14 @@ class TestSemanticRouter:
     @pytest.mark.skipif(
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
     )
-    def test_utterance_diff(self, openai_encoder, routes, routes_2, index_cls, router_cls):
+    def test_utterance_diff(
+        self, openai_encoder, routes, routes_2, index_cls, router_cls
+    ):
         index = init_index(index_cls)
         _ = router_cls(
             encoder=openai_encoder, routes=routes, index=index, auto_sync="local"
         )
-        route_layer_2 = router_cls(
-            encoder=openai_encoder, routes=routes_2, index=index
-        )
+        route_layer_2 = router_cls(encoder=openai_encoder, routes=routes_2, index=index)
         if index_cls is PineconeIndex:
             time.sleep(PINECONE_SLEEP)  # allow for index to be populated
         diff = route_layer_2.get_utterance_diff(include_metadata=True)
@@ -312,7 +308,9 @@ class TestSemanticRouter:
     @pytest.mark.skipif(
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
     )
-    def test_auto_sync_local(self, openai_encoder, routes, routes_2, index_cls, router_cls):
+    def test_auto_sync_local(
+        self, openai_encoder, routes, routes_2, index_cls, router_cls
+    ):
         if index_cls is PineconeIndex:
             # TEST LOCAL
             pinecone_index = init_index(index_cls)
@@ -337,7 +335,9 @@ class TestSemanticRouter:
     @pytest.mark.skipif(
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
     )
-    def test_auto_sync_remote(self, openai_encoder, routes, routes_2, index_cls, router_cls):
+    def test_auto_sync_remote(
+        self, openai_encoder, routes, routes_2, index_cls, router_cls
+    ):
         if index_cls is PineconeIndex:
             # TEST REMOTE
             pinecone_index = init_index(index_cls)
@@ -462,7 +462,9 @@ class TestSemanticRouter:
     @pytest.mark.skipif(
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
     )
-    def test_auto_sync_merge(self, openai_encoder, routes, routes_2, index_cls, router_cls):
+    def test_auto_sync_merge(
+        self, openai_encoder, routes, routes_2, index_cls, router_cls
+    ):
         if index_cls is PineconeIndex:
             # TEST MERGE
             pinecone_index = init_index(index_cls)
@@ -540,7 +542,9 @@ class TestSemanticRouter:
     @pytest.mark.skipif(
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
     )
-    def test_sync_lock_auto_releases(self, openai_encoder, routes, index_cls, router_cls):
+    def test_sync_lock_auto_releases(
+        self, openai_encoder, routes, index_cls, router_cls
+    ):
         """Test that sync lock is automatically released after sync operations"""
         index = init_index(index_cls)
         route_layer = router_cls(
@@ -585,7 +589,9 @@ class TestAsyncSemanticRouter:
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
     )
     @pytest.mark.asyncio
-    async def test_second_initialization_sync(self, openai_encoder, routes, index_cls, router_cls):
+    async def test_second_initialization_sync(
+        self, openai_encoder, routes, index_cls, router_cls
+    ):
         index = init_index(index_cls, init_async_index=True)
         route_layer = router_cls(
             encoder=openai_encoder, routes=routes, index=index, auto_sync="local"
@@ -605,9 +611,7 @@ class TestAsyncSemanticRouter:
         _ = router_cls(
             encoder=openai_encoder, routes=routes, index=index, auto_sync="local"
         )
-        route_layer = router_cls(
-            encoder=openai_encoder, routes=routes_2, index=index
-        )
+        route_layer = router_cls(encoder=openai_encoder, routes=routes_2, index=index)
         if index_cls is PineconeIndex:
             await asyncio.sleep(PINECONE_SLEEP)  # allow for index to be populated
         assert await route_layer.async_is_synced() is False
@@ -616,14 +620,14 @@ class TestAsyncSemanticRouter:
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
     )
     @pytest.mark.asyncio
-    async def test_utterance_diff(self, openai_encoder, routes, routes_2, index_cls, router_cls):
+    async def test_utterance_diff(
+        self, openai_encoder, routes, routes_2, index_cls, router_cls
+    ):
         index = init_index(index_cls, init_async_index=True)
         _ = router_cls(
             encoder=openai_encoder, routes=routes, index=index, auto_sync="local"
         )
-        route_layer_2 = router_cls(
-            encoder=openai_encoder, routes=routes_2, index=index
-        )
+        route_layer_2 = router_cls(encoder=openai_encoder, routes=routes_2, index=index)
         if index_cls is PineconeIndex:
             await asyncio.sleep(PINECONE_SLEEP)  # allow for index to be populated
         diff = await route_layer_2.aget_utterance_diff(include_metadata=True)
@@ -640,7 +644,9 @@ class TestAsyncSemanticRouter:
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
     )
     @pytest.mark.asyncio
-    async def test_auto_sync_local(self, openai_encoder, routes, routes_2, index_cls, router_cls):
+    async def test_auto_sync_local(
+        self, openai_encoder, routes, routes_2, index_cls, router_cls
+    ):
         if index_cls is PineconeIndex:
             # TEST LOCAL
             pinecone_index = init_index(index_cls, init_async_index=True)
@@ -666,7 +672,9 @@ class TestAsyncSemanticRouter:
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
     )
     @pytest.mark.asyncio
-    async def test_auto_sync_remote(self, openai_encoder, routes, routes_2, index_cls, router_cls):
+    async def test_auto_sync_remote(
+        self, openai_encoder, routes, routes_2, index_cls, router_cls
+    ):
         if index_cls is PineconeIndex:
             # TEST REMOTE
             pinecone_index = init_index(index_cls, init_async_index=True)
@@ -795,7 +803,9 @@ class TestAsyncSemanticRouter:
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
     )
     @pytest.mark.asyncio
-    async def test_auto_sync_merge(self, openai_encoder, routes, routes_2, index_cls, router_cls):
+    async def test_auto_sync_merge(
+        self, openai_encoder, routes, routes_2, index_cls, router_cls
+    ):
         if index_cls is PineconeIndex:
             # TEST MERGE
             pinecone_index = init_index(index_cls, init_async_index=True)
@@ -875,7 +885,9 @@ class TestAsyncSemanticRouter:
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
     )
     @pytest.mark.asyncio
-    async def test_sync_lock_auto_releases(self, openai_encoder, routes, index_cls, router_cls):
+    async def test_sync_lock_auto_releases(
+        self, openai_encoder, routes, index_cls, router_cls
+    ):
         """Test that sync lock is automatically released after sync operations"""
         index = init_index(index_cls, init_async_index=True)
         route_layer = router_cls(
-- 
GitLab