From 54df32bb81096a489314e17c9294e05d2c523fc4 Mon Sep 17 00:00:00 2001
From: Vits <vittorio.mayellaro.dev@gmail.com>
Date: Thu, 11 Jul 2024 23:59:41 +0200
Subject: [PATCH] Added _add_and_sync to replace add for index syncing when
 adding routes at startup

---
 semantic_router/index/base.py     | 15 ++++++++--
 semantic_router/index/local.py    | 12 +++++++-
 semantic_router/index/pinecone.py | 47 +++++++++++++++++++++++--------
 semantic_router/index/qdrant.py   | 14 +++++++--
 semantic_router/layer.py          |  3 +-
 5 files changed, 71 insertions(+), 20 deletions(-)

diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py
index e53ca44f..9e226da7 100644
--- a/semantic_router/index/base.py
+++ b/semantic_router/index/base.py
@@ -18,14 +18,13 @@ class BaseIndex(BaseModel):
     utterances: Optional[np.ndarray] = None
     dimensions: Union[int, None] = None
     type: str = "base"
-    sync: str = "merge-force-local"
+    sync: Union[str, None] = None
 
     def add(
         self,
         embeddings: List[List[float]],
         routes: List[str],
         utterances: List[Any],
-        sync: bool = False,
     ):
         """
         Add embeddings to the index.
@@ -33,6 +32,18 @@ class BaseIndex(BaseModel):
         """
         raise NotImplementedError("This method should be implemented by subclasses.")
 
+    def _add_and_sync(
+        self,
+        embeddings: List[List[float]],
+        routes: List[str],
+        utterances: List[Any],
+    ):
+        """
+        Add embeddings to the index and manage index syncing if necessary.
+        This method should be implemented by subclasses.
+        """
+        raise NotImplementedError("This method should be implemented by subclasses.")
+
     def delete(self, route_name: str):
         """
         Deletes route by route name.
diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py
index b1108873..7e32f3a8 100644
--- a/semantic_router/index/local.py
+++ b/semantic_router/index/local.py
@@ -4,6 +4,7 @@ import numpy as np
 
 from semantic_router.index.base import BaseIndex
 from semantic_router.linear import similarity_matrix, top_scores
+from semantic_router.utils.logger import logger
 
 
 class LocalIndex(BaseIndex):
@@ -25,7 +26,6 @@ class LocalIndex(BaseIndex):
         embeddings: List[List[float]],
         routes: List[str],
         utterances: List[str],
-        sync: bool = False,
     ):
         embeds = np.array(embeddings)  # type: ignore
         routes_arr = np.array(routes)
@@ -42,6 +42,16 @@ class LocalIndex(BaseIndex):
             self.routes = np.concatenate([self.routes, routes_arr])
             self.utterances = np.concatenate([self.utterances, utterances_arr])
 
+    def _add_and_sync(
+        self,
+        embeddings: List[List[float]],
+        routes: List[str],
+        utterances: List[str],
+    ):
+        if self.sync is not None:
+            logger.warning("Sync add is not implemented for LocalIndex.")
+        self.add(embeddings, routes, utterances)
+
     def get_routes(self) -> List[Tuple]:
         """
         Gets a list of route and utterance objects currently stored in the index.
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 27038824..a9c93ccd 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -65,7 +65,7 @@ class PineconeIndex(BaseIndex):
         host: str = "",
         namespace: Optional[str] = "",
         base_url: Optional[str] = "https://api.pinecone.io",
-        sync: str = "merge-force-local",
+        sync: str = "local",
     ):
         super().__init__()
         self.index_name = index_name
@@ -282,7 +282,6 @@ class PineconeIndex(BaseIndex):
         embeddings: List[List[float]],
         routes: List[str],
         utterances: List[str],
-        sync: bool = False,
         batch_size: int = 100,
     ):
         """Add vectors to Pinecone in batches."""
@@ -290,14 +289,34 @@ class PineconeIndex(BaseIndex):
             self.dimensions = self.dimensions or len(embeddings[0])
             self.index = self._init_index(force_create=True)
 
-        if sync:
-            local_routes = {
-                "routes": routes,
-                "utterances": utterances,
-                "embeddings": embeddings,
-            }
-            data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes)
+        vectors_to_upsert = [
+            PineconeRecord(values=vector, route=route, utterance=utterance).to_dict()
+            for vector, route, utterance in zip(embeddings, routes, utterances)
+        ]
+
+        for i in range(0, len(vectors_to_upsert), batch_size):
+            batch = vectors_to_upsert[i : i + batch_size]
+            self._batch_upsert(batch)
+
+    def _add_and_sync(
+        self,
+        embeddings: List[List[float]],
+        routes: List[str],
+        utterances: List[str],
+        batch_size: int = 100,
+    ):
+        """Add vectors to Pinecone in batches."""
+        if self.index is None:
+            self.dimensions = self.dimensions or len(embeddings[0])
+            self.index = self._init_index(force_create=True)
 
+        local_routes = {
+            "routes": routes,
+            "utterances": utterances,
+            "embeddings": embeddings,
+        }
+        if self.sync is not None:
+            data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes)
             routes_to_delete: dict = {}
             for route, utterance in data_to_delete:
                 routes_to_delete.setdefault(route, []).append(utterance)
@@ -312,9 +331,11 @@ class PineconeIndex(BaseIndex):
                 ]
                 if ids_to_delete and self.index:
                     self.index.delete(ids=ids_to_delete)
-
         else:
-            data_to_upsert = zip(embeddings, routes, utterances)
+            data_to_upsert = [
+                (vector, route, utterance)
+                for vector, route, utterance in zip(embeddings, routes, utterances)
+            ]
 
         vectors_to_upsert = [
             PineconeRecord(values=vector, route=route, utterance=utterance).to_dict()
@@ -389,7 +410,9 @@ class PineconeIndex(BaseIndex):
                         if self.index
                         else {}
                     )
-                    metadata.extend([x["metadata"] for x in res_meta["vectors"].values()])
+                    metadata.extend(
+                        [x["metadata"] for x in res_meta["vectors"].values()]
+                    )
                 # extract metadata only
 
             # Check if there's a next page token; if not, break the loop
diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py
index 4bf8d893..0fff2314 100644
--- a/semantic_router/index/qdrant.py
+++ b/semantic_router/index/qdrant.py
@@ -160,16 +160,24 @@ class QdrantIndex(BaseIndex):
                 **self.config,
             )
 
-    def add(
+    def _add_and_sync(
         self,
         embeddings: List[List[float]],
         routes: List[str],
         utterances: List[str],
-        sync: bool = False,
         batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE,
     ):
-        if sync:
+        if self.sync is not None:
             logger.warning("Sync add is not implemented for QdrantIndex")
+        self.add(embeddings, routes, utterances, batch_size)
+
+    def add(
+        self,
+        embeddings: List[List[float]],
+        routes: List[str],
+        utterances: List[str],
+        batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE,
+    ):
         self.dimensions = self.dimensions or len(embeddings[0])
         self._init_collection()
 
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 3ac1596c..5c2d7228 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -466,11 +466,10 @@ class RouteLayer:
         # create route array
         route_names = [route.name for route in routes for _ in route.utterances]
         # add everything to the index
-        self.index.add(
+        self.index._add_and_sync(
             embeddings=embedded_utterances,
             routes=route_names,
             utterances=all_utterances,
-            sync=True,
         )
 
     def _encode(self, text: str) -> Any:
-- 
GitLab