From ea7d3fb93dadb7994750efc44360f03a8d246690 Mon Sep 17 00:00:00 2001 From: Vittorio <vittorio.mayellaro.dev@gmail.com> Date: Fri, 4 Oct 2024 13:59:09 +0200 Subject: [PATCH] Implemented is_synced for Pinecone Index and tested it --- semantic_router/index/base.py | 13 +++++ semantic_router/index/pinecone.py | 86 ++++++++++++++++++++++++++----- semantic_router/layer.py | 9 ++++ 3 files changed, 94 insertions(+), 14 deletions(-) diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 5ddb586e..be1be557 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -110,6 +110,19 @@ class BaseIndex(BaseModel): """ raise NotImplementedError("This method should be implemented by subclasses.") + def is_synced( + self, + local_route_names: List[str], + local_utterances_list: List[str], + local_function_schemas_list: List[Dict[str, Any]], + local_metadata_list: List[Dict[str, Any]] + ) -> bool: + """ + Checks whether local and remote index are synchronized. + This method should be implemented by subclasses. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + def _sync_index( self, local_route_names: List[str], diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 08b714d2..7c304c09 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -215,21 +215,14 @@ class PineconeIndex(BaseIndex): logger.warning("Index could not be initialized.") self.host = index_stats["host"] if index_stats else None - def _sync_index( + def _format_routes_dict_for_sync( self, - local_route_names: List[str], - local_utterances_list: List[str], - local_function_schemas_list: List[Dict[str, Any]], + local_route_names: List[str], + local_utterances_list: List[str], + local_function_schemas_list: List[Dict[str, Any]], local_metadata_list: List[Dict[str, Any]], - dimensions: int, - ) -> Tuple[List, List, Dict]: - if self.index is None: - self.dimensions = self.dimensions or dimensions - self.index = self._init_index(force_create=True) - - remote_routes = self.get_routes() - - # Create remote dictionary for storing utterances and metadata + remote_routes: List[Tuple] + ) -> Tuple[Dict, Dict]: remote_dict: Dict[str, Dict[str, Any]] = { route: { "utterances": set(), @@ -241,7 +234,6 @@ class PineconeIndex(BaseIndex): for route, utterance, function_schemas, metadata in remote_routes: remote_dict[route]["utterances"].add(utterance) - # Create local dictionary for storing utterances and metadata local_dict: Dict[str, Dict[str, Any]] = {} for route, utterance, function_schemas, metadata in zip( local_route_names, @@ -259,6 +251,72 @@ class PineconeIndex(BaseIndex): local_dict[route]["function_schemas"] = function_schemas local_dict[route]["metadata"] = metadata + return local_dict, remote_dict + + def is_synced( + self, + local_route_names: List[str], + local_utterances_list: List[str], + local_function_schemas_list: List[Dict[str, Any]], + local_metadata_list: List[Dict[str, Any]] + ) -> bool: + remote_routes = self.get_routes() + + local_dict, remote_dict = self._format_routes_dict_for_sync( + local_route_names, + local_utterances_list, + local_function_schemas_list, + local_metadata_list, + remote_routes + ) + logger.info(f"LOCAL: {local_dict}") + logger.info(f"REMOTE: {remote_dict}") + + all_routes = set(remote_dict.keys()).union(local_dict.keys()) + + for route in all_routes: + local_utterances = local_dict.get(route, {}).get("utterances", set()) + remote_utterances = remote_dict.get(route, {}).get("utterances", set()) + local_function_schemas = ( + local_dict.get(route, {}).get("function_schemas", {}) or {} + ) + remote_function_schemas = ( + remote_dict.get(route, {}).get("function_schemas", {}) or {} + ) + local_metadata = local_dict.get(route, {}).get("metadata", {}) + remote_metadata = remote_dict.get(route, {}).get("metadata", {}) + + if ( + local_utterances != remote_utterances + or local_function_schemas != remote_function_schemas + or local_metadata != remote_metadata + ): + return False + + return True + + def _sync_index( + self, + local_route_names: List[str], + local_utterances_list: List[str], + local_function_schemas_list: List[Dict[str, Any]], + local_metadata_list: List[Dict[str, Any]], + dimensions: int, + ) -> Tuple[List, List, Dict]: + if self.index is None: + self.dimensions = self.dimensions or dimensions + self.index = self._init_index(force_create=True) + + remote_routes = self.get_routes() + + local_dict, remote_dict = self._format_routes_dict_for_sync( + local_route_names, + local_utterances_list, + local_function_schemas_list, + local_metadata_list, + remote_routes + ) + all_routes = set(remote_dict.keys()).union(local_dict.keys()) routes_to_add = [] diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 158c12be..a597e938 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -526,6 +526,15 @@ class RouteLayer: logger.error(f"Failed to add routes to the index: {e}") raise Exception("Indexing error occurred") from e + def is_synced(self) -> bool: + if not self.index.sync: + raise ValueError("Index is not set to sync with remote index.") + + local_route_names, local_utterances, local_function_schemas, local_metadata = ( + self._extract_routes_details(self.routes, include_metadata=True) + ) + return self.index.is_synced(local_route_names, local_utterances, local_function_schemas, local_metadata) + def _add_and_sync_routes(self, routes: List[Route]): # create embeddings for all routes and sync at startup with remote ones based on sync setting local_route_names, local_utterances, local_function_schemas, local_metadata = ( -- GitLab