diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 1ab1db73781ef9aca9a76196eb8d0b4c0bf44700..432f0af98066d9d5ec407694145a55f5767a7724 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -220,6 +220,7 @@ class RouteLayer: self, text: Optional[str] = None, vector: Optional[List[float]] = None, + simulate_static: bool = False, ) -> RouteChoice: # if no vector provided, encode text to get vector if vector is None: @@ -230,7 +231,7 @@ class RouteLayer: route, top_class_scores = self._retrieve_top_route(vector) passed = self._check_threshold(top_class_scores, route) - if passed and route is not None: + if passed and route is not None and not simulate_static: if route.function_schema and text is None: raise ValueError( "Route has a function schema, but no text was provided." @@ -248,6 +249,13 @@ class RouteLayer: else: route.llm = self.llm return route(text) + elif passed and route is not None and simulate_static: + return RouteChoice( + name=route.name, + function_call=None, + similarity_score=None, + trigger=None, + ) else: # if no route passes threshold, return empty route choice return RouteChoice() @@ -497,28 +505,13 @@ class RouteLayer: """ correct = 0 for xq, target_route in zip(Xq, y): - # We can't do route_choice = self(vector=xq) here as it won't work for dynamic routes. - route_choice = self._simulate_static_route_selection(vector=xq) + # We treate dynamic routes as static here, because when evaluating we use only vectors, and dynamic routes expect strings by default. + route_choice = self(vector=xq, simulate_static=True) if route_choice.name == target_route: correct += 1 accuracy = correct / len(Xq) return accuracy - def _simulate_static_route_selection(self, vector: List[float]) -> RouteChoice: - """ - Simulate the route selection process treating all routes as static, including threshold checking. - Dynamic routes require a query string to be passed to the __call__ method, but here we work with vectors to boost performance. - Hence, we simulate the route selection process treating all routes as static. - """ - route, scores = self._retrieve_top_route(vector) - passed = self._check_threshold(scores, route) - if passed and route is not None: - return RouteChoice( - name=route.name, function_call=None, similarity_score=None, trigger=None - ) - else: - return RouteChoice() - def _get_route_names(self) -> List[str]: return [route.name for route in self.routes]