From 25eae2433c9e4ba7dc1d63b705a61ea2858def5b Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Mon, 11 Mar 2024 17:10:37 +0400 Subject: [PATCH] Linting. --- semantic_router/layer.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 0741756c..5acf92e0 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -226,7 +226,7 @@ class RouteLayer: if text is None: raise ValueError("Either text or vector must be provided") vector = self._encode(text=text) - + route, top_class_scores = self._retrieve_top_route(vector) passed = self._check_threshold(top_class_scores, route) @@ -251,8 +251,10 @@ class RouteLayer: else: # if no route passes threshold, return empty route choice return RouteChoice() - - def _retrieve_top_route(self, vector: List[float]) -> Tuple[Optional[Route], List[float]]: + + def _retrieve_top_route( + self, vector: List[float] + ) -> Tuple[Optional[Route], List[float]]: """ Retrieve the top matching route based on the given vector. Returns a tuple of the route (if any) and the scores of the top class. @@ -264,17 +266,20 @@ class RouteLayer: # TODO do we need this check? route = self.check_for_matching_routes(top_class) return route, top_class_scores - + def _check_threshold(self, scores: List[float], route: Optional[Route]) -> bool: """ Check if the route's score passes the specified threshold. """ if route is None: return False - threshold = route.score_threshold if route.score_threshold is not None else self.score_threshold + threshold = ( + route.score_threshold + if route.score_threshold is not None + else self.score_threshold + ) return self._pass_threshold(scores, threshold) - def __str__(self): return ( f"RouteLayer(encoder={self.encoder}, " @@ -498,7 +503,7 @@ class RouteLayer: correct += 1 accuracy = correct / len(Xq) return accuracy - + def _simulate_static_route_selection(self, vector: List[float]) -> RouteChoice: """ Simulate the route selection process treating all routes as static, including threshold checking. @@ -508,7 +513,9 @@ class RouteLayer: route, scores = self._retrieve_top_route(vector) passed = self._check_threshold(scores, route) if passed: - return RouteChoice(name=route.name, function_call=None, similarity_score=None, trigger=None) + return RouteChoice( + name=route.name, function_call=None, similarity_score=None, trigger=None + ) else: return RouteChoice() -- GitLab