From 25eae2433c9e4ba7dc1d63b705a61ea2858def5b Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Mon, 11 Mar 2024 17:10:37 +0400
Subject: [PATCH] Linting.

---
 semantic_router/layer.py | 23 +++++++++++++++--------
 1 file changed, 15 insertions(+), 8 deletions(-)

diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 0741756c..5acf92e0 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -226,7 +226,7 @@ class RouteLayer:
             if text is None:
                 raise ValueError("Either text or vector must be provided")
             vector = self._encode(text=text)
-        
+
         route, top_class_scores = self._retrieve_top_route(vector)
         passed = self._check_threshold(top_class_scores, route)
 
@@ -251,8 +251,10 @@ class RouteLayer:
         else:
             # if no route passes threshold, return empty route choice
             return RouteChoice()
-        
-    def _retrieve_top_route(self, vector: List[float]) -> Tuple[Optional[Route], List[float]]:
+
+    def _retrieve_top_route(
+        self, vector: List[float]
+    ) -> Tuple[Optional[Route], List[float]]:
         """
         Retrieve the top matching route based on the given vector.
         Returns a tuple of the route (if any) and the scores of the top class.
@@ -264,17 +266,20 @@ class RouteLayer:
         # TODO do we need this check?
         route = self.check_for_matching_routes(top_class)
         return route, top_class_scores
-    
+
     def _check_threshold(self, scores: List[float], route: Optional[Route]) -> bool:
         """
         Check if the route's score passes the specified threshold.
         """
         if route is None:
             return False
-        threshold = route.score_threshold if route.score_threshold is not None else self.score_threshold
+        threshold = (
+            route.score_threshold
+            if route.score_threshold is not None
+            else self.score_threshold
+        )
         return self._pass_threshold(scores, threshold)
 
-
     def __str__(self):
         return (
             f"RouteLayer(encoder={self.encoder}, "
@@ -498,7 +503,7 @@ class RouteLayer:
                 correct += 1
         accuracy = correct / len(Xq)
         return accuracy
-        
+
     def _simulate_static_route_selection(self, vector: List[float]) -> RouteChoice:
         """
         Simulate the route selection process treating all routes as static, including threshold checking.
@@ -508,7 +513,9 @@ class RouteLayer:
         route, scores = self._retrieve_top_route(vector)
         passed = self._check_threshold(scores, route)
         if passed:
-            return RouteChoice(name=route.name, function_call=None, similarity_score=None, trigger=None)
+            return RouteChoice(
+                name=route.name, function_call=None, similarity_score=None, trigger=None
+            )
         else:
             return RouteChoice()
 
-- 
GitLab