From 0cef1bdead3b5c4cafd7f112694d93f9fabe2b14 Mon Sep 17 00:00:00 2001 From: Luca Mannini <dev@lucamannini.com> Date: Thu, 1 Feb 2024 17:50:51 +0100 Subject: [PATCH] Lint --- docs/00-introduction.ipynb | 1 - semantic_router/layer.py | 32 +++++++++++++++++++++----------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/docs/00-introduction.ipynb b/docs/00-introduction.ipynb index 7d4fed5d..5ec13702 100644 --- a/docs/00-introduction.ipynb +++ b/docs/00-introduction.ipynb @@ -199,7 +199,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "rl.retrieve_multiple_routes(\"Hi! How are you doing in politics??\")" ] }, diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 31c7cfa0..6f1b6a46 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -257,13 +257,17 @@ class RouteLayer: results = self._retrieve(xq=vector_arr) # decide most relevant routes - categories_with_scores = self._semantic_classify_multiple_routes(results, self.score_threshold) + categories_with_scores = self._semantic_classify_multiple_routes( + results, self.score_threshold + ) route_choices = [] for category, score in categories_with_scores: route = self.check_for_matching_routes(category) if route: - route_choice = RouteChoice(name=route.name, similarity_score=score, route=route) + route_choice = RouteChoice( + name=route.name, similarity_score=score, route=route + ) route_choices.append(route_choice) return route_choices @@ -401,7 +405,9 @@ class RouteLayer: logger.warning("No classification found for semantic classifier.") return "", [] - def _semantic_classify_multiple_routes(self, query_results: List[dict], threshold: float) -> List[Tuple[str, float]]: + def _semantic_classify_multiple_routes( + self, query_results: List[dict], threshold: float + ) -> List[Tuple[str, float]]: scores_by_class: Dict[str, List[float]] = {} for result in query_results: score = result["score"] @@ -411,7 +417,6 @@ class RouteLayer: else: scores_by_class[route] = [score] - # Filter classes based on threshold and find max score for each classes_above_threshold = [] for route, scores in scores_by_class.items(): @@ -420,8 +425,7 @@ class RouteLayer: classes_above_threshold.append((route, max_score)) return classes_above_threshold - - + def _pass_threshold(self, scores: List[float], threshold: float) -> bool: if scores: return max(scores) > threshold @@ -541,8 +545,8 @@ def threshold_random_search( if __name__ == "__main__": from semantic_router import Route - from semantic_router.layer import RouteLayer from semantic_router.encoders import OpenAIEncoder + from semantic_router.layer import RouteLayer # Define routes with example phrases politics = Route( @@ -577,7 +581,13 @@ if __name__ == "__main__": rl = RouteLayer(encoder=encoder, routes=routes) # Test the RouteLayer with example queries - print(rl.retrieve_multiple_routes("how's the weather today?")) # Expected to match the chitchat route - print(rl.retrieve_multiple_routes("don't you love politics?")) # Expected to match the politics route - print(rl.retrieve_multiple_routes("I'm interested in learning about llama 2")) # Expected to return None since it doesn't match any route - print(rl.retrieve_multiple_routes("Hi! How are you doing in politics??")) \ No newline at end of file + print( + rl.retrieve_multiple_routes("how's the weather today?") + ) # Expected to match the chitchat route + print( + rl.retrieve_multiple_routes("don't you love politics?") + ) # Expected to match the politics route + print( + rl.retrieve_multiple_routes("I'm interested in learning about llama 2") + ) # Expected to return None since it doesn't match any route + print(rl.retrieve_multiple_routes("Hi! How are you doing in politics??")) -- GitLab