Skip to content
Snippets Groups Projects
Unverified Commit 95485482 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Ran Black

parent 9c73ba27
No related branches found
No related tags found
No related merge requests found
......@@ -17,6 +17,7 @@ from tqdm.auto import tqdm
from typing import Dict
import random
def is_valid(layer_config: str) -> bool:
"""Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]"""
try:
......@@ -189,7 +190,11 @@ class RouteLayer:
top_class, top_class_scores = self._semantic_classify(results)
# get chosen route object
route = [route for route in self.routes if route.name == top_class][0]
threshold = route.score_threshold if route.score_threshold is not None else self.score_threshold
threshold = (
route.score_threshold
if route.score_threshold is not None
else self.score_threshold
)
passed = self._pass_threshold(top_class_scores, threshold)
if passed:
if route.function_schema and not isinstance(route.llm, BaseLLM):
......@@ -343,35 +348,49 @@ class RouteLayer:
return {route.name: route.score_threshold for route in self.routes}
def fit(
self,
self,
X: List[str],
Y: List[str],
score_threshold_values: List[float] = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95],
num_samples: int = 20
):
score_threshold_values: List[float] = [
0.5,
0.55,
0.6,
0.65,
0.7,
0.75,
0.8,
0.85,
0.9,
0.95,
],
num_samples: int = 20,
):
test_route_selection = TestRouteSelection(route_layer=self)
# Find the best score threshold for each route
best_thresholds, best_accuracy = test_route_selection.random_score_threshold_search(
(
best_thresholds,
best_accuracy,
) = test_route_selection.random_score_threshold_search(
X=X,
Y=Y,
score_threshold_values=score_threshold_values,
num_samples=num_samples
)
num_samples=num_samples,
)
test_route_selection.update_route_thresholds(best_thresholds)
return best_accuracy, best_thresholds
class TestRouteSelection:
class TestRouteSelection:
def __init__(self, route_layer: RouteLayer):
self.route_layer = route_layer
def random_score_threshold_search(
self,
X: List[str],
Y: List[str],
score_threshold_values: List[float],
num_samples: int,
):
self,
X: List[str],
Y: List[str],
score_threshold_values: List[float],
num_samples: int,
):
# Define the range of threshold values for each route
route_names = [route.name for route in self.route_layer.routes]
best_accuracy = 0
......@@ -379,7 +398,9 @@ class TestRouteSelection:
# Evaluate the performance for each random sample
for _ in tqdm(range(num_samples), desc=f"Processing {num_samples} Samples."):
# Generate a random threshold for each route
score_thresholds = {route: random.choice(score_threshold_values) for route in route_names}
score_thresholds = {
route: random.choice(score_threshold_values) for route in route_names
}
# Update the route thresholds
self.update_route_thresholds(score_thresholds)
......@@ -391,13 +412,17 @@ class TestRouteSelection:
return best_thresholds, best_accuracy
def update_route_thresholds(self, score_thresholds: Optional[Dict[str, float]] = None):
def update_route_thresholds(
self, score_thresholds: Optional[Dict[str, float]] = None
):
"""
Update the score thresholds for each route.
"""
if score_thresholds:
for route in self.route_layer.routes:
route.score_threshold = score_thresholds.get(route.name, self.route_layer.score_threshold)
route.score_threshold = score_thresholds.get(
route.name, self.route_layer.score_threshold
)
def evaluate(self, X: List[str], Y: List[str]) -> float:
"""
......@@ -410,8 +435,3 @@ class TestRouteSelection:
correct += 1
accuracy = correct / len(X)
return accuracy
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment