From 884e15e117a7a647abb295e3525f8aedf0a25582 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Tue, 2 Apr 2024 22:49:39 +0400
Subject: [PATCH] Linting.

---
 docs/00-introduction.ipynb |  2 +-
 semantic_router/layer.py   | 17 +++++++++++------
 2 files changed, 12 insertions(+), 7 deletions(-)

diff --git a/docs/00-introduction.ipynb b/docs/00-introduction.ipynb
index 97b2ce33..dfd85f43 100644
--- a/docs/00-introduction.ipynb
+++ b/docs/00-introduction.ipynb
@@ -162,7 +162,7 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "\u001b[32m2024-04-02 22:18:59 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n"
+      "\u001b[32m2024-04-02 22:45:21 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n"
      ]
     }
    ],
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 04d611ad..77e6713e 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -264,7 +264,7 @@ class RouteLayer:
             route = self.check_for_matching_routes(category)
             if route:
                 route_choice = RouteChoice(
-                    name=route.name, similarity_score=score, route=route
+                    name=route.name, similarity_score=score
                 )
                 route_choices.append(route_choice)
 
@@ -395,7 +395,7 @@ class RouteLayer:
         else:
             logger.warning("No classification found for semantic classifier.")
             return "", []
-        
+
     def get(self, name: str) -> Optional[Route]:
         for route in self.routes:
             if route.name == name:
@@ -415,14 +415,20 @@ class RouteLayer:
             route_obj = self.get(route_name)
             if route_obj is not None:
                 # Use the Route object's threshold if it exists, otherwise use the provided threshold
-                _threshold = route_obj.score_threshold if route_obj.score_threshold is not None else self.score_threshold
+                _threshold = (
+                    route_obj.score_threshold
+                    if route_obj.score_threshold is not None
+                    else self.score_threshold
+                )
                 if self._pass_threshold(scores, _threshold):
                     max_score = max(scores)
                     classes_above_threshold.append((route_name, max_score))
 
         return classes_above_threshold
-    
-    def group_scores_by_class(self, query_results: List[dict]) -> Dict[str, List[float]]:
+
+    def group_scores_by_class(
+        self, query_results: List[dict]
+    ) -> Dict[str, List[float]]:
         scores_by_class: Dict[str, List[float]] = {}
         for result in query_results:
             score = result["score"]
@@ -433,7 +439,6 @@ class RouteLayer:
                 scores_by_class[route] = [score]
         return scores_by_class
 
-
     def _pass_threshold(self, scores: List[float], threshold: float) -> bool:
         if scores:
             return max(scores) > threshold
-- 
GitLab