From 6494666b7ae2a67818924e0e7207fd55e157f2bc Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Thu, 25 Jan 2024 13:18:02 +0400
Subject: [PATCH] Altered fit() to use X and Y Inputs

---
 semantic_router/layer.py | 27 +++++++++++++++------------
 1 file changed, 15 insertions(+), 12 deletions(-)

diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 29c5a349..b03bb25f 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -341,15 +341,17 @@ class RouteLayer:
         return {route.name: route.score_threshold for route in self.routes}
 
     def fit(
-            self, 
-            test_data: List[Tuple[str, str]],
-            score_threshold_values: List[float]=[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95],
-            num_samples: int = 20
-            ):
+        self, 
+        X: List[str],
+        Y: List[str],
+        score_threshold_values: List[float] = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95],
+        num_samples: int = 20
+        ):
         test_route_selection = TestRouteSelection(route_layer=self)
         # Find the best score threshold for each route
         best_thresholds, best_accuracy = test_route_selection.random_score_threshold_search(
-            test_data=test_data,
+            X=X,
+            Y=Y,
             score_threshold_values=score_threshold_values,
             num_samples=num_samples
             )
@@ -363,7 +365,8 @@ class TestRouteSelection:
 
     def random_score_threshold_search(
             self, 
-            test_data: List[Tuple[str, str]],
+            X: List[str],
+            Y: List[str],
             score_threshold_values: List[float], 
             num_samples: int,
             ):
@@ -379,7 +382,7 @@ class TestRouteSelection:
             # Update the route thresholds
             self.update_route_thresholds(score_thresholds)
 
-            accuracy = self.evaluate(test_data=test_data)
+            accuracy = self.evaluate(X=X, Y=Y)
             if accuracy > best_accuracy:
                 best_accuracy = accuracy
                 best_thresholds = score_thresholds
@@ -394,18 +397,18 @@ class TestRouteSelection:
             for route in self.route_layer.routes:
                 route.score_threshold = score_thresholds.get(route.name, self.route_layer.score_threshold)
 
-    def evaluate(self, test_data: List[Tuple[str, str]]) -> float:
+    def evaluate(self, X: List[str], Y: List[str]) -> float:
         """
         Evaluate the accuracy of the route selection.
         """
         correct = 0
-        for input_text, expected_route_name in test_data:
+        for input_text, expected_route_name in zip(X, Y):
             route_choice = self.route_layer(input_text)
             if route_choice.name == expected_route_name:
                 correct += 1
-        accuracy = correct / len(test_data)
+        accuracy = correct / len(X)
         return accuracy
-    
+        
 
 
 
-- 
GitLab