diff --git a/docs/00-introduction.ipynb b/docs/00-introduction.ipynb index 97b2ce3323debc2e95b39fd8ebd88770ee66a409..dfd85f43a693551c5dde655b34b480a7a6bd42f7 100644 --- a/docs/00-introduction.ipynb +++ b/docs/00-introduction.ipynb @@ -162,7 +162,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-04-02 22:18:59 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" + "\u001b[32m2024-04-02 22:45:21 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" ] } ], diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 04d611ad2357b6a5512ecf70274a1288044af9ed..77e6713e51460a6575748f24c1ac95fc8a144837 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -264,7 +264,7 @@ class RouteLayer: route = self.check_for_matching_routes(category) if route: route_choice = RouteChoice( - name=route.name, similarity_score=score, route=route + name=route.name, similarity_score=score ) route_choices.append(route_choice) @@ -395,7 +395,7 @@ class RouteLayer: else: logger.warning("No classification found for semantic classifier.") return "", [] - + def get(self, name: str) -> Optional[Route]: for route in self.routes: if route.name == name: @@ -415,14 +415,20 @@ class RouteLayer: route_obj = self.get(route_name) if route_obj is not None: # Use the Route object's threshold if it exists, otherwise use the provided threshold - _threshold = route_obj.score_threshold if route_obj.score_threshold is not None else self.score_threshold + _threshold = ( + route_obj.score_threshold + if route_obj.score_threshold is not None + else self.score_threshold + ) if self._pass_threshold(scores, _threshold): max_score = max(scores) classes_above_threshold.append((route_name, max_score)) return classes_above_threshold - - def group_scores_by_class(self, query_results: List[dict]) -> Dict[str, List[float]]: + + def group_scores_by_class( + self, query_results: List[dict] + ) -> Dict[str, List[float]]: scores_by_class: Dict[str, List[float]] = {} for result in query_results: score = result["score"] @@ -433,7 +439,6 @@ class RouteLayer: scores_by_class[route] = [score] return scores_by_class - def _pass_threshold(self, scores: List[float], threshold: float) -> bool: if scores: return max(scores) > threshold