diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 9eb99532eb0db9d6128260f77f7fce1027393e6d..c51c3ff2a543d4496ee000fd8e2776faaa0f32cc 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -32,7 +32,7 @@ class BaseIndex(BaseModel): This method should be implemented by subclasses. """ raise NotImplementedError("This method should be implemented by subclasses.") - + def _remove_and_sync(self, routes_to_delete: dict): """ Remove embeddings in a routes syncing process from the index. @@ -86,7 +86,9 @@ class BaseIndex(BaseModel): """ raise NotImplementedError("This method should be implemented by subclasses.") - def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): + def _sync_index( + self, local_route_names: List[str], local_utterances: List[str], dimensions: int + ): """ Synchronize the local index with the remote index based on the specified mode. Modes: diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index dbc41f1a6af2ffc11ec05a4d6f6c208bb5a7e358..7150b267587715d09124eb6f94ab242a7795c3f9 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -46,7 +46,9 @@ class LocalIndex(BaseIndex): if self.sync is not None: logger.warning("Sync remove is not implemented for LocalIndex.") - def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): + def _sync_index( + self, local_route_names: List[str], local_utterances: List[str], dimensions: int + ): if self.sync is not None: logger.error("Sync remove is not implemented for LocalIndex.") diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 1908d90530ce5984438962f1601f926f243195d1..97777df50109757e841323f736ede343825e7f57 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -201,7 +201,9 @@ class PineconeIndex(BaseIndex): logger.warning("Index could not be initialized.") self.host = index_stats["host"] if index_stats else None - def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): + def _sync_index( + self, local_route_names: List[str], local_utterances: List[str], dimensions: int + ): if self.index is None: self.dimensions = self.dimensions or dimensions self.index = self._init_index(force_create=True) @@ -253,7 +255,7 @@ class PineconeIndex(BaseIndex): layer_routes[route] = list(local_utterances) elif self.sync == "merge-force-remote": if route in local_dict and route not in remote_dict: - utterances_to_include = local_utterances + utterances_to_include = set(local_utterances) if local_utterances: layer_routes[route] = list(local_utterances) else: @@ -288,8 +290,6 @@ class PineconeIndex(BaseIndex): for utterance in utterances_to_include: routes_to_add.append((route, utterance)) - logger.info(f"Layer routes: {layer_routes}") - return routes_to_add, routes_to_delete, layer_routes def _batch_upsert(self, batch: List[Dict]): diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index a77e6f888ba8023cd5ce6d3fc955bd6b1f0615ce..c6e0dc912819ade5a7e11e3fb86d140e8158fad4 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -163,8 +163,10 @@ class QdrantIndex(BaseIndex): def _remove_and_sync(self, routes_to_delete: dict): if self.sync is not None: logger.error("Sync remove is not implemented for LocalIndex.") - - def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): + + def _sync_index( + self, local_route_names: List[str], local_utterances: List[str], dimensions: int + ): if self.sync is not None: logger.error("Sync remove is not implemented for QdrantIndex.") diff --git a/semantic_router/layer.py b/semantic_router/layer.py index c94d5731544accff48b124cf088036e64414cb7f..63f2c6673041f97ad921a2f09771e01d20f78fff 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -487,14 +487,13 @@ class RouteLayer: utterances=all_utterances, ) - def _add_and_sync_routes(self, routes: List[Route]): # create embeddings for all routes and sync at startup with remote ones based on sync setting local_route_names, local_utterances = self._extract_routes_details(routes) routes_to_add, routes_to_delete, layer_routes_dict = self.index._sync_index( local_route_names=local_route_names, local_utterances=local_utterances, - dimensions=len(self.encoder(["dummy"])[0]) + dimensions=len(self.encoder(["dummy"])[0]), ) layer_routes = [ @@ -508,8 +507,10 @@ class RouteLayer: self.index._remove_and_sync(data_to_delete) all_utterances_to_add = [utt for _, utt in routes_to_add] - embedded_utterances_to_add = self.encoder(all_utterances_to_add) if all_utterances_to_add else [] - + embedded_utterances_to_add = ( + self.encoder(all_utterances_to_add) if all_utterances_to_add else [] + ) + route_names_to_add = [route for route, _, in routes_to_add] self.index.add( @@ -517,14 +518,16 @@ class RouteLayer: routes=route_names_to_add, utterances=all_utterances_to_add, ) - + self._set_layer_routes(layer_routes) - def _extract_routes_details(self, routes: List[Route]) -> Tuple[List[str], List[str]]: + def _extract_routes_details( + self, routes: List[Route] + ) -> Tuple: route_names = [route.name for route in routes for _ in route.utterances] utterances = [utterance for route in routes for utterance in route.utterances] return route_names, utterances - + def _encode(self, text: str) -> Any: """Given some text, encode it.""" # create query vector