Skip to content
Snippets Groups Projects
Unverified Commit 25eae243 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Linting.

parent 2ea1f199
No related branches found
No related tags found
No related merge requests found
......@@ -226,7 +226,7 @@ class RouteLayer:
if text is None:
raise ValueError("Either text or vector must be provided")
vector = self._encode(text=text)
route, top_class_scores = self._retrieve_top_route(vector)
passed = self._check_threshold(top_class_scores, route)
......@@ -251,8 +251,10 @@ class RouteLayer:
else:
# if no route passes threshold, return empty route choice
return RouteChoice()
def _retrieve_top_route(self, vector: List[float]) -> Tuple[Optional[Route], List[float]]:
def _retrieve_top_route(
self, vector: List[float]
) -> Tuple[Optional[Route], List[float]]:
"""
Retrieve the top matching route based on the given vector.
Returns a tuple of the route (if any) and the scores of the top class.
......@@ -264,17 +266,20 @@ class RouteLayer:
# TODO do we need this check?
route = self.check_for_matching_routes(top_class)
return route, top_class_scores
def _check_threshold(self, scores: List[float], route: Optional[Route]) -> bool:
"""
Check if the route's score passes the specified threshold.
"""
if route is None:
return False
threshold = route.score_threshold if route.score_threshold is not None else self.score_threshold
threshold = (
route.score_threshold
if route.score_threshold is not None
else self.score_threshold
)
return self._pass_threshold(scores, threshold)
def __str__(self):
return (
f"RouteLayer(encoder={self.encoder}, "
......@@ -498,7 +503,7 @@ class RouteLayer:
correct += 1
accuracy = correct / len(Xq)
return accuracy
def _simulate_static_route_selection(self, vector: List[float]) -> RouteChoice:
"""
Simulate the route selection process treating all routes as static, including threshold checking.
......@@ -508,7 +513,9 @@ class RouteLayer:
route, scores = self._retrieve_top_route(vector)
passed = self._check_threshold(scores, route)
if passed:
return RouteChoice(name=route.name, function_call=None, similarity_score=None, trigger=None)
return RouteChoice(
name=route.name, function_call=None, similarity_score=None, trigger=None
)
else:
return RouteChoice()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment