From 5dbff6c15c3a68b7336dd876b63ebc32ff0e9c4c Mon Sep 17 00:00:00 2001 From: Vits <vittorio.mayellaro.dev@gmail.com> Date: Tue, 9 Jul 2024 23:04:52 +0200 Subject: [PATCH] Linting and formatting --- semantic_router/index/base.py | 10 ++-- semantic_router/index/local.py | 6 ++- semantic_router/index/pinecone.py | 82 ++++++++++++++++++++++++------- 3 files changed, 75 insertions(+), 23 deletions(-) diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 6d0969fc..e53ca44f 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -21,7 +21,11 @@ class BaseIndex(BaseModel): sync: str = "merge-force-local" def add( - self, embeddings: List[List[float]], routes: List[str], utterances: List[Any] + self, + embeddings: List[List[float]], + routes: List[str], + utterances: List[Any], + sync: bool = False, ): """ Add embeddings to the index. @@ -74,7 +78,7 @@ class BaseIndex(BaseModel): This method should be implemented by subclasses. """ raise NotImplementedError("This method should be implemented by subclasses.") - + def _sync_index(self, local_routes: dict): """ Synchronize the local index with the remote index based on the specified mode. @@ -85,7 +89,7 @@ class BaseIndex(BaseModel): - "merge-force-remote": Merge both local and remote taking only remote routes utterances when a route with same route name is present both locally and remotely. - "merge-force-local": Merge both local and remote taking only local routes utterances when a route with same route name is present both locally and remotely. - "merge": Merge both local and remote, merging also local and remote utterances when a route with same route name is present both locally and remotely. - + This method should be implemented by subclasses. """ raise NotImplementedError("This method should be implemented by subclasses.") diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index df9e02c1..b1108873 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -21,7 +21,11 @@ class LocalIndex(BaseIndex): arbitrary_types_allowed = True def add( - self, embeddings: List[List[float]], routes: List[str], utterances: List[str] + self, + embeddings: List[List[float]], + routes: List[str], + utterances: List[str], + sync: bool = False, ): embeds = np.array(embeddings) # type: ignore routes_arr = np.array(routes) diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 48c186ab..dc86004a 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -199,12 +199,12 @@ class PineconeIndex(BaseIndex): def _sync_index(self, local_routes: dict): remote_routes = self.get_routes() - remote_dict = {route: set() for route, _ in remote_routes} + remote_dict: dict = {route: set() for route, _ in remote_routes} for route, utterance in remote_routes: remote_dict[route].add(utterance) - local_dict = {route: set() for route in local_routes['routes']} - for route, utterance in zip(local_routes['routes'], local_routes['utterances']): + local_dict: dict = {route: set() for route in local_routes["routes"]} + for route, utterance in zip(local_routes["routes"], local_routes["utterances"]): local_dict[route].add(utterance) all_routes = set(remote_dict.keys()).union(local_dict.keys()) @@ -218,13 +218,21 @@ class PineconeIndex(BaseIndex): if self.sync == "error": if local_utterances != remote_utterances: - raise ValueError(f"Synchronization error: Differences found in route '{route}'") - utterances_to_include = set() + raise ValueError( + f"Synchronization error: Differences found in route '{route}'" + ) + utterances_to_include: set = set() elif self.sync == "remote": utterances_to_include = set() elif self.sync == "local": utterances_to_include = local_utterances - remote_utterances - routes_to_delete.extend([(route, utterance) for utterance in remote_utterances if utterance not in local_utterances]) + routes_to_delete.extend( + [ + (route, utterance) + for utterance in remote_utterances + if utterance not in local_utterances + ] + ) elif self.sync == "merge-force-remote": if route in local_dict and route not in remote_dict: utterances_to_include = local_utterances @@ -233,7 +241,13 @@ class PineconeIndex(BaseIndex): elif self.sync == "merge-force-local": if route in local_dict: utterances_to_include = local_utterances - remote_utterances - routes_to_delete.extend([(route, utterance) for utterance in remote_utterances if utterance not in local_utterances]) + routes_to_delete.extend( + [ + (route, utterance) + for utterance in remote_utterances + if utterance not in local_utterances + ] + ) else: utterances_to_include = set() elif self.sync == "merge": @@ -242,12 +256,20 @@ class PineconeIndex(BaseIndex): raise ValueError("Invalid sync mode specified") for utterance in utterances_to_include: - indices = [i for i, x in enumerate(local_routes['utterances']) if x == utterance and local_routes['routes'][i] == route] - routes_to_add.extend([(local_routes['embeddings'][idx], route, utterance) for idx in indices]) + indices = [ + i + for i, x in enumerate(local_routes["utterances"]) + if x == utterance and local_routes["routes"][i] == route + ] + routes_to_add.extend( + [ + (local_routes["embeddings"][idx], route, utterance) + for idx in indices + ] + ) return routes_to_add, routes_to_delete - def _batch_upsert(self, batch: List[Dict]): """Helper method for upserting a single batch of records.""" if self.index is not None: @@ -260,8 +282,8 @@ class PineconeIndex(BaseIndex): embeddings: List[List[float]], routes: List[str], utterances: List[str], - batch_size: int = 100, sync: bool = False, + batch_size: int = 100, ): """Add vectors to Pinecone in batches.""" if self.index is None: @@ -269,19 +291,28 @@ class PineconeIndex(BaseIndex): self.index = self._init_index(force_create=True) if sync: - local_routes = {"routes": routes, "utterances": utterances, "embeddings": embeddings} + local_routes = { + "routes": routes, + "utterances": utterances, + "embeddings": embeddings, + } data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes) - routes_to_delete = {} + routes_to_delete: dict = {} for route, utterance in data_to_delete: routes_to_delete.setdefault(route, []).append(utterance) for route, utterances in routes_to_delete.items(): remote_routes = self._get_routes_with_ids(route_name=route) - ids_to_delete = [r["id"] for r in remote_routes if (r["route"], r["utterance"]) in zip([route]*len(utterances), utterances)] - if ids_to_delete: + ids_to_delete = [ + r["id"] + for r in remote_routes + if (r["route"], r["utterance"]) + in zip([route] * len(utterances), utterances) + ] + if ids_to_delete and self.index: self.index.delete(ids=ids_to_delete) - + else: data_to_upsert = zip(embeddings, routes, utterances) @@ -298,14 +329,27 @@ class PineconeIndex(BaseIndex): clean_route = clean_route_name(route_name) ids, _ = self._get_all(prefix=f"{clean_route}#") return ids - + def _get_routes_with_ids(self, route_name: str): clean_route = clean_route_name(route_name) ids, _ = self._get_all(prefix=f"{clean_route}#") route_tuples = [] for id in ids: - res_meta = self.index.fetch(ids=[id], namespace=self.namespace) - route_tuples.extend([{"id": id, "route": x["metadata"]["sr_route"], "utterance": x["metadata"]["sr_utterance"]} for x in res_meta["vectors"].values()]) + res_meta = ( + self.index.fetch(ids=[id], namespace=self.namespace) + if self.index + else {} + ) + route_tuples.extend( + [ + { + "id": id, + "route": x["metadata"]["sr_route"], + "utterance": x["metadata"]["sr_utterance"], + } + for x in res_meta["vectors"].values() + ] + ) return route_tuples def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False): -- GitLab