From 4788c614894b8d3575fbc31ba68f00d000e90237 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Fri, 3 Jan 2025 14:44:57 +0400
Subject: [PATCH] fix: async usage and tests

---
 semantic_router/index/pinecone.py | 16 ++------
 tests/unit/test_sync.py           | 67 ++++++++++++++++++++++++++-----
 2 files changed, 60 insertions(+), 23 deletions(-)

diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index a92c57bd..e70d8b62 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -282,17 +282,6 @@ class PineconeIndex(BaseIndex):
         else:
             raise ValueError("Index is None, could not upsert.")
 
-    async def _async_batch_upsert(self, batch: List[Dict]):
-        """Helper method for upserting a single batch of records asynchronously.
-
-        :param batch: The batch of records to upsert.
-        :type batch: List[Dict]
-        """
-        if self.index is not None:
-            await self.index.upsert(vectors=batch, namespace=self.namespace)
-        else:
-            raise ValueError("Index is None, could not upsert.")
-
     def add(
         self,
         embeddings: List[List[float]],
@@ -347,7 +336,10 @@ class PineconeIndex(BaseIndex):
 
         for i in range(0, len(vectors_to_upsert), batch_size):
             batch = vectors_to_upsert[i : i + batch_size]
-            await self._async_batch_upsert(batch)
+            await self._async_upsert(
+                vectors=batch,
+                namespace=self.namespace or "",
+            )
 
     def _remove_and_sync(self, routes_to_delete: dict):
         for route, utterances in routes_to_delete.items():
diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py
index baa9d486..390cab49 100644
--- a/tests/unit/test_sync.py
+++ b/tests/unit/test_sync.py
@@ -148,12 +148,28 @@ def base_encoder():
 @pytest.fixture
 def cohere_encoder(mocker):
     mocker.patch.object(CohereEncoder, "__call__", side_effect=mock_encoder_call)
+
+    # Mock async call
+    async def async_mock_encoder_call(docs=None, utterances=None):
+        # Handle either docs or utterances parameter
+        texts = docs if docs is not None else utterances
+        return mock_encoder_call(texts)
+    
+    mocker.patch.object(CohereEncoder, "acall", side_effect=async_mock_encoder_call)
     return CohereEncoder(name="test-cohere-encoder", cohere_api_key="test_api_key")
 
 
 @pytest.fixture
 def openai_encoder(mocker):
     mocker.patch.object(OpenAIEncoder, "__call__", side_effect=mock_encoder_call)
+
+    # Mock async call
+    async def async_mock_encoder_call(docs=None, utterances=None):
+        # Handle either docs or utterances parameter
+        texts = docs if docs is not None else utterances
+        return mock_encoder_call(texts)
+    
+    mocker.patch.object(OpenAIEncoder, "acall", side_effect=async_mock_encoder_call)
     return OpenAIEncoder(name="text-embedding-3-small", openai_api_key="test_api_key")
 
 
@@ -508,17 +524,23 @@ class TestSemanticRouter:
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
     )
     def test_sync_lock_prevents_concurrent_sync(
-        self, openai_encoder, routes, index_cls, router_cls
+        self, openai_encoder, routes, routes_2, index_cls, router_cls
     ):
         """Test that sync lock prevents concurrent synchronization operations"""
         index = init_index(index_cls)
+        route_layer = router_cls(
+            encoder=openai_encoder,
+            routes=routes_2,
+            index=index,
+            auto_sync="local",
+        )
+        # initialize an out of sync router
         route_layer = router_cls(
             encoder=openai_encoder,
             routes=routes,
             index=index,
             auto_sync=None,
         )
-
         # Acquire sync lock
         route_layer.index.lock(value=True)
         if index_cls is PineconeIndex:
@@ -565,11 +587,15 @@ class TestSemanticRouter:
             time.sleep(PINECONE_SLEEP)
         assert route_layer.is_synced()
 
-        # clear index
-        route_layer.index.index.delete(namespace="", delete_all=True)
+        # clear index if pinecone
+        if index_cls is PineconeIndex:
+            route_layer.index.client.delete_index(route_layer.index.index_name)
 
 
-@pytest.mark.parametrize("index_cls", get_test_indexes())
+@pytest.mark.parametrize(
+    "index_cls,router_cls",
+    [(index, router) for index in get_test_indexes() for router in get_test_routers()],
+)
 class TestAsyncSemanticRouter:
     @pytest.mark.skipif(
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
@@ -762,7 +788,7 @@ class TestAsyncSemanticRouter:
             )
             await asyncio.sleep(PINECONE_SLEEP)  # allow for index to be populated
             # confirm local and remote are synced
-            assert route_layer.async_is_synced()
+            assert await route_layer.async_is_synced()
             # now confirm utterances are correct
             local_utterances = await route_layer.index.aget_utterances()
             # we sort to ensure order is the same
@@ -850,10 +876,17 @@ class TestAsyncSemanticRouter:
     )
     @pytest.mark.asyncio
     async def test_sync_lock_prevents_concurrent_sync(
-        self, openai_encoder, routes, index_cls, router_cls
+        self, openai_encoder, routes, routes_2, index_cls, router_cls
     ):
         """Test that sync lock prevents concurrent synchronization operations"""
         index = init_index(index_cls, init_async_index=True)
+        route_layer = router_cls(
+            encoder=openai_encoder,
+            routes=routes_2,
+            index=index,
+            auto_sync="local",
+        )
+        # initialize an out of sync router
         route_layer = router_cls(
             encoder=openai_encoder,
             routes=routes,
@@ -886,27 +919,39 @@ class TestAsyncSemanticRouter:
     )
     @pytest.mark.asyncio
     async def test_sync_lock_auto_releases(
-        self, openai_encoder, routes, index_cls, router_cls
+        self, openai_encoder, routes, routes_2, index_cls, router_cls
     ):
         """Test that sync lock is automatically released after sync operations"""
         index = init_index(index_cls, init_async_index=True)
+        print(f"1. {index.namespace=}")
+        route_layer = router_cls(
+            encoder=openai_encoder,
+            routes=routes_2,
+            index=index,
+            auto_sync="local",
+        )
+        print(f"2. {route_layer.index.namespace=}")
         route_layer = router_cls(
             encoder=openai_encoder,
             routes=routes,
             index=index,
             auto_sync=None,
         )
-
+        if index_cls is PineconeIndex:
+            await asyncio.sleep(PINECONE_SLEEP)
         # Initial sync should acquire and release lock
         await route_layer.async_sync("local")
         if index_cls is PineconeIndex:
             await asyncio.sleep(PINECONE_SLEEP)
+        print(f"3. {route_layer.index.namespace=}")
 
         # Lock should be released, allowing another sync
         await route_layer.async_sync("local")  # Should not raise exception
         if index_cls is PineconeIndex:
             await asyncio.sleep(PINECONE_SLEEP)
         assert await route_layer.async_is_synced()
+        print(f"4. {route_layer.index.namespace=}")
 
-        # clear index
-        route_layer.index.index.delete(namespace="", delete_all=True)
+        # clear index if pinecone
+        if index_cls is PineconeIndex:
+            route_layer.index.client.delete_index(route_layer.index.index_name)
-- 
GitLab