From f230223021de7e94fd6002b1dfbdb4ae187136e7 Mon Sep 17 00:00:00 2001 From: tolgadevAI <164843802+tolgadevAI@users.noreply.github.com> Date: Tue, 27 Aug 2024 12:06:33 +0300 Subject: [PATCH] various optimizations for remote and local routes --- semantic_router/index/base.py | 2 +- semantic_router/index/local.py | 2 +- semantic_router/index/pinecone.py | 107 ++++++++++++++++-------------- semantic_router/index/qdrant.py | 2 +- semantic_router/layer.py | 38 ++++------- 5 files changed, 74 insertions(+), 77 deletions(-) diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 9d824661..750e5d87 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -113,8 +113,8 @@ class BaseIndex(BaseModel): self, local_route_names: List[str], local_utterances: List[str], - dimensions: int, local_function_schemas: List[Dict[str, Any]], + dimensions: int, ): """ Synchronize the local index with the remote index based on the specified mode. diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index 7bc12bba..be4b48dd 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -52,8 +52,8 @@ class LocalIndex(BaseIndex): self, local_route_names: List[str], local_utterances: List[str], + local_function_schemas: List[Dict[str, Any]], dimensions: int, - local_function_schemas: List[str] | None = None, ): if self.sync is not None: logger.error("Sync remove is not implemented for LocalIndex.") diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 31c1bf1b..246f0851 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -212,8 +212,8 @@ class PineconeIndex(BaseIndex): self, local_route_names: List[str], local_utterances: List[str], - dimensions: int, local_function_schemas: List[Dict[str, Any]], + dimensions: int, ) -> Tuple: if self.index is None: @@ -222,23 +222,20 @@ class PineconeIndex(BaseIndex): remote_routes = self.get_routes() - remote_dict = { + remote_dict: Dict[str, Dict[str, Union[set, Dict]]] = { route: {"utterances": set(), "function_schemas": {}} for route, _, _ in remote_routes } for route, utterance, function_schema in remote_routes: - logger.info(f"function_schema remote: {function_schema}") - remote_dict[route]["utterances"].add(utterance) + remote_dict[route]["utterances"].add(utterance) # type: ignore - if not function_schema: - logger.info(f"function_schema remote is empty for {route}") - remote_dict[route]["function_schemas"].update({}) - else: - logger.info(f"function_schema remote is not empty for {route}") - remote_dict[route]["function_schemas"].update(function_schema) + logger.info( + f"function_schema remote is {'empty' if not function_schema else 'not empty'} for {route}" + ) + remote_dict[route]["function_schemas"].update(function_schema or {}) - local_dict = { + local_dict: Dict[str, Dict[str, Union[set, Dict]]] = { route: {"utterances": set(), "function_schemas": {}} for route in local_route_names } @@ -246,14 +243,13 @@ class PineconeIndex(BaseIndex): for route, utterance, function_schema in zip( local_route_names, local_utterances, local_function_schemas ): - logger.info(f"function_schema local: {function_schema}") - local_dict[route]["utterances"].add(utterance) + local_dict[route]["utterances"].add(utterance) # type: ignore local_dict[route]["function_schemas"].update(function_schema) all_routes = set(remote_dict.keys()).union(local_dict.keys()) routes_to_add = [] routes_to_delete = [] - layer_routes = {} + layer_routes: Dict[str, Dict[str, Union[List[str], Dict]]] = {} for route in all_routes: local_utterances_set = local_dict.get(route, {"utterances": set()})[ @@ -276,25 +272,27 @@ class PineconeIndex(BaseIndex): utterances_to_include: set = set() if self.sync == "error": - if (local_utterances_set != remote_utterances_set) or (local_function_schemas_dict != remote_function_schemas_dict): + if (local_utterances_set != remote_utterances_set) or ( + local_function_schemas_dict != remote_function_schemas_dict + ): raise ValueError( f"Synchronization error: Differences found in route '{route}'" ) if local_utterances_set: layer_routes[route] = {"utterances": list(local_utterances_set)} - if local_function_schemas_dict: - layer_routes[route][ - "function_schemas" - ] = local_function_schemas_dict + if isinstance(local_function_schemas_dict, dict): + layer_routes[route]["function_schemas"] = { + **local_function_schemas_dict + } elif self.sync == "remote": if remote_utterances_set: layer_routes[route] = {"utterances": list(remote_utterances_set)} - if remote_function_schemas_dict: - layer_routes[route][ - "function_schemas" - ] = remote_function_schemas_dict + if isinstance(remote_function_schemas_dict, dict): + layer_routes[route]["function_schemas"] = { + **remote_function_schemas_dict + } elif self.sync == "local": - utterances_to_include = local_utterances_set - remote_utterances_set + utterances_to_include = local_utterances_set - remote_utterances_set # type: ignore routes_to_delete.extend( [ (route, utterance) @@ -305,32 +303,32 @@ class PineconeIndex(BaseIndex): layer_routes[route] = {} if local_utterances_set: layer_routes[route] = {"utterances": list(local_utterances_set)} - if local_function_schemas_dict: - layer_routes[route][ - "function_schemas" - ] = local_function_schemas_dict + if isinstance(local_function_schemas_dict, dict): + layer_routes[route]["function_schemas"] = { + **local_function_schemas_dict + } elif self.sync == "merge-force-remote": if route in local_dict and route not in remote_dict: utterances_to_include = set(local_utterances) if local_utterances: layer_routes[route] = {"utterances": list(local_utterances)} - if local_function_schemas_dict: - layer_routes[route][ - "function_schemas" - ] = local_function_schemas_dict + if isinstance(local_function_schemas_dict, dict): + layer_routes[route]["function_schemas"] = { + **local_function_schemas_dict + } else: if remote_utterances_set: layer_routes[route] = { "utterances": list(remote_utterances_set) } - if remote_function_schemas_dict: - layer_routes[route][ - "function_schemas" - ] = remote_function_schemas_dict + if isinstance(remote_function_schemas_dict, dict): + layer_routes[route]["function_schemas"] = { + **remote_function_schemas_dict + } elif self.sync == "merge-force-local": if route in local_dict: - utterances_to_include = local_utterances_set - remote_utterances_set + utterances_to_include = local_utterances_set - remote_utterances_set # type: ignore routes_to_delete.extend( [ (route, utterance) @@ -340,32 +338,41 @@ class PineconeIndex(BaseIndex): ) if local_utterances_set: layer_routes[route] = {"utterances": list(local_utterances_set)} - if local_function_schemas_dict: - layer_routes[route][ - "function_schemas" - ] = local_function_schemas_dict + if isinstance(local_function_schemas_dict, dict): + layer_routes[route]["function_schemas"] = { + **local_function_schemas_dict + } else: if remote_utterances_set: layer_routes[route] = { "utterances": list(remote_utterances_set) } - if remote_function_schemas_dict: - layer_routes[route][ - "function_schemas" - ] = remote_function_schemas_dict + if isinstance(remote_function_schemas_dict, dict): + layer_routes[route]["function_schemas"] = { + **remote_function_schemas_dict + } elif self.sync == "merge": - utterances_to_include = local_utterances_set - remote_utterances_set + utterances_to_include = local_utterances_set - remote_utterances_set # type: ignore if local_utterances_set or remote_utterances_set: layer_routes[route] = { "utterances": list( - remote_utterances_set.union(local_utterances_set) + remote_utterances_set.union(local_utterances_set) # type: ignore ) } if local_function_schemas_dict or remote_function_schemas_dict: - layer_routes[route]["function_schemas"] = { - **remote_function_schemas_dict, - **local_function_schemas_dict, + # Ensure both are dictionaries before merging + layer_routes[route]["function_schemas"] = { # type: ignore + **( + remote_function_schemas_dict + if isinstance(remote_function_schemas_dict, dict) + else {} + ), + **( + local_function_schemas_dict + if isinstance(local_function_schemas_dict, dict) + else {} + ), } else: diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index 0b414cff..11a0a076 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -168,8 +168,8 @@ class QdrantIndex(BaseIndex): self, local_route_names: List[str], local_utterances: List[str], + local_function_schemas: List[Dict[str, Any]], dimensions: int, - local_function_schemas: List[str] | None = None, ): if self.sync is not None: logger.error("Sync remove is not implemented for QdrantIndex.") diff --git a/semantic_router/layer.py b/semantic_router/layer.py index fee3b665..6a4c4cdf 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -225,16 +225,16 @@ class RouteLayer: self._add_routes(routes=self.routes) def check_for_matching_routes(self, top_class: str) -> Optional[Route]: - matching_routes = [route for route in self.routes if route.name == top_class] - logger.info(f"matching_routes: {matching_routes}") - logger.info(f"self.routes: {self.routes}") - if not matching_routes: + # Use next with a generator expression for optimization + matching_route = next( + (route for route in self.routes if route.name == top_class), None + ) + if matching_route is None: logger.error( f"No route found with name {top_class}. Check to see if any Routes " "have been defined." ) - return None - return matching_routes[0] + return matching_route def __call__( self, @@ -496,7 +496,7 @@ class RouteLayer: function_schemas=( route.function_schemas * len(route.utterances) if route.function_schemas - else [""] * len(route.utterances) # type: ignore + else [{}] * len(route.utterances) ), ) except Exception as e: @@ -519,27 +519,17 @@ class RouteLayer: ) layer_routes: List[Route] = [] - logger.info(f"layer_routes_dict: {layer_routes_dict}") + for route in layer_routes_dict.keys(): - logger.info(f"route name: {route}") route_dict = layer_routes_dict[route] function_schemas = route_dict.get("function_schemas", None) - if not function_schemas: - layer_routes.append( - Route( - name=route, - utterances=route_dict["utterances"], - function_schemas=None, - ) - ) - else: - layer_routes.append( - Route( - name=route, - utterances=route_dict["utterances"], - function_schemas=[function_schemas], - ) + layer_routes.append( + Route( + name=route, + utterances=route_dict["utterances"], + function_schemas=[function_schemas] if function_schemas else None, ) + ) data_to_delete: dict = {} for route, utterance in routes_to_delete: -- GitLab