From 1fc4eb52b894b8b2989144b21ab08428e32a9ff9 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Tue, 2 Apr 2024 16:06:57 +0400 Subject: [PATCH] Moving repeated code into a separate method. Created the group_scores_by_class method. --- semantic_router/layer.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 895e90a9..f7c534d9 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -385,14 +385,7 @@ class RouteLayer: return [] def _semantic_classify(self, query_results: List[dict]) -> Tuple[str, List[float]]: - scores_by_class: Dict[str, List[float]] = {} - for result in query_results: - score = result["score"] - route = result["route"] - if route in scores_by_class: - scores_by_class[route].append(score) - else: - scores_by_class[route] = [score] + scores_by_class = self.group_scores_by_class(query_results) # Calculate total score for each class total_scores = {route: sum(scores) for route, scores in scores_by_class.items()} @@ -408,14 +401,7 @@ class RouteLayer: def _semantic_classify_multiple_routes( self, query_results: List[dict], threshold: float ) -> List[Tuple[str, float]]: - scores_by_class: Dict[str, List[float]] = {} - for result in query_results: - score = result["score"] - route = result["route"] - if route in scores_by_class: - scores_by_class[route].append(score) - else: - scores_by_class[route] = [score] + scores_by_class = self.group_scores_by_class(query_results) # Filter classes based on threshold and find max score for each classes_above_threshold = [] @@ -425,6 +411,18 @@ class RouteLayer: classes_above_threshold.append((route, max_score)) return classes_above_threshold + + def group_scores_by_class(self, query_results: List[dict]) -> Dict[str, List[float]]: + scores_by_class: Dict[str, List[float]] = {} + for result in query_results: + score = result["score"] + route = result["route"] + if route in scores_by_class: + scores_by_class[route].append(score) + else: + scores_by_class[route] = [score] + return scores_by_class + def _pass_threshold(self, scores: List[float], threshold: float) -> bool: if scores: -- GitLab