diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 29c5a3499c322fa9d7c02ef7e28df94f3c6ca76a..b03bb25ff6f42a374956dc543ebfe90134c23432 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -341,15 +341,17 @@ class RouteLayer: return {route.name: route.score_threshold for route in self.routes} def fit( - self, - test_data: List[Tuple[str, str]], - score_threshold_values: List[float]=[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95], - num_samples: int = 20 - ): + self, + X: List[str], + Y: List[str], + score_threshold_values: List[float] = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95], + num_samples: int = 20 + ): test_route_selection = TestRouteSelection(route_layer=self) # Find the best score threshold for each route best_thresholds, best_accuracy = test_route_selection.random_score_threshold_search( - test_data=test_data, + X=X, + Y=Y, score_threshold_values=score_threshold_values, num_samples=num_samples ) @@ -363,7 +365,8 @@ class TestRouteSelection: def random_score_threshold_search( self, - test_data: List[Tuple[str, str]], + X: List[str], + Y: List[str], score_threshold_values: List[float], num_samples: int, ): @@ -379,7 +382,7 @@ class TestRouteSelection: # Update the route thresholds self.update_route_thresholds(score_thresholds) - accuracy = self.evaluate(test_data=test_data) + accuracy = self.evaluate(X=X, Y=Y) if accuracy > best_accuracy: best_accuracy = accuracy best_thresholds = score_thresholds @@ -394,18 +397,18 @@ class TestRouteSelection: for route in self.route_layer.routes: route.score_threshold = score_thresholds.get(route.name, self.route_layer.score_threshold) - def evaluate(self, test_data: List[Tuple[str, str]]) -> float: + def evaluate(self, X: List[str], Y: List[str]) -> float: """ Evaluate the accuracy of the route selection. """ correct = 0 - for input_text, expected_route_name in test_data: + for input_text, expected_route_name in zip(X, Y): route_choice = self.route_layer(input_text) if route_choice.name == expected_route_name: correct += 1 - accuracy = correct / len(test_data) + accuracy = correct / len(X) return accuracy - +