From 4d6f6a1e2bcf0b135e2591db47dd8139e107cdfa Mon Sep 17 00:00:00 2001
From: Vits <vittorio.mayellaro.dev@gmail.com>
Date: Tue, 30 Jul 2024 19:26:48 +0200
Subject: [PATCH] Implemented/Modified remove and sync methods for pinecone

---
 semantic_router/index/base.py     |  13 ++--
 semantic_router/index/local.py    |  14 ++--
 semantic_router/index/pinecone.py | 111 +++++++++++++-----------------
 semantic_router/index/qdrant.py   |  15 ++--
 semantic_router/layer.py          |  63 +++++++++++++++--
 5 files changed, 122 insertions(+), 94 deletions(-)

diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py
index 76388d1d..9eb99532 100644
--- a/semantic_router/index/base.py
+++ b/semantic_router/index/base.py
@@ -32,15 +32,10 @@ class BaseIndex(BaseModel):
         This method should be implemented by subclasses.
         """
         raise NotImplementedError("This method should be implemented by subclasses.")
-
-    def _add_and_sync(
-        self,
-        embeddings: List[List[float]],
-        routes: List[str],
-        utterances: List[Any],
-    ):
+    
+    def _remove_and_sync(self, routes_to_delete: dict):
         """
-        Add embeddings to the index and manage index syncing if necessary.
+        Remove embeddings in a routes syncing process from the index.
         This method should be implemented by subclasses.
         """
         raise NotImplementedError("This method should be implemented by subclasses.")
@@ -91,7 +86,7 @@ class BaseIndex(BaseModel):
         """
         raise NotImplementedError("This method should be implemented by subclasses.")
 
-    def _sync_index(self, local_routes: dict):
+    def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int):
         """
         Synchronize the local index with the remote index based on the specified mode.
         Modes:
diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py
index 5426ec76..dbc41f1a 100644
--- a/semantic_router/index/local.py
+++ b/semantic_router/index/local.py
@@ -42,15 +42,13 @@ class LocalIndex(BaseIndex):
             self.routes = np.concatenate([self.routes, routes_arr])
             self.utterances = np.concatenate([self.utterances, utterances_arr])
 
-    def _add_and_sync(
-        self,
-        embeddings: List[List[float]],
-        routes: List[str],
-        utterances: List[str],
-    ):
+    def _remove_and_sync(self, routes_to_delete: dict):
+        if self.sync is not None:
+            logger.warning("Sync remove is not implemented for LocalIndex.")
+
+    def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int):
         if self.sync is not None:
-            logger.warning("Sync add is not implemented for LocalIndex.")
-        self.add(embeddings, routes, utterances)
+            logger.error("Sync remove is not implemented for LocalIndex.")
 
     def get_routes(self) -> List[Tuple]:
         """
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index a578eb01..e88fa148 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -11,6 +11,7 @@ from pydantic.v1 import BaseModel, Field
 
 from semantic_router.index.base import BaseIndex
 from semantic_router.utils.logger import logger
+from semantic_router.route import Route
 
 
 def clean_route_name(route_name: str) -> str:
@@ -201,33 +202,49 @@ class PineconeIndex(BaseIndex):
             logger.warning("Index could not be initialized.")
         self.host = index_stats["host"] if index_stats else None
 
-    def _sync_index(self, local_routes: dict):
+    def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int):
+        if self.index is None:
+            self.dimensions = self.dimensions or dimensions
+            self.index = self._init_index(force_create=True)
+
         remote_routes = self.get_routes()
+
         remote_dict: dict = {route: set() for route, _ in remote_routes}
         for route, utterance in remote_routes:
             remote_dict[route].add(utterance)
 
-        local_dict: dict = {route: set() for route in local_routes["routes"]}
-        for route, utterance in zip(local_routes["routes"], local_routes["utterances"]):
+        local_dict: dict = {route: set() for route in local_route_names}
+        for route, utterance in zip(local_route_names, local_utterances):
             local_dict[route].add(utterance)
 
+        logger.info(f"Local routes: {local_dict}")
+        logger.info(f"Remote routes: {remote_dict}")
+
         all_routes = set(remote_dict.keys()).union(local_dict.keys())
 
         routes_to_add = []
         routes_to_delete = []
+        layer_routes = {}
 
         for route in all_routes:
             local_utterances = local_dict.get(route, set())
             remote_utterances = remote_dict.get(route, set())
 
+            if not local_utterances and not remote_utterances:
+                continue
+
             if self.sync == "error":
                 if local_utterances != remote_utterances:
                     raise ValueError(
                         f"Synchronization error: Differences found in route '{route}'"
                     )
                 utterances_to_include: set = set()
