From 09c3a74c1d3ca665707794e9ae41610d90656805 Mon Sep 17 00:00:00 2001
From: Vits <vittorio.mayellaro.dev@gmail.com>
Date: Tue, 16 Jul 2024 22:47:56 +0200
Subject: [PATCH] Fixed

---
 semantic_router/index/pinecone.py | 56 ++++++++++++++++++++++++++++---
 semantic_router/layer.py          | 24 ++++++++++++-
 2 files changed, 75 insertions(+), 5 deletions(-)

diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index a578eb01..64312b3f 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -11,6 +11,7 @@ from pydantic.v1 import BaseModel, Field
 
 from semantic_router.index.base import BaseIndex
 from semantic_router.utils.logger import logger
+from semantic_router.route import Route
 
 
 def clean_route_name(route_name: str) -> str:
@@ -203,6 +204,28 @@ class PineconeIndex(BaseIndex):
 
     def _sync_index(self, local_routes: dict):
         remote_routes = self.get_routes()
+        if not local_routes["routes"]:
+            if self.sync != "remote":
+                raise ValueError(
+                    "Local routes must be provided to sync the index if the sync setting is not 'remote'."
+                )
+            else:
+                if not remote_routes:
+                    raise ValueError("No routes found in the index.")
+        if (
+            (self.sync in ["remote", "merge-force-remote"] and not remote_routes)
+            or (
+                self.sync in ["error", "local", "merge-force-local"]
+                and not local_routes["routes"]
+            )
+            or (
+                self.sync == "merge"
+                and not remote_routes
+                and not local_routes["routes"]
+            )
+        ):
+            raise ValueError("No routes found in the index.")
+
         remote_dict: dict = {route: set() for route, _ in remote_routes}
         for route, utterance in remote_routes:
             remote_dict[route].add(utterance)
@@ -215,6 +238,7 @@ class PineconeIndex(BaseIndex):
 
         routes_to_add = []
         routes_to_delete = []
+        layer_routes = {}
 
         for route in all_routes:
             local_utterances = local_dict.get(route, set())
@@ -226,8 +250,11 @@ class PineconeIndex(BaseIndex):
                         f"Synchronization error: Differences found in route '{route}'"
                     )
                 utterances_to_include: set = set()
+                layer_routes[route] = list(local_utterances)
             elif self.sync == "remote":
                 utterances_to_include = set()
+                if remote_utterances:
+                    layer_routes[route] = list(remote_utterances)
             elif self.sync == "local":
                 utterances_to_include = local_utterances - remote_utterances
                 routes_to_delete.extend(
@@ -237,11 +264,16 @@ class PineconeIndex(BaseIndex):
                         if utterance not in local_utterances
                     ]
                 )
+                layer_routes[route] = list(local_utterances)
             elif self.sync == "merge-force-remote":
                 if route in local_dict and route not in remote_dict:
                     utterances_to_include = local_utterances
+                    if local_utterances:
+                        layer_routes[route] = list(local_utterances)
                 else:
                     utterances_to_include = set()
+                    if remote_utterances:
+                        layer_routes[route] = list(remote_utterances)
             elif self.sync == "merge-force-local":
                 if route in local_dict:
                     utterances_to_include = local_utterances - remote_utterances
@@ -252,10 +284,15 @@ class PineconeIndex(BaseIndex):
                             if utterance not in local_utterances
                         ]
                     )
+                    if local_utterances:
+                        layer_routes[route] = local_utterances
                 else:
                     utterances_to_include = set()
+                    if remote_utterances:
+                        layer_routes[route] = list(remote_utterances)
             elif self.sync == "merge":
                 utterances_to_include = local_utterances - remote_utterances
+                layer_routes[route] = list(remote_utterances.union(local_utterances))
             else:
                 raise ValueError("Invalid sync mode specified")
 
@@ -272,7 +309,7 @@ class PineconeIndex(BaseIndex):
                     ]
                 )
 
-        return routes_to_add, routes_to_delete
+        return routes_to_add, routes_to_delete, layer_routes
 
     def _batch_upsert(self, batch: List[Dict]):
         """Helper method for upserting a single batch of records."""
@@ -308,8 +345,8 @@ class PineconeIndex(BaseIndex):
         routes: List[str],
         utterances: List[str],
         batch_size: int = 100,
-    ):
-        """Add vectors to Pinecone in batches."""
+    ) -> List[Route]:
+        """Add vectors to Pinecone in batches and return the overall updated list of Route objects."""
         if self.index is None:
             self.dimensions = self.dimensions or len(embeddings[0])
             self.index = self._init_index(force_create=True)
@@ -320,7 +357,15 @@ class PineconeIndex(BaseIndex):
             "embeddings": embeddings,
         }
         if self.sync is not None:
-            data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes)
+            data_to_upsert, data_to_delete, layer_routes_dict = self._sync_index(
+                local_routes=local_routes
+            )
+
+            layer_routes = [
+                Route(name=route, utterances=layer_routes_dict[route])
+                for route in layer_routes_dict.keys()
+            ]
+
             routes_to_delete: dict = {}
             for route, utterance in data_to_delete:
                 routes_to_delete.setdefault(route, []).append(utterance)
@@ -335,6 +380,7 @@ class PineconeIndex(BaseIndex):
                 ]
                 if ids_to_delete and self.index:
                     self.index.delete(ids=ids_to_delete)
+
         else:
             data_to_upsert = [
                 (vector, route, utterance)
@@ -350,6 +396,8 @@ class PineconeIndex(BaseIndex):
             batch = vectors_to_upsert[i : i + batch_size]
             self._batch_upsert(batch)
 
+        return layer_routes
+
     def _get_route_ids(self, route_name: str):
         clean_route = clean_route_name(route_name)
         ids, _ = self._get_all(prefix=f"{clean_route}#")
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 5c2d7228..20a325b7 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -220,6 +220,19 @@ class RouteLayer:
         if len(self.routes) > 0:
             # initialize index now
             self._add_routes(routes=self.routes)
+        elif self.index.sync in ["merge", "remote", "merge-force-remote"]:
+            dummy_embedding = self.encoder(["dummy"])
+
+            layer_routes = self.index._add_and_sync(
+                embeddings=dummy_embedding,
+                routes=[],
+                utterances=[],
+            )
+            self._set_layer_routes(layer_routes)
+        else:
+            raise ValueError(
+                "No routes provided for RouteLayer. Please provide routes or set sync to 'remote' if you want to use only remote routes."
+            )
 
     def check_for_matching_routes(self, top_class: str) -> Optional[Route]:
         matching_routes = [route for route in self.routes if route.name == top_class]
@@ -380,6 +393,14 @@ class RouteLayer:
         )
         return self._pass_threshold(scores, threshold)
 
+    def _set_layer_routes(self, new_routes: List[Route]):
+        """
+        Set and override the current routes with a new list of routes.
+
+        :param new_routes: List of Route objects to set as the current routes.
+        """
+        self.routes = new_routes
+
     def __str__(self):
         return (
             f"RouteLayer(encoder={self.encoder}, "
@@ -466,11 +487,12 @@ class RouteLayer:
         # create route array
         route_names = [route.name for route in routes for _ in route.utterances]
         # add everything to the index
-        self.index._add_and_sync(
+        layer_routes = self.index._add_and_sync(
             embeddings=embedded_utterances,
             routes=route_names,
             utterances=all_utterances,
         )
+        self._set_layer_routes(layer_routes)
 
     def _encode(self, text: str) -> Any:
         """Given some text, encode it."""
-- 
GitLab