From c5dd5dd2e0381e37248d6e0bc3c0094cf61d7082 Mon Sep 17 00:00:00 2001
From: Vits <vittorio.mayellaro.dev@gmail.com>
Date: Thu, 18 Jul 2024 01:20:28 +0200
Subject: [PATCH] Fixed pytests

---
 semantic_router/layer.py | 38 ++++++++++++++++++++++++++++----------
 1 file changed, 28 insertions(+), 10 deletions(-)

diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index f4042b6c..5852b8db 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -217,18 +217,21 @@ class RouteLayer:
             if route.score_threshold is None:
                 route.score_threshold = self.score_threshold
         # if routes list has been passed, we initialize index now
-        if len(self.routes) > 0:
+        if self.index.sync:
             # initialize index now
-            self._add_routes(routes=self.routes)
-        elif self.index.sync:
-            dummy_embedding = self.encoder(["dummy"])
+            if len(self.routes) > 0:
+                self._add_and_sync_routes(routes=self.routes)
+            else:
+                dummy_embedding = self.encoder(["dummy"])
 
-            layer_routes = self.index._add_and_sync(
-                embeddings=dummy_embedding,
-                routes=[],
-                utterances=[],
-            )
-            self._set_layer_routes(layer_routes)
+                layer_routes = self.index._add_and_sync(
+                    embeddings=dummy_embedding,
+                    routes=[],
+                    utterances=[],
+                )
+                self._set_layer_routes(layer_routes)
+        elif len(self.routes) > 0:
+            self._add_routes(routes=self.routes)
 
     def check_for_matching_routes(self, top_class: str) -> Optional[Route]:
         matching_routes = [route for route in self.routes if route.name == top_class]
@@ -483,6 +486,21 @@ class RouteLayer:
         # create route array
         route_names = [route.name for route in routes for _ in route.utterances]
         # add everything to the index
+        self.index.add(
+            embeddings=embedded_utterances,
+            routes=route_names,
+            utterances=all_utterances,
+        )
+    
+    def _add_and_sync_routes(self, routes: List[Route]):
+        # create embeddings for all routes and sync at startup with remote ones based on sync setting
+        all_utterances = [
+            utterance for route in routes for utterance in route.utterances
+        ]
+        embedded_utterances = self.encoder(all_utterances)
+        # create route array
+        route_names = [route.name for route in routes for _ in route.utterances]
+        # add everything to the index
         layer_routes = self.index._add_and_sync(
             embeddings=embedded_utterances,
             routes=route_names,
-- 
GitLab