From 5dbff6c15c3a68b7336dd876b63ebc32ff0e9c4c Mon Sep 17 00:00:00 2001
From: Vits <vittorio.mayellaro.dev@gmail.com>
Date: Tue, 9 Jul 2024 23:04:52 +0200
Subject: [PATCH] Linting and formatting

---
 semantic_router/index/base.py     | 10 ++--
 semantic_router/index/local.py    |  6 ++-
 semantic_router/index/pinecone.py | 82 ++++++++++++++++++++++++-------
 3 files changed, 75 insertions(+), 23 deletions(-)

diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py
index 6d0969fc..e53ca44f 100644
--- a/semantic_router/index/base.py
+++ b/semantic_router/index/base.py
@@ -21,7 +21,11 @@ class BaseIndex(BaseModel):
     sync: str = "merge-force-local"
 
     def add(
-        self, embeddings: List[List[float]], routes: List[str], utterances: List[Any]
+        self,
+        embeddings: List[List[float]],
+        routes: List[str],
+        utterances: List[Any],
+        sync: bool = False,
     ):
         """
         Add embeddings to the index.
@@ -74,7 +78,7 @@ class BaseIndex(BaseModel):
         This method should be implemented by subclasses.
         """
         raise NotImplementedError("This method should be implemented by subclasses.")
-    
+
     def _sync_index(self, local_routes: dict):
         """
         Synchronize the local index with the remote index based on the specified mode.
@@ -85,7 +89,7 @@ class BaseIndex(BaseModel):
         - "merge-force-remote": Merge both local and remote taking only remote routes utterances when a route with same route name is present both locally and remotely.
         - "merge-force-local": Merge both local and remote taking only local routes utterances when a route with same route name is present both locally and remotely.
         - "merge": Merge both local and remote, merging also local and remote utterances when a route with same route name is present both locally and remotely.
-        
+
         This method should be implemented by subclasses.
         """
         raise NotImplementedError("This method should be implemented by subclasses.")
diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py
index df9e02c1..b1108873 100644
--- a/semantic_router/index/local.py
+++ b/semantic_router/index/local.py
@@ -21,7 +21,11 @@ class LocalIndex(BaseIndex):
         arbitrary_types_allowed = True
 
     def add(
-        self, embeddings: List[List[float]], routes: List[str], utterances: List[str]
+        self,
+        embeddings: List[List[float]],
+        routes: List[str],
+        utterances: List[str],
+        sync: bool = False,
     ):
         embeds = np.array(embeddings)  # type: ignore
         routes_arr = np.array(routes)
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 48c186ab..dc86004a 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -199,12 +199,12 @@ class PineconeIndex(BaseIndex):
 
     def _sync_index(self, local_routes: dict):
         remote_routes = self.get_routes()
-        remote_dict = {route: set() for route, _ in remote_routes}
+        remote_dict: dict = {route: set() for route, _ in remote_routes}
         for route, utterance in remote_routes:
             remote_dict[route].add(utterance)
 
-        local_dict = {route: set() for route in local_routes['routes']}
-        for route, utterance in zip(local_routes['routes'], local_routes['utterances']):
+        local_dict: dict = {route: set() for route in local_routes["routes"]}
+        for route, utterance in zip(local_routes["routes"], local_routes["utterances"]):
             local_dict[route].add(utterance)
 
         all_routes = set(remote_dict.keys()).union(local_dict.keys())
@@ -218,13 +218,21 @@ class PineconeIndex(BaseIndex):
 
             if self.sync == "error":
                 if local_utterances != remote_utterances:
-                    raise ValueError(f"Synchronization error: Differences found in route '{route}'")
-                utterances_to_include = set()
+                    raise ValueError(
+                        f"Synchronization error: Differences found in route '{route}'"
+                    )
+                utterances_to_include: set = set()
             elif self.sync == "remote":
                 utterances_to_include = set()
             elif self.sync == "local":
                 utterances_to_include = local_utterances - remote_utterances
-                routes_to_delete.extend([(route, utterance) for utterance in remote_utterances if utterance not in local_utterances])
+                routes_to_delete.extend(
+                    [
+                        (route, utterance)
+                        for utterance in remote_utterances
+                        if utterance not in local_utterances
+                    ]
+                )
             elif self.sync == "merge-force-remote":
                 if route in local_dict and route not in remote_dict:
                     utterances_to_include = local_utterances
