diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 27c1d4245e668c2612d1dd379c02655f36caed29..76aff0ebec5a086168011ba2dbb2459d9d380876 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -17,6 +17,7 @@ from tqdm.auto import tqdm from typing import Dict import random + def is_valid(layer_config: str) -> bool: """Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]""" try: @@ -189,7 +190,11 @@ class RouteLayer: top_class, top_class_scores = self._semantic_classify(results) # get chosen route object route = [route for route in self.routes if route.name == top_class][0] - threshold = route.score_threshold if route.score_threshold is not None else self.score_threshold + threshold = ( + route.score_threshold + if route.score_threshold is not None + else self.score_threshold + ) passed = self._pass_threshold(top_class_scores, threshold) if passed: if route.function_schema and not isinstance(route.llm, BaseLLM): @@ -343,35 +348,49 @@ class RouteLayer: return {route.name: route.score_threshold for route in self.routes} def fit( - self, + self, X: List[str], Y: List[str], - score_threshold_values: List[float] = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95], - num_samples: int = 20 - ): + score_threshold_values: List[float] = [ + 0.5, + 0.55, + 0.6, + 0.65, + 0.7, + 0.75, + 0.8, + 0.85, + 0.9, + 0.95, + ], + num_samples: int = 20, + ): test_route_selection = TestRouteSelection(route_layer=self) # Find the best score threshold for each route - best_thresholds, best_accuracy = test_route_selection.random_score_threshold_search( + ( + best_thresholds, + best_accuracy, + ) = test_route_selection.random_score_threshold_search( X=X, Y=Y, score_threshold_values=score_threshold_values, - num_samples=num_samples - ) + num_samples=num_samples, + ) test_route_selection.update_route_thresholds(best_thresholds) return best_accuracy, best_thresholds -class TestRouteSelection: +class TestRouteSelection: def __init__(self, route_layer: RouteLayer): self.route_layer = route_layer def random_score_threshold_search( - self, - X: List[str], - Y: List[str], - score_threshold_values: List[float], - num_samples: int, - ): + self, + X: List[str], + Y: List[str], + score_threshold_values: List[float], + num_samples: int, + ): # Define the range of threshold values for each route route_names = [route.name for route in self.route_layer.routes] best_accuracy = 0 @@ -379,7 +398,9 @@ class TestRouteSelection: # Evaluate the performance for each random sample for _ in tqdm(range(num_samples), desc=f"Processing {num_samples} Samples."): # Generate a random threshold for each route - score_thresholds = {route: random.choice(score_threshold_values) for route in route_names} + score_thresholds = { + route: random.choice(score_threshold_values) for route in route_names + } # Update the route thresholds self.update_route_thresholds(score_thresholds) @@ -391,13 +412,17 @@ class TestRouteSelection: return best_thresholds, best_accuracy - def update_route_thresholds(self, score_thresholds: Optional[Dict[str, float]] = None): + def update_route_thresholds( + self, score_thresholds: Optional[Dict[str, float]] = None + ): """ Update the score thresholds for each route. """ if score_thresholds: for route in self.route_layer.routes: - route.score_threshold = score_thresholds.get(route.name, self.route_layer.score_threshold) + route.score_threshold = score_thresholds.get( + route.name, self.route_layer.score_threshold + ) def evaluate(self, X: List[str], Y: List[str]) -> float: """ @@ -410,8 +435,3 @@ class TestRouteSelection: correct += 1 accuracy = correct / len(X) return accuracy - - - - -