From 248edf8bc8c411d09cdd442beb8dd9c4bf04dcc7 Mon Sep 17 00:00:00 2001
From: James Briggs <james.briggs@hotmail.com>
Date: Thu, 14 Nov 2024 12:37:11 +0100
Subject: [PATCH] fix: ongoing sync work

---
 semantic_router/index/pinecone.py |  4 +++
 semantic_router/layer.py          | 47 ++++++++++++++++++++++---
 semantic_router/schema.py         | 24 ++++++++-----
 tests/unit/test_sync.py           | 58 +++++++++++++++++++++----------
 4 files changed, 101 insertions(+), 32 deletions(-)

diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 2231a657..85efd809 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -484,6 +484,8 @@ class PineconeIndex(BaseIndex):
         batch_size: int = 100,
     ):
         """Add vectors to Pinecone in batches."""
+        temp = '\n'.join([f"{x[0]}: {x[1]}" for x in zip(routes, utterances)])
+        logger.warning("TEMP | add:\n"+temp)
         if self.index is None:
             self.dimensions = self.dimensions or len(embeddings[0])
             self.index = self._init_index(force_create=True)
@@ -506,6 +508,8 @@ class PineconeIndex(BaseIndex):
             self._batch_upsert(batch)
 
     def _remove_and_sync(self, routes_to_delete: dict):
