diff --git a/docs/00-introduction.ipynb b/docs/00-introduction.ipynb index 7d4fed5d974a7b31077a45f638681b7c8ebbf62d..5ec13702875339d029b2222847d2e91dc8bc1a5e 100644 --- a/docs/00-introduction.ipynb +++ b/docs/00-introduction.ipynb @@ -199,7 +199,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "rl.retrieve_multiple_routes(\"Hi! How are you doing in politics??\")" ] }, diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 31c7cfa081e9b9143d6b406c7d8f711374f72fc0..6f1b6a463b1ce0c39bb192e46e734b24e3ed9117 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -257,13 +257,17 @@ class RouteLayer: results = self._retrieve(xq=vector_arr) # decide most relevant routes - categories_with_scores = self._semantic_classify_multiple_routes(results, self.score_threshold) + categories_with_scores = self._semantic_classify_multiple_routes( + results, self.score_threshold + ) route_choices = [] for category, score in categories_with_scores: route = self.check_for_matching_routes(category) if route: - route_choice = RouteChoice(name=route.name, similarity_score=score, route=route) + route_choice = RouteChoice( + name=route.name, similarity_score=score, route=route + ) route_choices.append(route_choice) return route_choices @@ -401,7 +405,9 @@ class RouteLayer: logger.warning("No classification found for semantic classifier.") return "", [] - def _semantic_classify_multiple_routes(self, query_results: List[dict], threshold: float) -> List[Tuple[str, float]]: + def _semantic_classify_multiple_routes( + self, query_results: List[dict], threshold: float + ) -> List[Tuple[str, float]]: scores_by_class: Dict[str, List[float]] = {} for result in query_results: score = result["score"] @@ -411,7 +417,6 @@ class RouteLayer: else: scores_by_class[route] = [score] - # Filter classes based on threshold and find max score for each classes_above_threshold = [] for route, scores in scores_by_class.items(): @@ -420,8 +425,7 @@ class RouteLayer: classes_above_threshold.append((route, max_score)) return classes_above_threshold - - + def _pass_threshold(self, scores: List[float], threshold: float) -> bool: if scores: return max(scores) > threshold @@ -541,8 +545,8 @@ def threshold_random_search( if __name__ == "__main__": from semantic_router import Route - from semantic_router.layer import RouteLayer from semantic_router.encoders import OpenAIEncoder + from semantic_router.layer import RouteLayer # Define routes with example phrases politics = Route( @@ -577,7 +581,13 @@ if __name__ == "__main__": rl = RouteLayer(encoder=encoder, routes=routes) # Test the RouteLayer with example queries - print(rl.retrieve_multiple_routes("how's the weather today?")) # Expected to match the chitchat route - print(rl.retrieve_multiple_routes("don't you love politics?")) # Expected to match the politics route - print(rl.retrieve_multiple_routes("I'm interested in learning about llama 2")) # Expected to return None since it doesn't match any route - print(rl.retrieve_multiple_routes("Hi! How are you doing in politics??")) \ No newline at end of file + print( + rl.retrieve_multiple_routes("how's the weather today?") + ) # Expected to match the chitchat route + print( + rl.retrieve_multiple_routes("don't you love politics?") + ) # Expected to match the politics route + print( + rl.retrieve_multiple_routes("I'm interested in learning about llama 2") + ) # Expected to return None since it doesn't match any route + print(rl.retrieve_multiple_routes("Hi! How are you doing in politics??"))