diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 23f82da319d478ea30d487a526b63687026ed4ab..83234879b8c3b9dc64a149102f7234ad71c94702 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -182,17 +182,21 @@ class RouteLayer: # initialize index now self._add_routes(routes=self.routes) - def __call__(self, text: str) -> RouteChoice: - results = self._query(text) - top_class, top_class_scores = self._semantic_classify(results) - # get chosen route object + def check_for_matching_routes(self, top_class: str) -> Optional[Route]: matching_routes = [route for route in self.routes if route.name == top_class] if not matching_routes: logger.error( f"No route found with name {top_class}. Check to see if any Routes have been defined." ) + return None + return matching_routes[0] + + def __call__(self, text: str) -> RouteChoice: + results = self._query(text) + top_class, top_class_scores = self._semantic_classify(results) + route = self.check_for_matching_routes(top_class) + if route is None: return RouteChoice() - route = matching_routes[0] threshold = ( route.score_threshold if route.score_threshold is not None