Skip to content
Snippets Groups Projects
Unverified Commit 49fb8f2e authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Included Random Search and Fit

Created a random_score_threshold_search method in TestRouteSelection for finding optimum score threshold values across all routes, using a random search algorithm.

Also created a fit method in RouteLayer for running the random_score_threshold_search and for
parent 09b5111a
No related branches found
No related tags found
No related merge requests found
...@@ -12,10 +12,10 @@ from semantic_router.route import Route ...@@ -12,10 +12,10 @@ from semantic_router.route import Route
from semantic_router.schema import Encoder, EncoderType, RouteChoice from semantic_router.schema import Encoder, EncoderType, RouteChoice
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
import itertools import itertools
from tqdm import tqdm from tqdm.auto import tqdm
from typing import Dict from typing import Dict
import random
def is_valid(layer_config: str) -> bool: def is_valid(layer_config: str) -> bool:
"""Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]""" """Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]"""
...@@ -337,44 +337,76 @@ class RouteLayer: ...@@ -337,44 +337,76 @@ class RouteLayer:
config = self.to_config() config = self.to_config()
config.to_file(file_path) config.to_file(file_path)
def get_route_thresholds(self) -> Dict[str, float]:
return {route.name: route.score_threshold for route in self.routes}
def fit(
self,
test_data: List[Tuple[str, str]],
score_threshold_values: List[float]=[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95],
num_samples: int = 20
):
test_route_selection = TestRouteSelection(route_layer=self)
# Find the best score threshold for each route
best_thresholds, best_accuracy = test_route_selection.random_score_threshold_search(
test_data=test_data,
score_threshold_values=score_threshold_values,
num_samples=num_samples
)
test_route_selection.update_route_thresholds(best_thresholds)
return best_accuracy, best_thresholds
class TestRouteSelection: class TestRouteSelection:
def __init__(self, route_layer: RouteLayer, test_data: List[Tuple[str, str]], score_threshold_values: List[float] = None):
self.route_layer = route_layer
self.test_data = test_data
self.score_threshold_values = score_threshold_values
def evaluate(self, score_thresholds: Optional[Dict[str, float]] = None) -> float: def __init__(self, route_layer: RouteLayer):
correct = 0 self.route_layer = route_layer
for input_text, expected_route_name in self.test_data:
# Set the threshold for each route based on the provided thresholds, if any
if score_thresholds:
for route in self.route_layer.routes:
route.score_threshold = score_thresholds.get(route.name, self.route_layer.score_threshold)
route_choice = self.route_layer(input_text)
if route_choice.name == expected_route_name: # Adjusted this line
correct += 1
accuracy = correct / len(self.test_data)
return accuracy
def grid_search(self): def random_score_threshold_search(
self,
test_data: List[Tuple[str, str]],
score_threshold_values: List[float],
num_samples: int,
):
# Define the range of threshold values for each route # Define the range of threshold values for each route
route_names = [route.name for route in self.route_layer.routes] route_names = [route.name for route in self.route_layer.routes]
best_accuracy = 0
best_score = 0
best_thresholds = {} best_thresholds = {}
# Evaluate the performance for each random sample
for _ in tqdm(range(num_samples), desc=f"Processing {num_samples} Samples."):
# Generate a random threshold for each route
score_thresholds = {route: random.choice(score_threshold_values) for route in route_names}
# Create a list of dictionaries, each representing a possible combination of thresholds # Update the route thresholds
threshold_combinations = [dict(zip(route_names, score_thresholds)) for score_thresholds in itertools.product(self.score_threshold_values, repeat=len(route_names))] self.update_route_thresholds(score_thresholds)
print(f"Processing {len(threshold_combinations)} combinations.") accuracy = self.evaluate(test_data=test_data)
if accuracy > best_accuracy:
# Evaluate the performance for each combination best_accuracy = accuracy
for score_thresholds in tqdm(threshold_combinations):
score = self.evaluate(score_thresholds)
if score > best_score:
best_score = score
best_thresholds = score_thresholds best_thresholds = score_thresholds
return best_thresholds, best_score return best_thresholds, best_accuracy
\ No newline at end of file
def update_route_thresholds(self, score_thresholds: Optional[Dict[str, float]] = None):
"""
Update the score thresholds for each route.
"""
if score_thresholds:
for route in self.route_layer.routes:
route.score_threshold = score_thresholds.get(route.name, self.route_layer.score_threshold)
def evaluate(self, test_data: List[Tuple[str, str]]) -> float:
"""
Evaluate the accuracy of the route selection.
"""
correct = 0
for input_text, expected_route_name in test_data:
route_choice = self.route_layer(input_text)
if route_choice.name == expected_route_name:
correct += 1
accuracy = correct / len(test_data)
return accuracy
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment