From c5dd5dd2e0381e37248d6e0bc3c0094cf61d7082 Mon Sep 17 00:00:00 2001 From: Vits <vittorio.mayellaro.dev@gmail.com> Date: Thu, 18 Jul 2024 01:20:28 +0200 Subject: [PATCH] Fixed pytests --- semantic_router/layer.py | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/semantic_router/layer.py b/semantic_router/layer.py index f4042b6c..5852b8db 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -217,18 +217,21 @@ class RouteLayer: if route.score_threshold is None: route.score_threshold = self.score_threshold # if routes list has been passed, we initialize index now - if len(self.routes) > 0: + if self.index.sync: # initialize index now - self._add_routes(routes=self.routes) - elif self.index.sync: - dummy_embedding = self.encoder(["dummy"]) + if len(self.routes) > 0: + self._add_and_sync_routes(routes=self.routes) + else: + dummy_embedding = self.encoder(["dummy"]) - layer_routes = self.index._add_and_sync( - embeddings=dummy_embedding, - routes=[], - utterances=[], - ) - self._set_layer_routes(layer_routes) + layer_routes = self.index._add_and_sync( + embeddings=dummy_embedding, + routes=[], + utterances=[], + ) + self._set_layer_routes(layer_routes) + elif len(self.routes) > 0: + self._add_routes(routes=self.routes) def check_for_matching_routes(self, top_class: str) -> Optional[Route]: matching_routes = [route for route in self.routes if route.name == top_class] @@ -483,6 +486,21 @@ class RouteLayer: # create route array route_names = [route.name for route in routes for _ in route.utterances] # add everything to the index + self.index.add( + embeddings=embedded_utterances, + routes=route_names, + utterances=all_utterances, + ) + + def _add_and_sync_routes(self, routes: List[Route]): + # create embeddings for all routes and sync at startup with remote ones based on sync setting + all_utterances = [ + utterance for route in routes for utterance in route.utterances + ] + embedded_utterances = self.encoder(all_utterances) + # create route array + route_names = [route.name for route in routes for _ in route.utterances] + # add everything to the index layer_routes = self.index._add_and_sync( embeddings=embedded_utterances, routes=route_names, -- GitLab