+        temp = '\n'.join([f"{route}: {utterances}" for route, utterances in routes_to_delete.items()])
+        logger.warning("TEMP | _remove_and_sync:\n"+temp)
         for route, utterances in routes_to_delete.items():
             remote_routes = self._get_routes_with_ids(route_name=route)
             ids_to_delete = [
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index b1bc9142..64186e90 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -521,7 +521,7 @@ class RouteLayer:
         :param utterances: The utterances to add to the local RouteLayer.
         :type utterances: List[Utterance]
         """
-        new_routes = {}
+        new_routes = {route.name: route for route in self.routes}
         for utt_obj in utterances:
             if utt_obj.route not in new_routes.keys():
                 new_routes[utt_obj.route] = Route(
@@ -531,8 +531,13 @@ class RouteLayer:
                     metadata=utt_obj.metadata
                 )
             else:
-                new_routes[utt_obj.route].utterances.append(utt_obj.utterance)
-        self.routes.extend(list(new_routes.values()))
+                if utt_obj.utterance not in new_routes[utt_obj.route].utterances:
+                    new_routes[utt_obj.route].utterances.append(utt_obj.utterance)
+                new_routes[utt_obj.route].function_schemas = utt_obj.function_schemas
+                new_routes[utt_obj.route].metadata = utt_obj.metadata
+        temp = '\n'.join([f"{name}: {r.utterances}" for name, r in new_routes.items()])
+        logger.warning("TEMP | _local_upsert:\n"+temp)
+        self.routes = list(new_routes.values())
 
     def _local_delete(self, utterances: List[Utterance]):
         """Deletes routes from the local RouteLayer.
@@ -540,8 +545,40 @@ class RouteLayer:
         :param utterances: The utterances to delete from the local RouteLayer.
         :type utterances: List[Utterance]
         """
-        route_names = set([utt.route for utt in utterances])
-        self.routes = [route for route in self.routes if route.name not in route_names]
+        # create dictionary of route names to utterances
+        route_dict = {}
+        for utt in utterances:
+            route_dict.setdefault(utt.route, []).append(utt.utterance)
+        temp = '\n'.join([f"{r}: {u}" for r, u in route_dict.items()])
+        logger.warning("TEMP | _local_delete:\n"+temp)
+        # iterate over current routes and delete specific utterance if found
+        new_routes = []
+        for route in self.routes:
+            if route.name in route_dict.keys():
+                # drop utterances that are in route_dict deletion list
+                new_utterances = list(set(route.utterances) - set(route_dict[route.name]))
+                if len(new_utterances) == 0:
+                    # the route is now empty, so we skip it
+                    continue
+                else:
+                    new_routes.append(
+                        Route(
+                            name=route.name,
+                            utterances=new_utterances,
+                            # use existing function schemas and metadata
+                            function_schemas=route.function_schemas,
+                            metadata=route.metadata
+                        )
+                    )
+                logger.warning(f"TEMP | _local_delete OLD | {route.name}: {route.utterances}")
+                logger.warning(f"TEMP | _local_delete NEW | {route.name}: {new_routes[-1].utterances}")
+            else:
+                # the route is not in the route_dict, so we keep it as is
+                new_routes.append(route)
+        temp = '\n'.join([f"{r}: {u}" for r, u in route_dict.items()])
+        logger.warning("TEMP | _local_delete:\n"+temp)
+        
+        self.routes = new_routes
 
 
     def _retrieve_top_route(
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 1e2dd488..012ef616 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -255,7 +255,7 @@ class UtteranceDiff(BaseModel):
         elif sync_mode == "local":
             return {
                 "remote": {
-                    "upsert": local_only,
+                    "upsert": local_only,# + remote_updates,
                     "delete": remote_only
                 },
                 "local": {
@@ -281,11 +281,11 @@ class UtteranceDiff(BaseModel):
             # if we see route: utterance exists in local, we do not pull it in
             # from remote
             local_route_utt_strs = set([utt.to_str() for utt in local_only])
-            # get remote utterances that belong to routes_in_both
+            # get remote utterances that are in local
             remote_to_keep = [utt for utt in remote_only if (
                 utt.route in local_route_names and utt.to_str() not in local_route_utt_strs
             )]
-            # get remote utterances that do NOT belong to routes_in_both
+            # get remote utterances that are NOT in local
             remote_to_delete = [utt for utt in remote_only if utt.route not in local_route_names]
             return {
                 "remote": {
@@ -298,12 +298,18 @@ class UtteranceDiff(BaseModel):
                 }
             }
         elif sync_mode == "merge-force-local":  # merge-to-remote merge-join-remote
-            # get set of route names that exist in both local and remote
-            routes_in_both = set([utt.route for utt in local_and_remote])
-            # get local utterances that belong to routes_in_both
-            local_to_keep = [utt for utt in local_only if utt.route in routes_in_both]
-            # get local utterances that do NOT belong to routes_in_both
-            local_to_delete = [utt for utt in local_only if utt.route not in routes_in_both]
+            # get set of route names that exist in remote (we keep these if
+            # they are in local)
+            remote_route_names = set([utt.route for utt in remote_only])
+            # if we see route: utterance exists in remote, we do not pull it in
+            # from local
+            remote_route_utt_strs = set([utt.to_str() for utt in remote_only])
+            # get local utterances that are in remote
+            local_to_keep = [utt for utt in local_only if (
+                utt.route in remote_route_names and utt.to_str() not in remote_route_utt_strs
+            )]
+            # get local utterances that are NOT in remote
+            local_to_delete = [utt for utt in local_only if utt.route not in remote_route_names]
             return {
                 "remote": {
                     "upsert": local_to_keep,
diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py
index ad397dc7..7a20d7a5 100644
--- a/tests/unit/test_sync.py
+++ b/tests/unit/test_sync.py
@@ -317,7 +317,10 @@ class TestRouteLayer:
             # we sort to ensure order is the same
             local_utterances.sort(key=lambda x: x.to_str(include_metadata=True))
             assert local_utterances == [
-                Utterance(route='Route 1', utterance='Hello'),
+                Utterance(
+                    route='Route 1', utterance='Hello',
+                    metadata={'type': 'default'}
+                ),
                 Utterance(
                     route='Route 1', utterance='Hi',
                     metadata={'type': 'default'}
@@ -334,28 +337,36 @@ class TestRouteLayer:
     @pytest.mark.skipif(
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
     )
-    def test_auto_sync_merge_force_local(self, openai_encoder, routes, index_cls):
+    def test_auto_sync_merge_force_local(self, openai_encoder, routes, routes_2, index_cls):
         if index_cls is PineconeIndex:
             # TEST MERGE FORCE LOCAL
             pinecone_index = init_index(index_cls)
+            route_layer = RouteLayer(
+                encoder=openai_encoder, routes=routes_2, index=pinecone_index,
+                auto_sync="local"
+            )
+            time.sleep(PINECONE_SLEEP)  # allow for index to be populated
             route_layer = RouteLayer(
                 encoder=openai_encoder, routes=routes, index=pinecone_index,
-                auto_sync="merge-force-local"
+                auto_sync="merge-force-remote"
             )
-
             time.sleep(PINECONE_SLEEP)  # allow for index to be populated
-            assert route_layer.index.get_utterances() == [
-                Utterance(
-                    route="Route 1", utterance="Hello",
-                    metadata={"type": "default"}
-                ),
+            # confirm local and remote are synced
+            assert route_layer.is_synced()
+            # now confirm utterances are correct
+            local_utterances = route_layer.index.get_utterances()
+            # we sort to ensure order is the same
+            local_utterances.sort(key=lambda x: x.to_str(include_metadata=True))
+            assert local_utterances == [
+                Utterance(route='Route 1', utterance='Hello'),
                 Utterance(
-                    route="Route 1", utterance="Hi",
-                    metadata={"type": "default"}
+                    route='Route 1', utterance='Hi',
+                    metadata={'type': 'default'}
                 ),
-                Utterance(route="Route 2", utterance="Bye"),
-                Utterance(route="Route 2", utterance="Au revoir"),
-                Utterance(route="Route 2", utterance="Goodbye"),
+                Utterance(route='Route 2', utterance='Au revoir'),
+                Utterance(route='Route 2', utterance='Bye'),
+                Utterance(route='Route 2', utterance='Goodbye'),
+                Utterance(route='Route 2', utterance='Hi')
             ], "The routes in the index should match the local routes"
 
             # clear index
@@ -364,17 +375,27 @@ class TestRouteLayer:
     @pytest.mark.skipif(
         os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required"
     )
-    def test_auto_sync_merge(self, openai_encoder, routes_4, index_cls):
+    def test_auto_sync_merge(self, openai_encoder, routes, routes_2, index_cls):
         if index_cls is PineconeIndex:
             # TEST MERGE
             pinecone_index = init_index(index_cls)
             route_layer = RouteLayer(
-                encoder=openai_encoder, routes=routes_4, index=pinecone_index,
+                encoder=openai_encoder, routes=routes, index=pinecone_index,
+                auto_sync="local"
+            )
+            time.sleep(PINECONE_SLEEP)  # allow for index to be populated
+            route_layer = RouteLayer(
+                encoder=openai_encoder, routes=routes_2, index=pinecone_index,
                 auto_sync="merge"
             )
-
             time.sleep(PINECONE_SLEEP)  # allow for index to be populated
-            assert route_layer.index.get_utterances() == [
+            # confirm local and remote are synced
+            assert route_layer.is_synced()
+            # now confirm utterances are correct
+            local_utterances = route_layer.index.get_utterances()
+            # we sort to ensure order is the same
+            local_utterances.sort(key=lambda x: x.to_str(include_metadata=True))
+            assert local_utterances == [
                 ("Route 1", "Hello", None, {"type": "default"}),
                 ("Route 1", "Hi", None, {"type": "default"}),
                 ("Route 1", "Goodbye", None, {"type": "default"}),
@@ -382,6 +403,7 @@ class TestRouteLayer:
                 ("Route 2", "Asparagus", None, {}),
                 ("Route 2", "Au revoir", None, {}),
                 ("Route 2", "Goodbye", None, {}),
+                ("Route 3", "Boo", None, {}),
             ], "The routes in the index should match the local routes"
 
             # clear index
-- 
GitLab