From 248edf8bc8c411d09cdd442beb8dd9c4bf04dcc7 Mon Sep 17 00:00:00 2001 From: James Briggs <james.briggs@hotmail.com> Date: Thu, 14 Nov 2024 12:37:11 +0100 Subject: [PATCH] fix: ongoing sync work --- semantic_router/index/pinecone.py | 4 +++ semantic_router/layer.py | 47 ++++++++++++++++++++++--- semantic_router/schema.py | 24 ++++++++----- tests/unit/test_sync.py | 58 +++++++++++++++++++++---------- 4 files changed, 101 insertions(+), 32 deletions(-) diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 2231a657..85efd809 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -484,6 +484,8 @@ class PineconeIndex(BaseIndex): batch_size: int = 100, ): """Add vectors to Pinecone in batches.""" + temp = '\n'.join([f"{x[0]}: {x[1]}" for x in zip(routes, utterances)]) + logger.warning("TEMP | add:\n"+temp) if self.index is None: self.dimensions = self.dimensions or len(embeddings[0]) self.index = self._init_index(force_create=True) @@ -506,6 +508,8 @@ class PineconeIndex(BaseIndex): self._batch_upsert(batch) def _remove_and_sync(self, routes_to_delete: dict): + temp = '\n'.join([f"{route}: {utterances}" for route, utterances in routes_to_delete.items()]) + logger.warning("TEMP | _remove_and_sync:\n"+temp) for route, utterances in routes_to_delete.items(): remote_routes = self._get_routes_with_ids(route_name=route) ids_to_delete = [ diff --git a/semantic_router/layer.py b/semantic_router/layer.py index b1bc9142..64186e90 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -521,7 +521,7 @@ class RouteLayer: :param utterances: The utterances to add to the local RouteLayer. :type utterances: List[Utterance] """ - new_routes = {} + new_routes = {route.name: route for route in self.routes} for utt_obj in utterances: if utt_obj.route not in new_routes.keys(): new_routes[utt_obj.route] = Route( @@ -531,8 +531,13 @@ class RouteLayer: metadata=utt_obj.metadata ) else: - new_routes[utt_obj.route].utterances.append(utt_obj.utterance) - self.routes.extend(list(new_routes.values())) + if utt_obj.utterance not in new_routes[utt_obj.route].utterances: + new_routes[utt_obj.route].utterances.append(utt_obj.utterance) + new_routes[utt_obj.route].function_schemas = utt_obj.function_schemas + new_routes[utt_obj.route].metadata = utt_obj.metadata + temp = '\n'.join([f"{name}: {r.utterances}" for name, r in new_routes.items()]) + logger.warning("TEMP | _local_upsert:\n"+temp) + self.routes = list(new_routes.values()) def _local_delete(self, utterances: List[Utterance]): """Deletes routes from the local RouteLayer. @@ -540,8 +545,40 @@ class RouteLayer: :param utterances: The utterances to delete from the local RouteLayer. :type utterances: List[Utterance] """ - route_names = set([utt.route for utt in utterances]) - self.routes = [route for route in self.routes if route.name not in route_names] + # create dictionary of route names to utterances + route_dict = {} + for utt in utterances: + route_dict.setdefault(utt.route, []).append(utt.utterance) + temp = '\n'.join([f"{r}: {u}" for r, u in route_dict.items()]) + logger.warning("TEMP | _local_delete:\n"+temp) + # iterate over current routes and delete specific utterance if found + new_routes = [] + for route in self.routes: + if route.name in route_dict.keys(): + # drop utterances that are in route_dict deletion list + new_utterances = list(set(route.utterances) - set(route_dict[route.name])) + if len(new_utterances) == 0: + # the route is now empty, so we skip it + continue + else: + new_routes.append( + Route( + name=route.name, + utterances=new_utterances, + # use existing function schemas and metadata + function_schemas=route.function_schemas, + metadata=route.metadata + ) + ) + logger.warning(f"TEMP | _local_delete OLD | {route.name}: {route.utterances}") + logger.warning(f"TEMP | _local_delete NEW | {route.name}: {new_routes[-1].utterances}") + else: + # the route is not in the route_dict, so we keep it as is + new_routes.append(route) + temp = '\n'.join([f"{r}: {u}" for r, u in route_dict.items()]) + logger.warning("TEMP | _local_delete:\n"+temp) + + self.routes = new_routes def _retrieve_top_route( diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 1e2dd488..012ef616 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -255,7 +255,7 @@ class UtteranceDiff(BaseModel): elif sync_mode == "local": return { "remote": { - "upsert": local_only, + "upsert": local_only,# + remote_updates, "delete": remote_only }, "local": { @@ -281,11 +281,11 @@ class UtteranceDiff(BaseModel): # if we see route: utterance exists in local, we do not pull it in #Â from remote local_route_utt_strs = set([utt.to_str() for utt in local_only]) - # get remote utterances that belong to routes_in_both + # get remote utterances that are in local remote_to_keep = [utt for utt in remote_only if ( utt.route in local_route_names and utt.to_str() not in local_route_utt_strs )] - #Â get remote utterances that do NOT belong to routes_in_both + #Â get remote utterances that are NOT in local remote_to_delete = [utt for utt in remote_only if utt.route not in local_route_names] return { "remote": { @@ -298,12 +298,18 @@ class UtteranceDiff(BaseModel): } } elif sync_mode == "merge-force-local": # merge-to-remote merge-join-remote - # get set of route names that exist in both local and remote - routes_in_both = set([utt.route for utt in local_and_remote]) - # get local utterances that belong to routes_in_both - local_to_keep = [utt for utt in local_only if utt.route in routes_in_both] - # get local utterances that do NOT belong to routes_in_both - local_to_delete = [utt for utt in local_only if utt.route not in routes_in_both] + # get set of route names that exist in remote (we keep these if + # they are in local) + remote_route_names = set([utt.route for utt in remote_only]) + # if we see route: utterance exists in remote, we do not pull it in + #Â from local + remote_route_utt_strs = set([utt.to_str() for utt in remote_only]) + # get local utterances that are in remote + local_to_keep = [utt for utt in local_only if ( + utt.route in remote_route_names and utt.to_str() not in remote_route_utt_strs + )] + # get local utterances that are NOT in remote + local_to_delete = [utt for utt in local_only if utt.route not in remote_route_names] return { "remote": { "upsert": local_to_keep, diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index ad397dc7..7a20d7a5 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -317,7 +317,10 @@ class TestRouteLayer: #Â we sort to ensure order is the same local_utterances.sort(key=lambda x: x.to_str(include_metadata=True)) assert local_utterances == [ - Utterance(route='Route 1', utterance='Hello'), + Utterance( + route='Route 1', utterance='Hello', + metadata={'type': 'default'} + ), Utterance( route='Route 1', utterance='Hi', metadata={'type': 'default'} @@ -334,28 +337,36 @@ class TestRouteLayer: @pytest.mark.skipif( os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" ) - def test_auto_sync_merge_force_local(self, openai_encoder, routes, index_cls): + def test_auto_sync_merge_force_local(self, openai_encoder, routes, routes_2, index_cls): if index_cls is PineconeIndex: # TEST MERGE FORCE LOCAL pinecone_index = init_index(index_cls) + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes_2, index=pinecone_index, + auto_sync="local" + ) + time.sleep(PINECONE_SLEEP) # allow for index to be populated route_layer = RouteLayer( encoder=openai_encoder, routes=routes, index=pinecone_index, - auto_sync="merge-force-local" + auto_sync="merge-force-remote" ) - time.sleep(PINECONE_SLEEP) # allow for index to be populated - assert route_layer.index.get_utterances() == [ - Utterance( - route="Route 1", utterance="Hello", - metadata={"type": "default"} - ), + # confirm local and remote are synced + assert route_layer.is_synced() + #Â now confirm utterances are correct + local_utterances = route_layer.index.get_utterances() + #Â we sort to ensure order is the same + local_utterances.sort(key=lambda x: x.to_str(include_metadata=True)) + assert local_utterances == [ + Utterance(route='Route 1', utterance='Hello'), Utterance( - route="Route 1", utterance="Hi", - metadata={"type": "default"} + route='Route 1', utterance='Hi', + metadata={'type': 'default'} ), - Utterance(route="Route 2", utterance="Bye"), - Utterance(route="Route 2", utterance="Au revoir"), - Utterance(route="Route 2", utterance="Goodbye"), + Utterance(route='Route 2', utterance='Au revoir'), + Utterance(route='Route 2', utterance='Bye'), + Utterance(route='Route 2', utterance='Goodbye'), + Utterance(route='Route 2', utterance='Hi') ], "The routes in the index should match the local routes" # clear index @@ -364,17 +375,27 @@ class TestRouteLayer: @pytest.mark.skipif( os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" ) - def test_auto_sync_merge(self, openai_encoder, routes_4, index_cls): + def test_auto_sync_merge(self, openai_encoder, routes, routes_2, index_cls): if index_cls is PineconeIndex: # TEST MERGE pinecone_index = init_index(index_cls) route_layer = RouteLayer( - encoder=openai_encoder, routes=routes_4, index=pinecone_index, + encoder=openai_encoder, routes=routes, index=pinecone_index, + auto_sync="local" + ) + time.sleep(PINECONE_SLEEP) # allow for index to be populated + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes_2, index=pinecone_index, auto_sync="merge" ) - time.sleep(PINECONE_SLEEP) # allow for index to be populated - assert route_layer.index.get_utterances() == [ + # confirm local and remote are synced + assert route_layer.is_synced() + #Â now confirm utterances are correct + local_utterances = route_layer.index.get_utterances() + #Â we sort to ensure order is the same + local_utterances.sort(key=lambda x: x.to_str(include_metadata=True)) + assert local_utterances == [ ("Route 1", "Hello", None, {"type": "default"}), ("Route 1", "Hi", None, {"type": "default"}), ("Route 1", "Goodbye", None, {"type": "default"}), @@ -382,6 +403,7 @@ class TestRouteLayer: ("Route 2", "Asparagus", None, {}), ("Route 2", "Au revoir", None, {}), ("Route 2", "Goodbye", None, {}), + ("Route 3", "Boo", None, {}), ], "The routes in the index should match the local routes" # clear index -- GitLab