From 0063476e922389e65b9e916cd53169399959c330 Mon Sep 17 00:00:00 2001 From: James Briggs <james.briggs@hotmail.com> Date: Wed, 13 Nov 2024 14:57:03 +0100 Subject: [PATCH] fix: merge-force-remote --- semantic_router/layer.py | 1 + semantic_router/schema.py | 18 ++++++++++++------ tests/unit/test_sync.py | 23 ++++++++++++++++++----- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 78de158e..b1bc9142 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -497,6 +497,7 @@ class RouteLayer: data_to_delete.setdefault( utt_obj.route, [] ).append(utt_obj.utterance) + # TODO: switch to remove without sync?? self.index._remove_and_sync(data_to_delete) if strategy["remote"]["upsert"]: utterances_text = [utt.utterance for utt in strategy["remote"]["upsert"]] diff --git a/semantic_router/schema.py b/semantic_router/schema.py index fbf1da50..1e2dd488 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -274,13 +274,19 @@ class UtteranceDiff(BaseModel): "delete": local_only } } - elif sync_mode == "merge-force-remote": - # get set of route names that exist in both local and remote - routes_in_both = set([utt.route for utt in local_and_remote]) + elif sync_mode == "merge-force-remote": # merge-to-local merge-join-local + # get set of route names that exist in local (we keep these if + # they are in remote) + local_route_names = set([utt.route for utt in local_only]) + # if we see route: utterance exists in local, we do not pull it in + #Â from remote + local_route_utt_strs = set([utt.to_str() for utt in local_only]) # get remote utterances that belong to routes_in_both - remote_to_keep = [utt for utt in remote_only if utt.route in routes_in_both] + remote_to_keep = [utt for utt in remote_only if ( + utt.route in local_route_names and utt.to_str() not in local_route_utt_strs + )] #Â get remote utterances that do NOT belong to routes_in_both - remote_to_delete = [utt for utt in remote_only if utt.route not in routes_in_both] + remote_to_delete = [utt for utt in remote_only if utt.route not in local_route_names] return { "remote": { "upsert": local_only, @@ -291,7 +297,7 @@ class UtteranceDiff(BaseModel): "delete": [] } } - elif sync_mode == "merge-force-local": + elif sync_mode == "merge-force-local": # merge-to-remote merge-join-remote # get set of route names that exist in both local and remote routes_in_both = set([utt.route for utt in local_and_remote]) # get local utterances that belong to routes_in_both diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index b7acebe9..ad397dc7 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -25,7 +25,7 @@ def mock_encoder_call(utterances): "Au revoir": [1.3, 1.4, 1.5], "Asparagus": [-2.0, 1.0, 0.0], } - return [mock_responses.get(u, [0.0, 0.0, 0.0]) for u in utterances] + return [mock_responses.get(u, [0.3, 0.1, 0.2]) for u in utterances] TEST_ID = ( @@ -119,6 +119,7 @@ def routes(): return [ Route(name="Route 1", utterances=["Hello", "Hi"], metadata={"type": "default"}), Route(name="Route 2", utterances=["Goodbye", "Bye", "Au revoir"]), + Route(name="Route 3", utterances=["Boo"]), ] @@ -243,6 +244,7 @@ class TestRouteLayer: assert "- Route 2: Hi | None | {}" in diff assert "+ Route 2: Bye | None | {}" in diff assert "+ Route 2: Goodbye | None | {}" in diff + assert "+ Route 3: Boo | None | {}" in diff @pytest.mark.skipif( os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" @@ -300,6 +302,7 @@ class TestRouteLayer: pinecone_index = init_index(index_cls) route_layer = RouteLayer( encoder=openai_encoder, routes=routes, index=pinecone_index, + auto_sync="local" ) time.sleep(PINECONE_SLEEP) # allow for index to be populated route_layer = RouteLayer( @@ -307,12 +310,22 @@ class TestRouteLayer: auto_sync="merge-force-remote" ) time.sleep(PINECONE_SLEEP) # allow for index to be populated - assert route_layer.index.get_utterances() == [ + # confirm local and remote are synced + assert route_layer.is_synced() + #Â now confirm utterances are correct + local_utterances = route_layer.index.get_utterances() + #Â we sort to ensure order is the same + local_utterances.sort(key=lambda x: x.to_str(include_metadata=True)) + assert local_utterances == [ + Utterance(route='Route 1', utterance='Hello'), Utterance( - route="Route 1", utterance="Hello", - metadata={"type": "default"} + route='Route 1', utterance='Hi', + metadata={'type': 'default'} ), - Utterance(route="Route 2", utterance="Hi"), + Utterance(route='Route 2', utterance='Au revoir'), + Utterance(route='Route 2', utterance='Bye'), + Utterance(route='Route 2', utterance='Goodbye'), + Utterance(route='Route 2', utterance='Hi') ], "The routes in the index should match the local routes" # clear index -- GitLab