From 0063476e922389e65b9e916cd53169399959c330 Mon Sep 17 00:00:00 2001
From: James Briggs <james.briggs@hotmail.com>
Date: Wed, 13 Nov 2024 14:57:03 +0100
Subject: [PATCH] fix: merge-force-remote

---
 semantic_router/layer.py  |  1 +
 semantic_router/schema.py | 18 ++++++++++++------
 tests/unit/test_sync.py   | 23 ++++++++++++++++++-----
 3 files changed, 31 insertions(+), 11 deletions(-)

diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 78de158e..b1bc9142 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -497,6 +497,7 @@ class RouteLayer:
                 data_to_delete.setdefault(
                     utt_obj.route, []
                 ).append(utt_obj.utterance)
+            # TODO: switch to remove without sync??
             self.index._remove_and_sync(data_to_delete)
         if strategy["remote"]["upsert"]:
             utterances_text = [utt.utterance for utt in strategy["remote"]["upsert"]]
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index fbf1da50..1e2dd488 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -274,13 +274,19 @@ class UtteranceDiff(BaseModel):
                     "delete": local_only
                 }
             }
-        elif sync_mode == "merge-force-remote":
-            # get set of route names that exist in both local and remote
-            routes_in_both = set([utt.route for utt in local_and_remote])
+        elif sync_mode == "merge-force-remote":  # merge-to-local merge-join-local
+            # get set of route names that exist in local (we keep these if
+            # they are in remote)
+            local_route_names = set([utt.route for utt in local_only])
+            # if we see route: utterance exists in local, we do not pull it in
+            # from remote
+            local_route_utt_strs = set([utt.to_str() for utt in local_only])
             # get remote utterances that belong to routes_in_both
-            remote_to_keep = [utt for utt in remote_only if utt.route in routes_in_both]
+            remote_to_keep = [utt for utt in remote_only if (
+                utt.route in local_route_names and utt.to_str() not in local_route_utt_strs
+            )]
             # get remote utterances that do NOT belong to routes_in_both
-            remote_to_delete = [utt for utt in remote_only if utt.route not in routes_in_both]
+            remote_to_delete = [utt for utt in remote_only if utt.route not in local_route_names]
             return {
                 "remote": {
                     "upsert": local_only,
@@ -291,7 +297,7 @@ class UtteranceDiff(BaseModel):
                     "delete": []
                 }
             }
-        elif sync_mode == "merge-force-local":
+        elif sync_mode == "merge-force-local":  # merge-to-remote merge-join-remote
             # get set of route names that exist in both local and remote
             routes_in_both = set([utt.route for utt in local_and_remote])
             # get local utterances that belong to routes_in_both
diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py
index b7acebe9..ad397dc7 100644
--- a/tests/unit/test_sync.py
+++ b/tests/unit/test_sync.py
@@ -25,7 +25,7 @@ def mock_encoder_call(utterances):
         "Au revoir": [1.3, 1.4, 1.5],
         "Asparagus": [-2.0, 1.0, 0.0],
     }
-    return [mock_responses.get(u, [0.0, 0.0, 0.0]) for u in utterances]
+    return [mock_responses.get(u, [0.3, 0.1, 0.2]) for u in utterances]
 
 
 TEST_ID = (
@@ -119,6 +119,7 @@ def routes():
     return [
         Route(name="Route 1", utterances=["Hello", "Hi"], metadata={"type": "default"}),
         Route(name="Route 2", utterances=["Goodbye", "Bye", "Au revoir"]),
+        Route(name="Route 3", utterances=["Boo"]),
     ]
 
 
@@ -243,6 +244,7 @@ class TestRouteLayer:
         assert "- Route 2: Hi | None | {}" in diff
         assert "+ Route 2: Bye | None | {}" in diff
         assert "+ Route 2: Goodbye | None | {}" in diff
+        assert "+ Route 3: Boo | None | {}" in diff
 
     @pytest.mark.skipif(
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
@@ -300,6 +302,7 @@ class TestRouteLayer:
             pinecone_index = init_index(index_cls)
             route_layer = RouteLayer(
                 encoder=openai_encoder, routes=routes, index=pinecone_index,
+                auto_sync="local"
             )
             time.sleep(PINECONE_SLEEP)  # allow for index to be populated
             route_layer = RouteLayer(
@@ -307,12 +310,22 @@ class TestRouteLayer:
                 auto_sync="merge-force-remote"
             )
             time.sleep(PINECONE_SLEEP)  # allow for index to be populated
-            assert route_layer.index.get_utterances() == [
+            # confirm local and remote are synced
+            assert route_layer.is_synced()
+            # now confirm utterances are correct
+            local_utterances = route_layer.index.get_utterances()
+            # we sort to ensure order is the same
+            local_utterances.sort(key=lambda x: x.to_str(include_metadata=True))
+            assert local_utterances == [
+                Utterance(route='Route 1', utterance='Hello'),
                 Utterance(
-                    route="Route 1", utterance="Hello",
-                    metadata={"type": "default"}
+                    route='Route 1', utterance='Hi',
+                    metadata={'type': 'default'}
                 ),
-                Utterance(route="Route 2", utterance="Hi"),
+                Utterance(route='Route 2', utterance='Au revoir'),
+                Utterance(route='Route 2', utterance='Bye'),
+                Utterance(route='Route 2', utterance='Goodbye'),
+                Utterance(route='Route 2', utterance='Hi')
             ], "The routes in the index should match the local routes"
 
             # clear index
-- 
GitLab