From 41d1c15b80f64175ef9f77869d3dcadf4dbc80b2 Mon Sep 17 00:00:00 2001 From: tolgadevAI <164843802+tolgadevAI@users.noreply.github.com> Date: Tue, 27 Aug 2024 15:26:11 +0300 Subject: [PATCH] optimize the '_add_routes' function --- semantic_router/layer.py | 60 +++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 6a4c4cdf..aa4a8ef6 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -482,28 +482,44 @@ class RouteLayer: self.routes.append(route) def _add_routes(self, routes: List[Route]): - if routes: - for route in routes: - logger.info(f"Adding `{route.name}` route") - embeddings = self.encoder(route.utterances) - if route.score_threshold is None: - route.score_threshold = self.score_threshold - try: - self.index.add( - embeddings=embeddings, - routes=[route.name] * len(route.utterances), - utterances=route.utterances, - function_schemas=( - route.function_schemas * len(route.utterances) - if route.function_schemas - else [{}] * len(route.utterances) - ), - ) - except Exception as e: - logger.error( - f"Failed to add route `{route.name}` to the index: {e}" - ) - raise Exception(f"Indexing error for route `{route.name}`") from e + if not routes: + logger.warning("No routes provided to add.") + return + + route_names = [] + all_embeddings = [] + all_utterances = [] + all_function_schemas = [] + + for route in routes: + logger.info(f"Adding `{route.name}` route") + route_embeddings = self.encoder(route.utterances) + + # Set score_threshold if not already set + route.score_threshold = route.score_threshold or self.score_threshold + + # Prepare data for batch insertion + route_names.extend([route.name] * len(route.utterances)) + all_embeddings.extend(route_embeddings) + all_utterances.extend(route.utterances) + all_function_schemas.extend( + route.function_schemas * len(route.utterances) + if route.function_schemas + else [{}] * len(route.utterances) + ) + + try: + # Batch insertion into the index + self.index.add( + embeddings=all_embeddings, + routes=route_names, + utterances=all_utterances, + function_schemas=all_function_schemas, + ) + except Exception as e: + logger.error(f"Failed to add routes to the index: {e}") + raise Exception("Indexing error occurred") from e + def _add_and_sync_routes(self, routes: List[Route]): # create embeddings for all routes and sync at startup with remote ones based on sync setting -- GitLab