Skip to content
Snippets Groups Projects
Unverified Commit 2ea1f199 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Fit and vec_evaluate can now handle dynamic routes by treating them as static.

Introduced _simulate_static_route_selection which simulates static routes, even when dynamic routes are being evaluated in _vec_evaluate().

This was necessary as dynamic routes use text inputs, but we use vector inputs when evaluating for increased performance.

Also refactored a little to avoid code duplication between _simulate_static_route_selection() and __call__().
parent 0cd4c37a
No related branches found
Tags v0.2.8
No related merge requests found
......@@ -225,23 +225,11 @@ class RouteLayer:
if vector is None:
if text is None:
raise ValueError("Either text or vector must be provided")
vector_arr = self._encode(text=text)
else:
vector_arr = np.array(vector)
# get relevant results (scores and routes)
results = self._retrieve(xq=vector_arr)
# decide most relevant routes
top_class, top_class_scores = self._semantic_classify(results)
# TODO do we need this check?
route = self.check_for_matching_routes(top_class)
if route is None:
return RouteChoice()
threshold = (
route.score_threshold
if route.score_threshold is not None
else self.score_threshold
)
passed = self._pass_threshold(top_class_scores, threshold)
vector = self._encode(text=text)
route, top_class_scores = self._retrieve_top_route(vector)
passed = self._check_threshold(top_class_scores, route)
if passed:
if route.function_schema and text is None:
raise ValueError(
......@@ -263,6 +251,29 @@ class RouteLayer:
else:
# if no route passes threshold, return empty route choice
return RouteChoice()
def _retrieve_top_route(self, vector: List[float]) -> Tuple[Optional[Route], List[float]]:
"""
Retrieve the top matching route based on the given vector.
Returns a tuple of the route (if any) and the scores of the top class.
"""
# get relevant results (scores and routes)
results = self._retrieve(xq=np.array(vector))
# decide most relevant routes
top_class, top_class_scores = self._semantic_classify(results)
# TODO do we need this check?
route = self.check_for_matching_routes(top_class)
return route, top_class_scores
def _check_threshold(self, scores: List[float], route: Optional[Route]) -> bool:
"""
Check if the route's score passes the specified threshold.
"""
if route is None:
return False
threshold = route.score_threshold if route.score_threshold is not None else self.score_threshold
return self._pass_threshold(scores, threshold)
def __str__(self):
return (
......@@ -481,11 +492,25 @@ class RouteLayer:
"""
correct = 0
for xq, target_route in zip(Xq, y):
route_choice = self(vector=xq)
# We can't do route_choice = self(vector=xq) here as it won't work for dynamic routes.
route_choice = self._simulate_static_route_selection(vector=xq)
if route_choice.name == target_route:
correct += 1
accuracy = correct / len(Xq)
return accuracy
def _simulate_static_route_selection(self, vector: List[float]) -> RouteChoice:
"""
Simulate the route selection process treating all routes as static, including threshold checking.
Dynamic routes require a query string to be passed to the __call__ method, but here we work with vectors to boost performance.
Hence, we simulate the route selection process treating all routes as static.
"""
route, scores = self._retrieve_top_route(vector)
passed = self._check_threshold(scores, route)
if passed:
return RouteChoice(name=route.name, function_call=None, similarity_score=None, trigger=None)
else:
return RouteChoice()
def _get_route_names(self) -> List[str]:
return [route.name for route in self.routes]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment