From 9e30a21fc9077244d2b7588870176442f0bb846f Mon Sep 17 00:00:00 2001 From: Vits <vittorio.mayellaro.dev@gmail.com> Date: Wed, 31 Jul 2024 11:31:11 +0200 Subject: [PATCH] Linting, formatting and removing unnecessary prints --- semantic_router/index/base.py | 6 ++++-- semantic_router/index/local.py | 4 +++- semantic_router/index/pinecone.py | 8 ++++---- semantic_router/index/qdrant.py | 6 ++++-- semantic_router/layer.py | 17 ++++++++++------- 5 files changed, 25 insertions(+), 16 deletions(-) diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 9eb99532..c51c3ff2 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -32,7 +32,7 @@ class BaseIndex(BaseModel): This method should be implemented by subclasses. """ raise NotImplementedError("This method should be implemented by subclasses.") - + def _remove_and_sync(self, routes_to_delete: dict): """ Remove embeddings in a routes syncing process from the index. @@ -86,7 +86,9 @@ class BaseIndex(BaseModel): """ raise NotImplementedError("This method should be implemented by subclasses.") - def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): + def _sync_index( + self, local_route_names: List[str], local_utterances: List[str], dimensions: int + ): """ Synchronize the local index with the remote index based on the specified mode. Modes: diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index dbc41f1a..7150b267 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -46,7 +46,9 @@ class LocalIndex(BaseIndex): if self.sync is not None: logger.warning("Sync remove is not implemented for LocalIndex.") - def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): + def _sync_index( + self, local_route_names: List[str], local_utterances: List[str], dimensions: int + ): if self.sync is not None: logger.error("Sync remove is not implemented for LocalIndex.") diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 1908d905..97777df5 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -201,7 +201,9 @@ class PineconeIndex(BaseIndex): logger.warning("Index could not be initialized.") self.host = index_stats["host"] if index_stats else None - def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): + def _sync_index( + self, local_route_names: List[str], local_utterances: List[str], dimensions: int + ): if self.index is None: self.dimensions = self.dimensions or dimensions self.index = self._init_index(force_create=True) @@ -253,7 +255,7 @@ class PineconeIndex(BaseIndex): layer_routes[route] = list(local_utterances) elif self.sync == "merge-force-remote": if route in local_dict and route not in remote_dict: - utterances_to_include = local_utterances + utterances_to_include = set(local_utterances) if local_utterances: layer_routes[route] = list(local_utterances) else: @@ -288,8 +290,6 @@ class PineconeIndex(BaseIndex): for utterance in utterances_to_include: routes_to_add.append((route, utterance)) - logger.info(f"Layer routes: {layer_routes}") - return routes_to_add, routes_to_delete, layer_routes def _batch_upsert(self, batch: List[Dict]): diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index a77e6f88..c6e0dc91 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -163,8 +163,10 @@ class QdrantIndex(BaseIndex): def _remove_and_sync(self, routes_to_delete: dict): if self.sync is not None: logger.error("Sync remove is not implemented for LocalIndex.") - - def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): + + def _sync_index( + self, local_route_names: List[str], local_utterances: List[str], dimensions: int + ): if self.sync is not None: logger.error("Sync remove is not implemented for QdrantIndex.") diff --git a/semantic_router/layer.py b/semantic_router/layer.py index c94d5731..63f2c667 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -487,14 +487,13 @@ class RouteLayer: utterances=all_utterances, ) - def _add_and_sync_routes(self, routes: List[Route]): # create embeddings for all routes and sync at startup with remote ones based on sync setting local_route_names, local_utterances = self._extract_routes_details(routes) routes_to_add, routes_to_delete, layer_routes_dict = self.index._sync_index( local_route_names=local_route_names, local_utterances=local_utterances, - dimensions=len(self.encoder(["dummy"])[0]) + dimensions=len(self.encoder(["dummy"])[0]), ) layer_routes = [ @@ -508,8 +507,10 @@ class RouteLayer: self.index._remove_and_sync(data_to_delete) all_utterances_to_add = [utt for _, utt in routes_to_add] - embedded_utterances_to_add = self.encoder(all_utterances_to_add) if all_utterances_to_add else [] - + embedded_utterances_to_add = ( + self.encoder(all_utterances_to_add) if all_utterances_to_add else [] + ) + route_names_to_add = [route for route, _, in routes_to_add] self.index.add( @@ -517,14 +518,16 @@ class RouteLayer: routes=route_names_to_add, utterances=all_utterances_to_add, ) - + self._set_layer_routes(layer_routes) - def _extract_routes_details(self, routes: List[Route]) -> Tuple[List[str], List[str]]: + def _extract_routes_details( + self, routes: List[Route] + ) -> Tuple: route_names = [route.name for route in routes for _ in route.utterances] utterances = [utterance for route in routes for utterance in route.utterances] return route_names, utterances - + def _encode(self, text: str) -> Any: """Given some text, encode it.""" # create query vector -- GitLab