@@ -233,7 +241,13 @@ class PineconeIndex(BaseIndex):
             elif self.sync == "merge-force-local":
                 if route in local_dict:
                     utterances_to_include = local_utterances - remote_utterances
-                    routes_to_delete.extend([(route, utterance) for utterance in remote_utterances if utterance not in local_utterances])
+                    routes_to_delete.extend(
+                        [
+                            (route, utterance)
+                            for utterance in remote_utterances
+                            if utterance not in local_utterances
+                        ]
+                    )
                 else:
                     utterances_to_include = set()
             elif self.sync == "merge":
@@ -242,12 +256,20 @@ class PineconeIndex(BaseIndex):
                 raise ValueError("Invalid sync mode specified")
 
             for utterance in utterances_to_include:
-                indices = [i for i, x in enumerate(local_routes['utterances']) if x == utterance and local_routes['routes'][i] == route]
-                routes_to_add.extend([(local_routes['embeddings'][idx], route, utterance) for idx in indices])
+                indices = [
+                    i
+                    for i, x in enumerate(local_routes["utterances"])
+                    if x == utterance and local_routes["routes"][i] == route
+                ]
+                routes_to_add.extend(
+                    [
+                        (local_routes["embeddings"][idx], route, utterance)
+                        for idx in indices
+                    ]
+                )
 
         return routes_to_add, routes_to_delete
 
-
     def _batch_upsert(self, batch: List[Dict]):
         """Helper method for upserting a single batch of records."""
         if self.index is not None:
@@ -260,8 +282,8 @@ class PineconeIndex(BaseIndex):
         embeddings: List[List[float]],
         routes: List[str],
         utterances: List[str],
-        batch_size: int = 100,
         sync: bool = False,
+        batch_size: int = 100,
     ):
         """Add vectors to Pinecone in batches."""
         if self.index is None:
@@ -269,19 +291,28 @@ class PineconeIndex(BaseIndex):
             self.index = self._init_index(force_create=True)
 
         if sync:
-            local_routes = {"routes": routes, "utterances": utterances, "embeddings": embeddings}
+            local_routes = {
+                "routes": routes,
+                "utterances": utterances,
+                "embeddings": embeddings,
+            }
             data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes)
 
-            routes_to_delete = {}
+            routes_to_delete: dict = {}
             for route, utterance in data_to_delete:
                 routes_to_delete.setdefault(route, []).append(utterance)
 
             for route, utterances in routes_to_delete.items():
                 remote_routes = self._get_routes_with_ids(route_name=route)
-                ids_to_delete = [r["id"] for r in remote_routes if (r["route"], r["utterance"]) in zip([route]*len(utterances), utterances)]
-                if ids_to_delete:
+                ids_to_delete = [
+                    r["id"]
+                    for r in remote_routes
+                    if (r["route"], r["utterance"])
+                    in zip([route] * len(utterances), utterances)
+                ]
+                if ids_to_delete and self.index:
                     self.index.delete(ids=ids_to_delete)
-                
+
         else:
             data_to_upsert = zip(embeddings, routes, utterances)
 
@@ -298,14 +329,27 @@ class PineconeIndex(BaseIndex):
         clean_route = clean_route_name(route_name)
         ids, _ = self._get_all(prefix=f"{clean_route}#")
         return ids
-    
+
     def _get_routes_with_ids(self, route_name: str):
         clean_route = clean_route_name(route_name)
         ids, _ = self._get_all(prefix=f"{clean_route}#")
         route_tuples = []
         for id in ids:
-            res_meta = self.index.fetch(ids=[id], namespace=self.namespace)
-            route_tuples.extend([{"id": id, "route": x["metadata"]["sr_route"], "utterance": x["metadata"]["sr_utterance"]} for x in res_meta["vectors"].values()])
+            res_meta = (
+                self.index.fetch(ids=[id], namespace=self.namespace)
+                if self.index
+                else {}
+            )
+            route_tuples.extend(
+                [
+                    {
+                        "id": id,
+                        "route": x["metadata"]["sr_route"],
+                        "utterance": x["metadata"]["sr_utterance"],
+                    }
+                    for x in res_meta["vectors"].values()
+                ]
+            )
         return route_tuples
 
     def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False):
-- 
GitLab