From 32434bc777ca598161d7ccbf5c307588fe16d05d Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Thu, 25 Jan 2024 14:23:31 +0400 Subject: [PATCH] check_for_matching_routes method added To tidy up __call__ in RouteLayer. --- semantic_router/layer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 23f82da3..83234879 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -182,17 +182,21 @@ class RouteLayer: # initialize index now self._add_routes(routes=self.routes) - def __call__(self, text: str) -> RouteChoice: - results = self._query(text) - top_class, top_class_scores = self._semantic_classify(results) - # get chosen route object + def check_for_matching_routes(self, top_class: str) -> Optional[Route]: matching_routes = [route for route in self.routes if route.name == top_class] if not matching_routes: logger.error( f"No route found with name {top_class}. Check to see if any Routes have been defined." ) + return None + return matching_routes[0] + + def __call__(self, text: str) -> RouteChoice: + results = self._query(text) + top_class, top_class_scores = self._semantic_classify(results) + route = self.check_for_matching_routes(top_class) + if route is None: return RouteChoice() - route = matching_routes[0] threshold = ( route.score_threshold if route.score_threshold is not None -- GitLab