diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 15b3fc0d2700529a0daee273aab43df2af03c7dc..4f169588c220829bb6ac89d7f91f7649926e175a 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -11,6 +11,10 @@ from semantic_router.llms import BaseLLM, OpenAILLM from semantic_router.route import Route from semantic_router.schema import Encoder, EncoderType, RouteChoice from semantic_router.utils.logger import logger +import itertools +from tqdm import tqdm + +from typing import Dict def is_valid(layer_config: str) -> bool: @@ -183,10 +187,11 @@ class RouteLayer: def __call__(self, text: str) -> RouteChoice: results = self._query(text) top_class, top_class_scores = self._semantic_classify(results) - passed = self._pass_threshold(top_class_scores, self.score_threshold) + # get chosen route object + route = [route for route in self.routes if route.name == top_class][0] + threshold = route.score_threshold if route.score_threshold is not None else self.score_threshold + passed = self._pass_threshold(top_class_scores, threshold) if passed: - # get chosen route object - route = [route for route in self.routes if route.name == top_class][0] if route.function_schema and not isinstance(route.llm, BaseLLM): if not self.llm: logger.warning( @@ -331,3 +336,45 @@ class RouteLayer: def to_yaml(self, file_path: str): config = self.to_config() config.to_file(file_path) + + +class TestRouteSelection: + def __init__(self, route_layer: RouteLayer, test_data: List[Tuple[str, str]], score_threshold_values: Optional[List[float]] = None): + self.route_layer = route_layer + self.test_data = test_data + self.score_threshold_values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + + def evaluate(self, score_thresholds: Optional[Dict[str, float]] = None) -> float: + correct = 0 + for input_text, expected_route_name in self.test_data: + # Set the threshold for each route based on the provided thresholds, if any + if score_thresholds: + for route in self.route_layer.routes: + route.score_threshold = score_thresholds.get(route.name, self.route_layer.score_threshold) + + route_choice = self.route_layer(input_text) + if route_choice.name == expected_route_name: # Adjusted this line + correct += 1 + accuracy = correct / len(self.test_data) + return accuracy + + def grid_search(self): + # Define the range of threshold values for each route + route_names = [route.name for route in self.route_layer.routes] + + best_score = 0 + best_thresholds = {} + + # Create a list of dictionaries, each representing a possible combination of thresholds + threshold_combinations = [dict(zip(route_names, score_thresholds)) for score_thresholds in itertools.product(self.score_threshold_values, repeat=len(route_names))] + + print(f"Processing {len(threshold_combinations)} combinations.") + + # Evaluate the performance for each combination + for score_thresholds in tqdm(threshold_combinations): + score = self.evaluate(score_thresholds) + if score > best_score: + best_score = score + best_thresholds = score_thresholds + + return best_thresholds, best_score \ No newline at end of file diff --git a/semantic_router/route.py b/semantic_router/route.py index 112f60fd4332c31d78ca2a4f5dd342a1306b3a24..ceb14a53c90473f72263625636c89b9d911400df 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -44,6 +44,7 @@ class Route(BaseModel): description: Optional[str] = None function_schema: Optional[Dict[str, Any]] = None llm: Optional[BaseLLM] = None + score_threshold: Optional[float] = None def __call__(self, query: str) -> RouteChoice: if self.function_schema: