diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 24722b437af5f18731eb05ebac8f9c4b6155872c..29c5a3499c322fa9d7c02ef7e28df94f3c6ca76a 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -12,10 +12,10 @@ from semantic_router.route import Route from semantic_router.schema import Encoder, EncoderType, RouteChoice from semantic_router.utils.logger import logger import itertools -from tqdm import tqdm +from tqdm.auto import tqdm from typing import Dict - +import random def is_valid(layer_config: str) -> bool: """Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]""" @@ -337,44 +337,76 @@ class RouteLayer: config = self.to_config() config.to_file(file_path) + def get_route_thresholds(self) -> Dict[str, float]: + return {route.name: route.score_threshold for route in self.routes} + + def fit( + self, + test_data: List[Tuple[str, str]], + score_threshold_values: List[float]=[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95], + num_samples: int = 20 + ): + test_route_selection = TestRouteSelection(route_layer=self) + # Find the best score threshold for each route + best_thresholds, best_accuracy = test_route_selection.random_score_threshold_search( + test_data=test_data, + score_threshold_values=score_threshold_values, + num_samples=num_samples + ) + test_route_selection.update_route_thresholds(best_thresholds) + return best_accuracy, best_thresholds class TestRouteSelection: - def __init__(self, route_layer: RouteLayer, test_data: List[Tuple[str, str]], score_threshold_values: List[float] = None): - self.route_layer = route_layer - self.test_data = test_data - self.score_threshold_values = score_threshold_values - def evaluate(self, score_thresholds: Optional[Dict[str, float]] = None) -> float: - correct = 0 - for input_text, expected_route_name in self.test_data: - # Set the threshold for each route based on the provided thresholds, if any - if score_thresholds: - for route in self.route_layer.routes: - route.score_threshold = score_thresholds.get(route.name, self.route_layer.score_threshold) - - route_choice = self.route_layer(input_text) - if route_choice.name == expected_route_name: # Adjusted this line - correct += 1 - accuracy = correct / len(self.test_data) - return accuracy + def __init__(self, route_layer: RouteLayer): + self.route_layer = route_layer - def grid_search(self): + def random_score_threshold_search( + self, + test_data: List[Tuple[str, str]], + score_threshold_values: List[float], + num_samples: int, + ): # Define the range of threshold values for each route route_names = [route.name for route in self.route_layer.routes] - - best_score = 0 + best_accuracy = 0 best_thresholds = {} + # Evaluate the performance for each random sample + for _ in tqdm(range(num_samples), desc=f"Processing {num_samples} Samples."): + # Generate a random threshold for each route + score_thresholds = {route: random.choice(score_threshold_values) for route in route_names} - # Create a list of dictionaries, each representing a possible combination of thresholds - threshold_combinations = [dict(zip(route_names, score_thresholds)) for score_thresholds in itertools.product(self.score_threshold_values, repeat=len(route_names))] + # Update the route thresholds + self.update_route_thresholds(score_thresholds) - print(f"Processing {len(threshold_combinations)} combinations.") - - # Evaluate the performance for each combination - for score_thresholds in tqdm(threshold_combinations): - score = self.evaluate(score_thresholds) - if score > best_score: - best_score = score + accuracy = self.evaluate(test_data=test_data) + if accuracy > best_accuracy: + best_accuracy = accuracy best_thresholds = score_thresholds - return best_thresholds, best_score \ No newline at end of file + return best_thresholds, best_accuracy + + def update_route_thresholds(self, score_thresholds: Optional[Dict[str, float]] = None): + """ + Update the score thresholds for each route. + """ + if score_thresholds: + for route in self.route_layer.routes: + route.score_threshold = score_thresholds.get(route.name, self.route_layer.score_threshold) + + def evaluate(self, test_data: List[Tuple[str, str]]) -> float: + """ + Evaluate the accuracy of the route selection. + """ + correct = 0 + for input_text, expected_route_name in test_data: + route_choice = self.route_layer(input_text) + if route_choice.name == expected_route_name: + correct += 1 + accuracy = correct / len(test_data) + return accuracy + + + + +