From 2ea1f19944812909cc3c34dbd45bf3e9972f501b Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Mon, 11 Mar 2024 17:03:41 +0400
Subject: [PATCH] Fit and vec_evaluate can now handle dynamic routes by
 treating them as static.

Introduced _simulate_static_route_selection which simulates static routes, even when dynamic routes are being evaluated in _vec_evaluate().

This was necessary as dynamic routes use text inputs, but we use vector inputs when evaluating for increased performance.

Also refactored a little to avoid code duplication between _simulate_static_route_selection() and __call__().
---
 semantic_router/layer.py | 61 ++++++++++++++++++++++++++++------------
 1 file changed, 43 insertions(+), 18 deletions(-)

diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 3ef16206..0741756c 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -225,23 +225,11 @@ class RouteLayer:
         if vector is None:
             if text is None:
                 raise ValueError("Either text or vector must be provided")
-            vector_arr = self._encode(text=text)
-        else:
-            vector_arr = np.array(vector)
-        # get relevant results (scores and routes)
-        results = self._retrieve(xq=vector_arr)
-        # decide most relevant routes
-        top_class, top_class_scores = self._semantic_classify(results)
-        # TODO do we need this check?
-        route = self.check_for_matching_routes(top_class)
-        if route is None:
-            return RouteChoice()
-        threshold = (
-            route.score_threshold
-            if route.score_threshold is not None
-            else self.score_threshold
-        )
-        passed = self._pass_threshold(top_class_scores, threshold)
+            vector = self._encode(text=text)
+        
+        route, top_class_scores = self._retrieve_top_route(vector)
+        passed = self._check_threshold(top_class_scores, route)
+
         if passed:
             if route.function_schema and text is None:
                 raise ValueError(
@@ -263,6 +251,29 @@ class RouteLayer:
         else:
             # if no route passes threshold, return empty route choice
             return RouteChoice()
+        
+    def _retrieve_top_route(self, vector: List[float]) -> Tuple[Optional[Route], List[float]]:
+        """
+        Retrieve the top matching route based on the given vector.
+        Returns a tuple of the route (if any) and the scores of the top class.
+        """
+        # get relevant results (scores and routes)
+        results = self._retrieve(xq=np.array(vector))
+        # decide most relevant routes
+        top_class, top_class_scores = self._semantic_classify(results)
+        # TODO do we need this check?
+        route = self.check_for_matching_routes(top_class)
+        return route, top_class_scores
+    
+    def _check_threshold(self, scores: List[float], route: Optional[Route]) -> bool:
+        """
+        Check if the route's score passes the specified threshold.
+        """
+        if route is None:
+            return False
+        threshold = route.score_threshold if route.score_threshold is not None else self.score_threshold
+        return self._pass_threshold(scores, threshold)
+
 
     def __str__(self):
         return (
@@ -481,11 +492,25 @@ class RouteLayer:
         """
         correct = 0
         for xq, target_route in zip(Xq, y):
-            route_choice = self(vector=xq)
+            # We can't do route_choice = self(vector=xq) here as it won't work for dynamic routes.
+            route_choice = self._simulate_static_route_selection(vector=xq)
             if route_choice.name == target_route:
                 correct += 1
         accuracy = correct / len(Xq)
         return accuracy
+        
+    def _simulate_static_route_selection(self, vector: List[float]) -> RouteChoice:
+        """
+        Simulate the route selection process treating all routes as static, including threshold checking.
+        Dynamic routes require a query string to be passed to the __call__ method, but here we work with vectors to boost performance.
+        Hence, we simulate the route selection process treating all routes as static.
+        """
+        route, scores = self._retrieve_top_route(vector)
+        passed = self._check_threshold(scores, route)
+        if passed:
+            return RouteChoice(name=route.name, function_call=None, similarity_score=None, trigger=None)
+        else:
+            return RouteChoice()
 
     def _get_route_names(self) -> List[str]:
         return [route.name for route in self.routes]
-- 
GitLab