Skip to content
Snippets Groups Projects
Unverified Commit e21beba6 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Now using Route thresholds if they exist.

Updated _semantic_classify_multiple_routes to use the Route thresholds if they exist.
parent 1fc4eb52
Branches
Tags
No related merge requests found
...@@ -257,9 +257,7 @@ class RouteLayer: ...@@ -257,9 +257,7 @@ class RouteLayer:
results = self._retrieve(xq=vector_arr) results = self._retrieve(xq=vector_arr)
# decide most relevant routes # decide most relevant routes
categories_with_scores = self._semantic_classify_multiple_routes( categories_with_scores = self._semantic_classify_multiple_routes(results)
results, self.score_threshold
)
route_choices = [] route_choices = []
for category, score in categories_with_scores: for category, score in categories_with_scores:
...@@ -397,18 +395,30 @@ class RouteLayer: ...@@ -397,18 +395,30 @@ class RouteLayer:
else: else:
logger.warning("No classification found for semantic classifier.") logger.warning("No classification found for semantic classifier.")
return "", [] return "", []
def get(self, name: str) -> Optional[Route]:
for route in self.routes:
if route.name == name:
return route
logger.error(f"Route `{name}` not found")
return None
def _semantic_classify_multiple_routes( def _semantic_classify_multiple_routes(
self, query_results: List[dict], threshold: float self, query_results: List[dict]
) -> List[Tuple[str, float]]: ) -> List[Tuple[str, float]]:
scores_by_class = self.group_scores_by_class(query_results) scores_by_class = self.group_scores_by_class(query_results)
# Filter classes based on threshold and find max score for each # Filter classes based on threshold and find max score for each
classes_above_threshold = [] classes_above_threshold = []
for route, scores in scores_by_class.items(): for route_name, scores in scores_by_class.items():
if self._pass_threshold(scores, threshold): # Use the get method to find the Route object by its name
max_score = max(scores) route_obj = self.get(route_name)
classes_above_threshold.append((route, max_score)) if route_obj is not None:
# Use the Route object's threshold if it exists, otherwise use the provided threshold
_threshold = route_obj.score_threshold if route_obj.score_threshold is not None else self.score_threshold
if self._pass_threshold(scores, _threshold):
max_score = max(scores)
classes_above_threshold.append((route_name, max_score))
return classes_above_threshold return classes_above_threshold
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment