From 2ea1f19944812909cc3c34dbd45bf3e9972f501b Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Mon, 11 Mar 2024 17:03:41 +0400 Subject: [PATCH] Fit and vec_evaluate can now handle dynamic routes by treating them as static. Introduced _simulate_static_route_selection which simulates static routes, even when dynamic routes are being evaluated in _vec_evaluate(). This was necessary as dynamic routes use text inputs, but we use vector inputs when evaluating for increased performance. Also refactored a little to avoid code duplication between _simulate_static_route_selection() and __call__(). --- semantic_router/layer.py | 61 ++++++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 3ef16206..0741756c 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -225,23 +225,11 @@ class RouteLayer: if vector is None: if text is None: raise ValueError("Either text or vector must be provided") - vector_arr = self._encode(text=text) - else: - vector_arr = np.array(vector) - # get relevant results (scores and routes) - results = self._retrieve(xq=vector_arr) - # decide most relevant routes - top_class, top_class_scores = self._semantic_classify(results) - # TODO do we need this check? - route = self.check_for_matching_routes(top_class) - if route is None: - return RouteChoice() - threshold = ( - route.score_threshold - if route.score_threshold is not None - else self.score_threshold - ) - passed = self._pass_threshold(top_class_scores, threshold) + vector = self._encode(text=text) + + route, top_class_scores = self._retrieve_top_route(vector) + passed = self._check_threshold(top_class_scores, route) + if passed: if route.function_schema and text is None: raise ValueError( @@ -263,6 +251,29 @@ class RouteLayer: else: # if no route passes threshold, return empty route choice return RouteChoice() + + def _retrieve_top_route(self, vector: List[float]) -> Tuple[Optional[Route], List[float]]: + """ + Retrieve the top matching route based on the given vector. + Returns a tuple of the route (if any) and the scores of the top class. + """ + # get relevant results (scores and routes) + results = self._retrieve(xq=np.array(vector)) + # decide most relevant routes + top_class, top_class_scores = self._semantic_classify(results) + # TODO do we need this check? + route = self.check_for_matching_routes(top_class) + return route, top_class_scores + + def _check_threshold(self, scores: List[float], route: Optional[Route]) -> bool: + """ + Check if the route's score passes the specified threshold. + """ + if route is None: + return False + threshold = route.score_threshold if route.score_threshold is not None else self.score_threshold + return self._pass_threshold(scores, threshold) + def __str__(self): return ( @@ -481,11 +492,25 @@ class RouteLayer: """ correct = 0 for xq, target_route in zip(Xq, y): - route_choice = self(vector=xq) + # We can't do route_choice = self(vector=xq) here as it won't work for dynamic routes. + route_choice = self._simulate_static_route_selection(vector=xq) if route_choice.name == target_route: correct += 1 accuracy = correct / len(Xq) return accuracy + + def _simulate_static_route_selection(self, vector: List[float]) -> RouteChoice: + """ + Simulate the route selection process treating all routes as static, including threshold checking. + Dynamic routes require a query string to be passed to the __call__ method, but here we work with vectors to boost performance. + Hence, we simulate the route selection process treating all routes as static. + """ + route, scores = self._retrieve_top_route(vector) + passed = self._check_threshold(scores, route) + if passed: + return RouteChoice(name=route.name, function_call=None, similarity_score=None, trigger=None) + else: + return RouteChoice() def _get_route_names(self) -> List[str]: return [route.name for route in self.routes] -- GitLab