diff --git a/semantic_router/layer.py b/semantic_router/layer.py index f7c534d92b51d33be8596cae904b7461d5512284..04d611ad2357b6a5512ecf70274a1288044af9ed 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -257,9 +257,7 @@ class RouteLayer: results = self._retrieve(xq=vector_arr) # decide most relevant routes - categories_with_scores = self._semantic_classify_multiple_routes( - results, self.score_threshold - ) + categories_with_scores = self._semantic_classify_multiple_routes(results) route_choices = [] for category, score in categories_with_scores: @@ -397,18 +395,30 @@ class RouteLayer: else: logger.warning("No classification found for semantic classifier.") return "", [] + + def get(self, name: str) -> Optional[Route]: + for route in self.routes: + if route.name == name: + return route + logger.error(f"Route `{name}` not found") + return None def _semantic_classify_multiple_routes( - self, query_results: List[dict], threshold: float + self, query_results: List[dict] ) -> List[Tuple[str, float]]: scores_by_class = self.group_scores_by_class(query_results) # Filter classes based on threshold and find max score for each classes_above_threshold = [] - for route, scores in scores_by_class.items(): - if self._pass_threshold(scores, threshold): - max_score = max(scores) - classes_above_threshold.append((route, max_score)) + for route_name, scores in scores_by_class.items(): + # Use the get method to find the Route object by its name + route_obj = self.get(route_name) + if route_obj is not None: + # Use the Route object's threshold if it exists, otherwise use the provided threshold + _threshold = route_obj.score_threshold if route_obj.score_threshold is not None else self.score_threshold + if self._pass_threshold(scores, _threshold): + max_score = max(scores) + classes_above_threshold.append((route_name, max_score)) return classes_above_threshold