From 4788c614894b8d3575fbc31ba68f00d000e90237 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Fri, 3 Jan 2025 14:44:57 +0400 Subject: [PATCH] fix: async usage and tests --- semantic_router/index/pinecone.py | 16 ++------ tests/unit/test_sync.py | 67 ++++++++++++++++++++++++++----- 2 files changed, 60 insertions(+), 23 deletions(-) diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index a92c57bd..e70d8b62 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -282,17 +282,6 @@ class PineconeIndex(BaseIndex): else: raise ValueError("Index is None, could not upsert.") - async def _async_batch_upsert(self, batch: List[Dict]): - """Helper method for upserting a single batch of records asynchronously. - - :param batch: The batch of records to upsert. - :type batch: List[Dict] - """ - if self.index is not None: - await self.index.upsert(vectors=batch, namespace=self.namespace) - else: - raise ValueError("Index is None, could not upsert.") - def add( self, embeddings: List[List[float]], @@ -347,7 +336,10 @@ class PineconeIndex(BaseIndex): for i in range(0, len(vectors_to_upsert), batch_size): batch = vectors_to_upsert[i : i + batch_size] - await self._async_batch_upsert(batch) + await self._async_upsert( + vectors=batch, + namespace=self.namespace or "", + ) def _remove_and_sync(self, routes_to_delete: dict): for route, utterances in routes_to_delete.items(): diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index baa9d486..390cab49 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -148,12 +148,28 @@ def base_encoder(): @pytest.fixture def cohere_encoder(mocker): mocker.patch.object(CohereEncoder, "__call__", side_effect=mock_encoder_call) + + # Mock async call + async def async_mock_encoder_call(docs=None, utterances=None): + # Handle either docs or utterances parameter + texts = docs if docs is not None else utterances + return mock_encoder_call(texts) + + mocker.patch.object(CohereEncoder, "acall", side_effect=async_mock_encoder_call) return CohereEncoder(name="test-cohere-encoder", cohere_api_key="test_api_key") @pytest.fixture def openai_encoder(mocker): mocker.patch.object(OpenAIEncoder, "__call__", side_effect=mock_encoder_call) + + # Mock async call + async def async_mock_encoder_call(docs=None, utterances=None): + # Handle either docs or utterances parameter + texts = docs if docs is not None else utterances + return mock_encoder_call(texts) + + mocker.patch.object(OpenAIEncoder, "acall", side_effect=async_mock_encoder_call) return OpenAIEncoder(name="text-embedding-3-small", openai_api_key="test_api_key") @@ -508,17 +524,23 @@ class TestSemanticRouter: os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" ) def test_sync_lock_prevents_concurrent_sync( - self, openai_encoder, routes, index_cls, router_cls + self, openai_encoder, routes, routes_2, index_cls, router_cls ): """Test that sync lock prevents concurrent synchronization operations""" index = init_index(index_cls) + route_layer = router_cls( + encoder=openai_encoder, + routes=routes_2, + index=index, + auto_sync="local", + ) + # initialize an out of sync router route_layer = router_cls( encoder=openai_encoder, routes=routes, index=index, auto_sync=None, ) - # Acquire sync lock route_layer.index.lock(value=True) if index_cls is PineconeIndex: @@ -565,11 +587,15 @@ class TestSemanticRouter: time.sleep(PINECONE_SLEEP) assert route_layer.is_synced() - # clear index - route_layer.index.index.delete(namespace="", delete_all=True) + # clear index if pinecone + if index_cls is PineconeIndex: + route_layer.index.client.delete_index(route_layer.index.index_name) -@pytest.mark.parametrize("index_cls", get_test_indexes()) +@pytest.mark.parametrize( + "index_cls,router_cls", + [(index, router) for index in get_test_indexes() for router in get_test_routers()], +) class TestAsyncSemanticRouter: @pytest.mark.skipif( os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" @@ -762,7 +788,7 @@ class TestAsyncSemanticRouter: ) await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated # confirm local and remote are synced - assert route_layer.async_is_synced() + assert await route_layer.async_is_synced() # now confirm utterances are correct local_utterances = await route_layer.index.aget_utterances() # we sort to ensure order is the same @@ -850,10 +876,17 @@ class TestAsyncSemanticRouter: ) @pytest.mark.asyncio async def test_sync_lock_prevents_concurrent_sync( - self, openai_encoder, routes, index_cls, router_cls + self, openai_encoder, routes, routes_2, index_cls, router_cls ): """Test that sync lock prevents concurrent synchronization operations""" index = init_index(index_cls, init_async_index=True) + route_layer = router_cls( + encoder=openai_encoder, + routes=routes_2, + index=index, + auto_sync="local", + ) + # initialize an out of sync router route_layer = router_cls( encoder=openai_encoder, routes=routes, @@ -886,27 +919,39 @@ class TestAsyncSemanticRouter: ) @pytest.mark.asyncio async def test_sync_lock_auto_releases( - self, openai_encoder, routes, index_cls, router_cls + self, openai_encoder, routes, routes_2, index_cls, router_cls ): """Test that sync lock is automatically released after sync operations""" index = init_index(index_cls, init_async_index=True) + print(f"1. {index.namespace=}") + route_layer = router_cls( + encoder=openai_encoder, + routes=routes_2, + index=index, + auto_sync="local", + ) + print(f"2. {route_layer.index.namespace=}") route_layer = router_cls( encoder=openai_encoder, routes=routes, index=index, auto_sync=None, ) - + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) # Initial sync should acquire and release lock await route_layer.async_sync("local") if index_cls is PineconeIndex: await asyncio.sleep(PINECONE_SLEEP) + print(f"3. {route_layer.index.namespace=}") # Lock should be released, allowing another sync await route_layer.async_sync("local") # Should not raise exception if index_cls is PineconeIndex: await asyncio.sleep(PINECONE_SLEEP) assert await route_layer.async_is_synced() + print(f"4. {route_layer.index.namespace=}") - # clear index - route_layer.index.index.delete(namespace="", delete_all=True) + # clear index if pinecone + if index_cls is PineconeIndex: + route_layer.index.client.delete_index(route_layer.index.index_name) -- GitLab