+                if local_utterances:
+                    layer_routes[route] = list(local_utterances)
             elif self.sync == "remote":
                 utterances_to_include = set()
+                if remote_utterances:
+                    layer_routes[route] = list(remote_utterances)
             elif self.sync == "local":
                 utterances_to_include = local_utterances - remote_utterances
                 routes_to_delete.extend(
@@ -237,11 +254,17 @@ class PineconeIndex(BaseIndex):
                         if utterance not in local_utterances
                     ]
                 )
+                if local_utterances:
+                    layer_routes[route] = list(local_utterances)
             elif self.sync == "merge-force-remote":
                 if route in local_dict and route not in remote_dict:
                     utterances_to_include = local_utterances
+                    if local_utterances:
+                        layer_routes[route] = list(local_utterances)
                 else:
                     utterances_to_include = set()
+                    if remote_utterances:
+                        layer_routes[route] = list(remote_utterances)
             elif self.sync == "merge-force-local":
                 if route in local_dict:
                     utterances_to_include = local_utterances - remote_utterances
@@ -252,27 +275,27 @@ class PineconeIndex(BaseIndex):
                             if utterance not in local_utterances
                         ]
                     )
+                    if local_utterances:
+                        layer_routes[route] = local_utterances
                 else:
                     utterances_to_include = set()
+                    if remote_utterances:
+                        layer_routes[route] = list(remote_utterances)
             elif self.sync == "merge":
                 utterances_to_include = local_utterances - remote_utterances
+                if local_utterances or remote_utterances:
+                    layer_routes[route] = list(
+                        remote_utterances.union(local_utterances)
+                    )
             else:
                 raise ValueError("Invalid sync mode specified")
 
             for utterance in utterances_to_include:
-                indices = [
-                    i
-                    for i, x in enumerate(local_routes["utterances"])
-                    if x == utterance and local_routes["routes"][i] == route
-                ]
-                routes_to_add.extend(
-                    [
-                        (local_routes["embeddings"][idx], route, utterance)
-                        for idx in indices
-                    ]
-                )
+                routes_to_add.append((route, utterance))
+
+        logger.info(f"Layer routes: {layer_routes}")
 
-        return routes_to_add, routes_to_delete
+        return routes_to_add, routes_to_delete, layer_routes
 
     def _batch_upsert(self, batch: List[Dict]):
         """Helper method for upserting a single batch of records."""
@@ -301,54 +324,18 @@ class PineconeIndex(BaseIndex):
         for i in range(0, len(vectors_to_upsert), batch_size):
             batch = vectors_to_upsert[i : i + batch_size]
             self._batch_upsert(batch)
-
-    def _add_and_sync(
-        self,
-        embeddings: List[List[float]],
-        routes: List[str],
-        utterances: List[str],
-        batch_size: int = 100,
-    ):
-        """Add vectors to Pinecone in batches."""
-        if self.index is None:
-            self.dimensions = self.dimensions or len(embeddings[0])
-            self.index = self._init_index(force_create=True)
-
-        local_routes = {
-            "routes": routes,
-            "utterances": utterances,
-            "embeddings": embeddings,
-        }
-        if self.sync is not None:
-            data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes)
-            routes_to_delete: dict = {}
-            for route, utterance in data_to_delete:
-                routes_to_delete.setdefault(route, []).append(utterance)
-
-            for route, utterances in routes_to_delete.items():
-                remote_routes = self._get_routes_with_ids(route_name=route)
-                ids_to_delete = [
-                    r["id"]
-                    for r in remote_routes
-                    if (r["route"], r["utterance"])
-                    in zip([route] * len(utterances), utterances)
-                ]
-                if ids_to_delete and self.index:
-                    self.index.delete(ids=ids_to_delete)
-        else:
-            data_to_upsert = [
-                (vector, route, utterance)
-                for vector, route, utterance in zip(embeddings, routes, utterances)
+    
+    def _remove_and_sync(self, routes_to_delete: dict):
+        for route, utterances in routes_to_delete.items():
+            remote_routes = self._get_routes_with_ids(route_name=route)
+            ids_to_delete = [
+                r["id"]
+                for r in remote_routes
+                if (r["route"], r["utterance"])
+                in zip([route] * len(utterances), utterances)
             ]
-
-        vectors_to_upsert = [
-            PineconeRecord(values=vector, route=route, utterance=utterance).to_dict()
-            for vector, route, utterance in data_to_upsert
-        ]
-
-        for i in range(0, len(vectors_to_upsert), batch_size):
-            batch = vectors_to_upsert[i : i + batch_size]
-            self._batch_upsert(batch)
+            if ids_to_delete and self.index:
+                self.index.delete(ids=ids_to_delete)
 
     def _get_route_ids(self, route_name: str):
         clean_route = clean_route_name(route_name)
diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py
index 0fff2314..a77e6f88 100644
--- a/semantic_router/index/qdrant.py
+++ b/semantic_router/index/qdrant.py
@@ -160,16 +160,13 @@ class QdrantIndex(BaseIndex):
                 **self.config,
             )
 
