From 0cef1bdead3b5c4cafd7f112694d93f9fabe2b14 Mon Sep 17 00:00:00 2001
From: Luca Mannini <dev@lucamannini.com>
Date: Thu, 1 Feb 2024 17:50:51 +0100
Subject: [PATCH] Lint

---
 docs/00-introduction.ipynb |  1 -
 semantic_router/layer.py   | 32 +++++++++++++++++++++-----------
 2 files changed, 21 insertions(+), 12 deletions(-)

diff --git a/docs/00-introduction.ipynb b/docs/00-introduction.ipynb
index 7d4fed5d..5ec13702 100644
--- a/docs/00-introduction.ipynb
+++ b/docs/00-introduction.ipynb
@@ -199,7 +199,6 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "\n",
     "rl.retrieve_multiple_routes(\"Hi! How are you doing in politics??\")"
    ]
   },
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 31c7cfa0..6f1b6a46 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -257,13 +257,17 @@ class RouteLayer:
         results = self._retrieve(xq=vector_arr)
 
         # decide most relevant routes
-        categories_with_scores = self._semantic_classify_multiple_routes(results, self.score_threshold)
+        categories_with_scores = self._semantic_classify_multiple_routes(
+            results, self.score_threshold
+        )
 
         route_choices = []
         for category, score in categories_with_scores:
             route = self.check_for_matching_routes(category)
             if route:
-                route_choice = RouteChoice(name=route.name, similarity_score=score, route=route)
+                route_choice = RouteChoice(
+                    name=route.name, similarity_score=score, route=route
+                )
                 route_choices.append(route_choice)
 
         return route_choices
@@ -401,7 +405,9 @@ class RouteLayer:
             logger.warning("No classification found for semantic classifier.")
             return "", []
 
-    def _semantic_classify_multiple_routes(self, query_results: List[dict], threshold: float) -> List[Tuple[str, float]]:
+    def _semantic_classify_multiple_routes(
+        self, query_results: List[dict], threshold: float
+    ) -> List[Tuple[str, float]]:
         scores_by_class: Dict[str, List[float]] = {}
         for result in query_results:
             score = result["score"]
@@ -411,7 +417,6 @@ class RouteLayer:
             else:
                 scores_by_class[route] = [score]
 
-
         # Filter classes based on threshold and find max score for each
         classes_above_threshold = []
         for route, scores in scores_by_class.items():
@@ -420,8 +425,7 @@ class RouteLayer:
                 classes_above_threshold.append((route, max_score))
 
         return classes_above_threshold
-        
-        
+
     def _pass_threshold(self, scores: List[float], threshold: float) -> bool:
         if scores:
             return max(scores) > threshold
@@ -541,8 +545,8 @@ def threshold_random_search(
 
 if __name__ == "__main__":
     from semantic_router import Route
-    from semantic_router.layer import RouteLayer
     from semantic_router.encoders import OpenAIEncoder
+    from semantic_router.layer import RouteLayer
 
     # Define routes with example phrases
     politics = Route(
@@ -577,7 +581,13 @@ if __name__ == "__main__":
     rl = RouteLayer(encoder=encoder, routes=routes)
 
     # Test the RouteLayer with example queries
-    print(rl.retrieve_multiple_routes("how's the weather today?"))  # Expected to match the chitchat route
-    print(rl.retrieve_multiple_routes("don't you love politics?"))  # Expected to match the politics route
-    print(rl.retrieve_multiple_routes("I'm interested in learning about llama 2"))  # Expected to return None since it doesn't match any route
-    print(rl.retrieve_multiple_routes("Hi! How are you doing in politics??")) 
\ No newline at end of file
+    print(
+        rl.retrieve_multiple_routes("how's the weather today?")
+    )  # Expected to match the chitchat route
+    print(
+        rl.retrieve_multiple_routes("don't you love politics?")
+    )  # Expected to match the politics route
+    print(
+        rl.retrieve_multiple_routes("I'm interested in learning about llama 2")
+    )  # Expected to return None since it doesn't match any route
+    print(rl.retrieve_multiple_routes("Hi! How are you doing in politics??"))
-- 
GitLab