From e21beba6383657d3a23b30906ffd17d64f642aca Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Tue, 2 Apr 2024 16:23:03 +0400 Subject: [PATCH] Now using Route thresholds if they exist. Updated _semantic_classify_multiple_routes to use the Route thresholds if they exist. --- semantic_router/layer.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/semantic_router/layer.py b/semantic_router/layer.py index f7c534d9..04d611ad 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -257,9 +257,7 @@ class RouteLayer: results = self._retrieve(xq=vector_arr) # decide most relevant routes - categories_with_scores = self._semantic_classify_multiple_routes( - results, self.score_threshold - ) + categories_with_scores = self._semantic_classify_multiple_routes(results) route_choices = [] for category, score in categories_with_scores: @@ -397,18 +395,30 @@ class RouteLayer: else: logger.warning("No classification found for semantic classifier.") return "", [] + + def get(self, name: str) -> Optional[Route]: + for route in self.routes: + if route.name == name: + return route + logger.error(f"Route `{name}` not found") + return None def _semantic_classify_multiple_routes( - self, query_results: List[dict], threshold: float + self, query_results: List[dict] ) -> List[Tuple[str, float]]: scores_by_class = self.group_scores_by_class(query_results) # Filter classes based on threshold and find max score for each classes_above_threshold = [] - for route, scores in scores_by_class.items(): - if self._pass_threshold(scores, threshold): - max_score = max(scores) - classes_above_threshold.append((route, max_score)) + for route_name, scores in scores_by_class.items(): + # Use the get method to find the Route object by its name + route_obj = self.get(route_name) + if route_obj is not None: + # Use the Route object's threshold if it exists, otherwise use the provided threshold + _threshold = route_obj.score_threshold if route_obj.score_threshold is not None else self.score_threshold + if self._pass_threshold(scores, _threshold): + max_score = max(scores) + classes_above_threshold.append((route_name, max_score)) return classes_above_threshold -- GitLab