-    def _add_and_sync(
-        self,
-        embeddings: List[List[float]],
-        routes: List[str],
-        utterances: List[str],
-        batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE,
-    ):
+    def _remove_and_sync(self, routes_to_delete: dict):
+        if self.sync is not None:
+            logger.error("Sync remove is not implemented for LocalIndex.")
+    
+    def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int):
         if self.sync is not None:
-            logger.warning("Sync add is not implemented for QdrantIndex")
-        self.add(embeddings, routes, utterances, batch_size)
+            logger.error("Sync remove is not implemented for QdrantIndex.")
 
     def add(
         self,
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 81f690b5..8888460d 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -217,8 +217,13 @@ class RouteLayer:
             if route.score_threshold is None:
                 route.score_threshold = self.score_threshold
         # if routes list has been passed, we initialize index now
-        if len(self.routes) > 0:
+        if self.index.sync:
             # initialize index now
+            if len(self.routes) > 0:
+                self._add_and_sync_routes(routes=self.routes)
+            else:
+                self._add_and_sync_routes(routes=[])
+        elif len(self.routes) > 0:
             self._add_routes(routes=self.routes)
 
     def check_for_matching_routes(self, top_class: str) -> Optional[Route]:
@@ -385,6 +390,14 @@ class RouteLayer:
         )
         return self._pass_threshold(scores, threshold)
 
+    def _set_layer_routes(self, new_routes: List[Route]):
+        """
+        Set and override the current routes with a new list of routes.
+
+        :param new_routes: List of Route objects to set as the current routes.
+        """
+        self.routes = new_routes
+
     def __str__(self):
         return (
             f"RouteLayer(encoder={self.encoder}, "
@@ -464,19 +477,57 @@ class RouteLayer:
 
     def _add_routes(self, routes: List[Route]):
         # create embeddings for all routes
-        all_utterances = [
-            utterance for route in routes for utterance in route.utterances
-        ]
+        route_names, all_utterances = self._extract_routes_details(routes)
         embedded_utterances = self.encoder(all_utterances)
         # create route array
-        route_names = [route.name for route in routes for _ in route.utterances]
         # add everything to the index
-        self.index._add_and_sync(
+        self.index.add(
             embeddings=embedded_utterances,
             routes=route_names,
             utterances=all_utterances,
         )
 
+
+    def _add_and_sync_routes(self, routes: List[Route]):
+        # create embeddings for all routes and sync at startup with remote ones based on sync setting
+        local_route_names, local_utterances = self._extract_routes_details(routes)
+        routes_to_add, routes_to_delete, layer_routes_dict = self.index._sync_index(
+            local_route_names=local_route_names,
+            local_utterances=local_utterances,
+            dimensions=len(self.encoder(["dummy"])[0])
+        )
+
+        logger.info(f"ROUTES TO ADD: {(routes_to_add)}")
+        logger.info(f"ROUTES TO DELETE: {(routes_to_delete)}")
+
+        layer_routes = [
+            Route(name=route, utterances=layer_routes_dict[route])
+            for route in layer_routes_dict.keys()
+        ]
+
+        data_to_delete: dict = {}
+        for route, utterance in routes_to_delete:
+            data_to_delete.setdefault(route, []).append(utterance)
+        self.index._remove_and_sync(data_to_delete)
+
+        all_utterances_to_add = [utt for _, utt in routes_to_add]
+        embedded_utterances_to_add = self.encoder(all_utterances_to_add) if all_utterances_to_add else []
+        
+        route_names_to_add = [route for route, _, in routes_to_add]
+
+        self.index.add(
+            embeddings=embedded_utterances_to_add,
+            routes=route_names_to_add,
+            utterances=all_utterances_to_add,
+        )
+        
+        self._set_layer_routes(layer_routes)
+
+    def _extract_routes_details(self, routes: List[Route]) -> Tuple[List[str], List[str]]:
+        route_names = [route.name for route in routes for _ in route.utterances]
+        utterances = [utterance for route in routes for utterance in route.utterances]
+        return route_names, utterances
+    
     def _encode(self, text: str) -> Any:
         """Given some text, encode it."""
         # create query vector
-- 
GitLab