From e21beba6383657d3a23b30906ffd17d64f642aca Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Tue, 2 Apr 2024 16:23:03 +0400
Subject: [PATCH] Now using Route thresholds if they exist.

Updated _semantic_classify_multiple_routes to use the Route thresholds if they exist.
---
 semantic_router/layer.py | 26 ++++++++++++++++++--------
 1 file changed, 18 insertions(+), 8 deletions(-)

diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index f7c534d9..04d611ad 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -257,9 +257,7 @@ class RouteLayer:
         results = self._retrieve(xq=vector_arr)
 
         # decide most relevant routes
-        categories_with_scores = self._semantic_classify_multiple_routes(
-            results, self.score_threshold
-        )
+        categories_with_scores = self._semantic_classify_multiple_routes(results)
 
         route_choices = []
         for category, score in categories_with_scores:
@@ -397,18 +395,30 @@ class RouteLayer:
         else:
             logger.warning("No classification found for semantic classifier.")
             return "", []
+        
+    def get(self, name: str) -> Optional[Route]:
+        for route in self.routes:
+            if route.name == name:
+                return route
+        logger.error(f"Route `{name}` not found")
+        return None
 
     def _semantic_classify_multiple_routes(
-        self, query_results: List[dict], threshold: float
+        self, query_results: List[dict]
     ) -> List[Tuple[str, float]]:
         scores_by_class = self.group_scores_by_class(query_results)
 
         # Filter classes based on threshold and find max score for each
         classes_above_threshold = []
-        for route, scores in scores_by_class.items():
-            if self._pass_threshold(scores, threshold):
-                max_score = max(scores)
-                classes_above_threshold.append((route, max_score))
+        for route_name, scores in scores_by_class.items():
+            # Use the get method to find the Route object by its name
+            route_obj = self.get(route_name)
+            if route_obj is not None:
+                # Use the Route object's threshold if it exists, otherwise use the provided threshold
+                _threshold = route_obj.score_threshold if route_obj.score_threshold is not None else self.score_threshold
+                if self._pass_threshold(scores, _threshold):
+                    max_score = max(scores)
+                    classes_above_threshold.append((route_name, max_score))
 
         return classes_above_threshold
     
-- 
GitLab