From 884e15e117a7a647abb295e3525f8aedf0a25582 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Tue, 2 Apr 2024 22:49:39 +0400 Subject: [PATCH] Linting. --- docs/00-introduction.ipynb | 2 +- semantic_router/layer.py | 17 +++++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/docs/00-introduction.ipynb b/docs/00-introduction.ipynb index 97b2ce33..dfd85f43 100644 --- a/docs/00-introduction.ipynb +++ b/docs/00-introduction.ipynb @@ -162,7 +162,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-04-02 22:18:59 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" + "\u001b[32m2024-04-02 22:45:21 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n" ] } ], diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 04d611ad..77e6713e 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -264,7 +264,7 @@ class RouteLayer: route = self.check_for_matching_routes(category) if route: route_choice = RouteChoice( - name=route.name, similarity_score=score, route=route + name=route.name, similarity_score=score ) route_choices.append(route_choice) @@ -395,7 +395,7 @@ class RouteLayer: else: logger.warning("No classification found for semantic classifier.") return "", [] - + def get(self, name: str) -> Optional[Route]: for route in self.routes: if route.name == name: @@ -415,14 +415,20 @@ class RouteLayer: route_obj = self.get(route_name) if route_obj is not None: # Use the Route object's threshold if it exists, otherwise use the provided threshold - _threshold = route_obj.score_threshold if route_obj.score_threshold is not None else self.score_threshold + _threshold = ( + route_obj.score_threshold + if route_obj.score_threshold is not None + else self.score_threshold + ) if self._pass_threshold(scores, _threshold): max_score = max(scores) classes_above_threshold.append((route_name, max_score)) return classes_above_threshold - - def group_scores_by_class(self, query_results: List[dict]) -> Dict[str, List[float]]: + + def group_scores_by_class( + self, query_results: List[dict] + ) -> Dict[str, List[float]]: scores_by_class: Dict[str, List[float]] = {} for result in query_results: score = result["score"] @@ -433,7 +439,6 @@ class RouteLayer: scores_by_class[route] = [score] return scores_by_class - def _pass_threshold(self, scores: List[float], threshold: float) -> bool: if scores: return max(scores) > threshold -- GitLab