diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 76388d1dd7dfcdde1d5630324f82361a1341936b..9eb99532eb0db9d6128260f77f7fce1027393e6d 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -32,15 +32,10 @@ class BaseIndex(BaseModel): This method should be implemented by subclasses. """ raise NotImplementedError("This method should be implemented by subclasses.") - - def _add_and_sync( - self, - embeddings: List[List[float]], - routes: List[str], - utterances: List[Any], - ): + + def _remove_and_sync(self, routes_to_delete: dict): """ - Add embeddings to the index and manage index syncing if necessary. + Remove embeddings in a routes syncing process from the index. This method should be implemented by subclasses. """ raise NotImplementedError("This method should be implemented by subclasses.") @@ -91,7 +86,7 @@ class BaseIndex(BaseModel): """ raise NotImplementedError("This method should be implemented by subclasses.") - def _sync_index(self, local_routes: dict): + def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): """ Synchronize the local index with the remote index based on the specified mode. Modes: diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index 5426ec76d8efc856a2c036a6dcd8afc33d69d091..dbc41f1a6af2ffc11ec05a4d6f6c208bb5a7e358 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -42,15 +42,13 @@ class LocalIndex(BaseIndex): self.routes = np.concatenate([self.routes, routes_arr]) self.utterances = np.concatenate([self.utterances, utterances_arr]) - def _add_and_sync( - self, - embeddings: List[List[float]], - routes: List[str], - utterances: List[str], - ): + def _remove_and_sync(self, routes_to_delete: dict): + if self.sync is not None: + logger.warning("Sync remove is not implemented for LocalIndex.") + + def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): if self.sync is not None: - logger.warning("Sync add is not implemented for LocalIndex.") - self.add(embeddings, routes, utterances) + logger.error("Sync remove is not implemented for LocalIndex.") def get_routes(self) -> List[Tuple]: """ diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index a578eb01b2edbf1177dd9a9ddb79c1dbaa425173..e88fa1485e151720fe1a39767043b0367f159358 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -11,6 +11,7 @@ from pydantic.v1 import BaseModel, Field from semantic_router.index.base import BaseIndex from semantic_router.utils.logger import logger +from semantic_router.route import Route def clean_route_name(route_name: str) -> str: @@ -201,33 +202,49 @@ class PineconeIndex(BaseIndex): logger.warning("Index could not be initialized.") self.host = index_stats["host"] if index_stats else None - def _sync_index(self, local_routes: dict): + def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): + if self.index is None: + self.dimensions = self.dimensions or dimensions + self.index = self._init_index(force_create=True) + remote_routes = self.get_routes() + remote_dict: dict = {route: set() for route, _ in remote_routes} for route, utterance in remote_routes: remote_dict[route].add(utterance) - local_dict: dict = {route: set() for route in local_routes["routes"]} - for route, utterance in zip(local_routes["routes"], local_routes["utterances"]): + local_dict: dict = {route: set() for route in local_route_names} + for route, utterance in zip(local_route_names, local_utterances): local_dict[route].add(utterance) + logger.info(f"Local routes: {local_dict}") + logger.info(f"Remote routes: {remote_dict}") + all_routes = set(remote_dict.keys()).union(local_dict.keys()) routes_to_add = [] routes_to_delete = [] + layer_routes = {} for route in all_routes: local_utterances = local_dict.get(route, set()) remote_utterances = remote_dict.get(route, set()) + if not local_utterances and not remote_utterances: + continue + if self.sync == "error": if local_utterances != remote_utterances: raise ValueError( f"Synchronization error: Differences found in route '{route}'" ) utterances_to_include: set = set() + if local_utterances: + layer_routes[route] = list(local_utterances) elif self.sync == "remote": utterances_to_include = set() + if remote_utterances: + layer_routes[route] = list(remote_utterances) elif self.sync == "local": utterances_to_include = local_utterances - remote_utterances routes_to_delete.extend( @@ -237,11 +254,17 @@ class PineconeIndex(BaseIndex): if utterance not in local_utterances ] ) + if local_utterances: + layer_routes[route] = list(local_utterances) elif self.sync == "merge-force-remote": if route in local_dict and route not in remote_dict: utterances_to_include = local_utterances + if local_utterances: + layer_routes[route] = list(local_utterances) else: utterances_to_include = set() + if remote_utterances: + layer_routes[route] = list(remote_utterances) elif self.sync == "merge-force-local": if route in local_dict: utterances_to_include = local_utterances - remote_utterances @@ -252,27 +275,27 @@ class PineconeIndex(BaseIndex): if utterance not in local_utterances ] ) + if local_utterances: + layer_routes[route] = local_utterances else: utterances_to_include = set() + if remote_utterances: + layer_routes[route] = list(remote_utterances) elif self.sync == "merge": utterances_to_include = local_utterances - remote_utterances + if local_utterances or remote_utterances: + layer_routes[route] = list( + remote_utterances.union(local_utterances) + ) else: raise ValueError("Invalid sync mode specified") for utterance in utterances_to_include: - indices = [ - i - for i, x in enumerate(local_routes["utterances"]) - if x == utterance and local_routes["routes"][i] == route - ] - routes_to_add.extend( - [ - (local_routes["embeddings"][idx], route, utterance) - for idx in indices - ] - ) + routes_to_add.append((route, utterance)) + + logger.info(f"Layer routes: {layer_routes}") - return routes_to_add, routes_to_delete + return routes_to_add, routes_to_delete, layer_routes def _batch_upsert(self, batch: List[Dict]): """Helper method for upserting a single batch of records.""" @@ -301,54 +324,18 @@ class PineconeIndex(BaseIndex): for i in range(0, len(vectors_to_upsert), batch_size): batch = vectors_to_upsert[i : i + batch_size] self._batch_upsert(batch) - - def _add_and_sync( - self, - embeddings: List[List[float]], - routes: List[str], - utterances: List[str], - batch_size: int = 100, - ): - """Add vectors to Pinecone in batches.""" - if self.index is None: - self.dimensions = self.dimensions or len(embeddings[0]) - self.index = self._init_index(force_create=True) - - local_routes = { - "routes": routes, - "utterances": utterances, - "embeddings": embeddings, - } - if self.sync is not None: - data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes) - routes_to_delete: dict = {} - for route, utterance in data_to_delete: - routes_to_delete.setdefault(route, []).append(utterance) - - for route, utterances in routes_to_delete.items(): - remote_routes = self._get_routes_with_ids(route_name=route) - ids_to_delete = [ - r["id"] - for r in remote_routes - if (r["route"], r["utterance"]) - in zip([route] * len(utterances), utterances) - ] - if ids_to_delete and self.index: - self.index.delete(ids=ids_to_delete) - else: - data_to_upsert = [ - (vector, route, utterance) - for vector, route, utterance in zip(embeddings, routes, utterances) + + def _remove_and_sync(self, routes_to_delete: dict): + for route, utterances in routes_to_delete.items(): + remote_routes = self._get_routes_with_ids(route_name=route) + ids_to_delete = [ + r["id"] + for r in remote_routes + if (r["route"], r["utterance"]) + in zip([route] * len(utterances), utterances) ] - - vectors_to_upsert = [ - PineconeRecord(values=vector, route=route, utterance=utterance).to_dict() - for vector, route, utterance in data_to_upsert - ] - - for i in range(0, len(vectors_to_upsert), batch_size): - batch = vectors_to_upsert[i : i + batch_size] - self._batch_upsert(batch) + if ids_to_delete and self.index: + self.index.delete(ids=ids_to_delete) def _get_route_ids(self, route_name: str): clean_route = clean_route_name(route_name) diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index 0fff231495d236120bc1dd4d3756fbd6a9359e38..a77e6f888ba8023cd5ce6d3fc955bd6b1f0615ce 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -160,16 +160,13 @@ class QdrantIndex(BaseIndex): **self.config, ) - def _add_and_sync( - self, - embeddings: List[List[float]], - routes: List[str], - utterances: List[str], - batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE, - ): + def _remove_and_sync(self, routes_to_delete: dict): + if self.sync is not None: + logger.error("Sync remove is not implemented for LocalIndex.") + + def _sync_index(self, local_route_names: List[str], local_utterances: List[str], dimensions: int): if self.sync is not None: - logger.warning("Sync add is not implemented for QdrantIndex") - self.add(embeddings, routes, utterances, batch_size) + logger.error("Sync remove is not implemented for QdrantIndex.") def add( self, diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 81f690b594ba5e3ed5151e206720746de6a4bb0c..8888460dbd78f3a8b34b5ba167dba6f356494390 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -217,8 +217,13 @@ class RouteLayer: if route.score_threshold is None: route.score_threshold = self.score_threshold # if routes list has been passed, we initialize index now - if len(self.routes) > 0: + if self.index.sync: # initialize index now + if len(self.routes) > 0: + self._add_and_sync_routes(routes=self.routes) + else: + self._add_and_sync_routes(routes=[]) + elif len(self.routes) > 0: self._add_routes(routes=self.routes) def check_for_matching_routes(self, top_class: str) -> Optional[Route]: @@ -385,6 +390,14 @@ class RouteLayer: ) return self._pass_threshold(scores, threshold) + def _set_layer_routes(self, new_routes: List[Route]): + """ + Set and override the current routes with a new list of routes. + + :param new_routes: List of Route objects to set as the current routes. + """ + self.routes = new_routes + def __str__(self): return ( f"RouteLayer(encoder={self.encoder}, " @@ -464,19 +477,57 @@ class RouteLayer: def _add_routes(self, routes: List[Route]): # create embeddings for all routes - all_utterances = [ - utterance for route in routes for utterance in route.utterances - ] + route_names, all_utterances = self._extract_routes_details(routes) embedded_utterances = self.encoder(all_utterances) # create route array - route_names = [route.name for route in routes for _ in route.utterances] # add everything to the index - self.index._add_and_sync( + self.index.add( embeddings=embedded_utterances, routes=route_names, utterances=all_utterances, ) + + def _add_and_sync_routes(self, routes: List[Route]): + # create embeddings for all routes and sync at startup with remote ones based on sync setting + local_route_names, local_utterances = self._extract_routes_details(routes) + routes_to_add, routes_to_delete, layer_routes_dict = self.index._sync_index( + local_route_names=local_route_names, + local_utterances=local_utterances, + dimensions=len(self.encoder(["dummy"])[0]) + ) + + logger.info(f"ROUTES TO ADD: {(routes_to_add)}") + logger.info(f"ROUTES TO DELETE: {(routes_to_delete)}") + + layer_routes = [ + Route(name=route, utterances=layer_routes_dict[route]) + for route in layer_routes_dict.keys() + ] + + data_to_delete: dict = {} + for route, utterance in routes_to_delete: + data_to_delete.setdefault(route, []).append(utterance) + self.index._remove_and_sync(data_to_delete) + + all_utterances_to_add = [utt for _, utt in routes_to_add] + embedded_utterances_to_add = self.encoder(all_utterances_to_add) if all_utterances_to_add else [] + + route_names_to_add = [route for route, _, in routes_to_add] + + self.index.add( + embeddings=embedded_utterances_to_add, + routes=route_names_to_add, + utterances=all_utterances_to_add, + ) + + self._set_layer_routes(layer_routes) + + def _extract_routes_details(self, routes: List[Route]) -> Tuple[List[str], List[str]]: + route_names = [route.name for route in routes for _ in route.utterances] + utterances = [utterance for route in routes for utterance in route.utterances] + return route_names, utterances + def _encode(self, text: str) -> Any: """Given some text, encode it.""" # create query vector