Skip to content
Snippets Groups Projects
Unverified Commit b1b2105f authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Initial Code

parent 45e7ca0d
Branches
Tags
No related merge requests found
...@@ -11,6 +11,10 @@ from semantic_router.llms import BaseLLM, OpenAILLM ...@@ -11,6 +11,10 @@ from semantic_router.llms import BaseLLM, OpenAILLM
from semantic_router.route import Route from semantic_router.route import Route
from semantic_router.schema import Encoder, EncoderType, RouteChoice from semantic_router.schema import Encoder, EncoderType, RouteChoice
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
import itertools
from tqdm import tqdm
from typing import Dict
def is_valid(layer_config: str) -> bool: def is_valid(layer_config: str) -> bool:
...@@ -183,10 +187,11 @@ class RouteLayer: ...@@ -183,10 +187,11 @@ class RouteLayer:
def __call__(self, text: str) -> RouteChoice: def __call__(self, text: str) -> RouteChoice:
results = self._query(text) results = self._query(text)
top_class, top_class_scores = self._semantic_classify(results) top_class, top_class_scores = self._semantic_classify(results)
passed = self._pass_threshold(top_class_scores, self.score_threshold) # get chosen route object
route = [route for route in self.routes if route.name == top_class][0]
threshold = route.score_threshold if route.score_threshold is not None else self.score_threshold
passed = self._pass_threshold(top_class_scores, threshold)
if passed: if passed:
# get chosen route object
route = [route for route in self.routes if route.name == top_class][0]
if route.function_schema and not isinstance(route.llm, BaseLLM): if route.function_schema and not isinstance(route.llm, BaseLLM):
if not self.llm: if not self.llm:
logger.warning( logger.warning(
...@@ -331,3 +336,45 @@ class RouteLayer: ...@@ -331,3 +336,45 @@ class RouteLayer:
def to_yaml(self, file_path: str): def to_yaml(self, file_path: str):
config = self.to_config() config = self.to_config()
config.to_file(file_path) config.to_file(file_path)
class TestRouteSelection:
def __init__(self, route_layer: RouteLayer, test_data: List[Tuple[str, str]], score_threshold_values: Optional[List[float]] = None):
self.route_layer = route_layer
self.test_data = test_data
self.score_threshold_values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
def evaluate(self, score_thresholds: Optional[Dict[str, float]] = None) -> float:
correct = 0
for input_text, expected_route_name in self.test_data:
# Set the threshold for each route based on the provided thresholds, if any
if score_thresholds:
for route in self.route_layer.routes:
route.score_threshold = score_thresholds.get(route.name, self.route_layer.score_threshold)
route_choice = self.route_layer(input_text)
if route_choice.name == expected_route_name: # Adjusted this line
correct += 1
accuracy = correct / len(self.test_data)
return accuracy
def grid_search(self):
# Define the range of threshold values for each route
route_names = [route.name for route in self.route_layer.routes]
best_score = 0
best_thresholds = {}
# Create a list of dictionaries, each representing a possible combination of thresholds
threshold_combinations = [dict(zip(route_names, score_thresholds)) for score_thresholds in itertools.product(self.score_threshold_values, repeat=len(route_names))]
print(f"Processing {len(threshold_combinations)} combinations.")
# Evaluate the performance for each combination
for score_thresholds in tqdm(threshold_combinations):
score = self.evaluate(score_thresholds)
if score > best_score:
best_score = score
best_thresholds = score_thresholds
return best_thresholds, best_score
\ No newline at end of file
...@@ -44,6 +44,7 @@ class Route(BaseModel): ...@@ -44,6 +44,7 @@ class Route(BaseModel):
description: Optional[str] = None description: Optional[str] = None
function_schema: Optional[Dict[str, Any]] = None function_schema: Optional[Dict[str, Any]] = None
llm: Optional[BaseLLM] = None llm: Optional[BaseLLM] = None
score_threshold: Optional[float] = None
def __call__(self, query: str) -> RouteChoice: def __call__(self, query: str) -> RouteChoice:
if self.function_schema: if self.function_schema:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment