From 1fc4eb52b894b8b2989144b21ab08428e32a9ff9 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Tue, 2 Apr 2024 16:06:57 +0400
Subject: [PATCH] Moving repeated code into a separate method.

Created the group_scores_by_class method.
---
 semantic_router/layer.py | 30 ++++++++++++++----------------
 1 file changed, 14 insertions(+), 16 deletions(-)

diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 895e90a9..f7c534d9 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -385,14 +385,7 @@ class RouteLayer:
             return []
 
     def _semantic_classify(self, query_results: List[dict]) -> Tuple[str, List[float]]:
-        scores_by_class: Dict[str, List[float]] = {}
-        for result in query_results:
-            score = result["score"]
-            route = result["route"]
-            if route in scores_by_class:
-                scores_by_class[route].append(score)
-            else:
-                scores_by_class[route] = [score]
+        scores_by_class = self.group_scores_by_class(query_results)
 
         # Calculate total score for each class
         total_scores = {route: sum(scores) for route, scores in scores_by_class.items()}
@@ -408,14 +401,7 @@ class RouteLayer:
     def _semantic_classify_multiple_routes(
         self, query_results: List[dict], threshold: float
     ) -> List[Tuple[str, float]]:
-        scores_by_class: Dict[str, List[float]] = {}
-        for result in query_results:
-            score = result["score"]
-            route = result["route"]
-            if route in scores_by_class:
-                scores_by_class[route].append(score)
-            else:
-                scores_by_class[route] = [score]
+        scores_by_class = self.group_scores_by_class(query_results)
 
         # Filter classes based on threshold and find max score for each
         classes_above_threshold = []
@@ -425,6 +411,18 @@ class RouteLayer:
                 classes_above_threshold.append((route, max_score))
 
         return classes_above_threshold
+    
+    def group_scores_by_class(self, query_results: List[dict]) -> Dict[str, List[float]]:
+        scores_by_class: Dict[str, List[float]] = {}
+        for result in query_results:
+            score = result["score"]
+            route = result["route"]
+            if route in scores_by_class:
+                scores_by_class[route].append(score)
+            else:
+                scores_by_class[route] = [score]
+        return scores_by_class
+
 
     def _pass_threshold(self, scores: List[float], threshold: float) -> bool:
         if scores:
-- 
GitLab