From 1bed311dc0cc285b7756554ad48928a80a122bcf Mon Sep 17 00:00:00 2001 From: Vittorio <vittorio.mayellaro.dev@gmail.com> Date: Thu, 8 Aug 2024 10:54:47 +0200 Subject: [PATCH] Added local execution for layer fitting --- semantic_router/index/base.py | 11 +++++++++++ semantic_router/layer.py | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index c51c3ff2..d0f12ac6 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -33,6 +33,17 @@ class BaseIndex(BaseModel): """ raise NotImplementedError("This method should be implemented by subclasses.") + def get_routes(self): + """ + Retrieves a list of routes and their associated utterances from the index. + This method should be implemented by subclasses. + + :returns: A list of tuples, each containing a route name and an associated utterance. + :rtype: list[tuple] + :raises NotImplementedError: If the method is not implemented by the subclass. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + def _remove_and_sync(self, routes_to_delete: dict): """ Remove embeddings in a routes syncing process from the index. diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 25e14b8f..6b548fc0 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -709,7 +709,21 @@ class RouteLayer: y: List[str], batch_size: int = 500, max_iter: int = 500, + local_execution: bool = False, ): + original_index = self.index + if local_execution: + # Switch to a local index for fitting + from semantic_router.index.local import LocalIndex + + remote_routes = self.index.get_routes() + # TODO Enhance by retrieving directly the vectors instead of embedding all utterances again + routes = [route_tuple[0] for route_tuple in remote_routes] + utterances = [route_tuple[1] for route_tuple in remote_routes] + embeddings = self.encoder(utterances) + self.index = LocalIndex() + self.index.add(embeddings=embeddings, routes=routes, utterances=utterances) + # convert inputs into array Xq: List[List[float]] = [] for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"): @@ -737,6 +751,10 @@ class RouteLayer: # update route layer to best thresholds self._update_thresholds(score_thresholds=best_thresholds) + if local_execution: + # Switch back to the original index + self.index = original_index + def evaluate(self, X: List[str], y: List[str], batch_size: int = 500) -> float: """ Evaluate the accuracy of the route selection. -- GitLab