diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index c51c3ff2a543d4496ee000fd8e2776faaa0f32cc..d0f12ac682ab1686dd9d931e73082d7695a2e1fe 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -33,6 +33,17 @@ class BaseIndex(BaseModel): """ raise NotImplementedError("This method should be implemented by subclasses.") + def get_routes(self): + """ + Retrieves a list of routes and their associated utterances from the index. + This method should be implemented by subclasses. + + :returns: A list of tuples, each containing a route name and an associated utterance. + :rtype: list[tuple] + :raises NotImplementedError: If the method is not implemented by the subclass. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + def _remove_and_sync(self, routes_to_delete: dict): """ Remove embeddings in a routes syncing process from the index. diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 25e14b8f4343c03c13bc709cae7b5d2bf37ddb8c..6b548fc068b7e4d235c50b0d6fb75922aa145ea0 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -709,7 +709,21 @@ class RouteLayer: y: List[str], batch_size: int = 500, max_iter: int = 500, + local_execution: bool = False, ): + original_index = self.index + if local_execution: + # Switch to a local index for fitting + from semantic_router.index.local import LocalIndex + + remote_routes = self.index.get_routes() + # TODO Enhance by retrieving directly the vectors instead of embedding all utterances again + routes = [route_tuple[0] for route_tuple in remote_routes] + utterances = [route_tuple[1] for route_tuple in remote_routes] + embeddings = self.encoder(utterances) + self.index = LocalIndex() + self.index.add(embeddings=embeddings, routes=routes, utterances=utterances) + # convert inputs into array Xq: List[List[float]] = [] for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"): @@ -737,6 +751,10 @@ class RouteLayer: # update route layer to best thresholds self._update_thresholds(score_thresholds=best_thresholds) + if local_execution: + # Switch back to the original index + self.index = original_index + def evaluate(self, X: List[str], y: List[str], batch_size: int = 500) -> float: """ Evaluate the accuracy of the route selection.