From ab9dc774eceb50f2fd102ce7fe685427460abf39 Mon Sep 17 00:00:00 2001 From: tolgadevAI <164843802+tolgadevAI@users.noreply.github.com> Date: Tue, 27 Aug 2024 02:39:09 +0300 Subject: [PATCH] fix the local sync --- semantic_router/index/pinecone.py | 27 ++++++++++++++------------- semantic_router/layer.py | 29 +++++++++++++++++++++-------- 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 70833c4a..f4d003f6 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -223,24 +223,25 @@ class PineconeIndex(BaseIndex): remote_routes = self.get_routes() remote_dict = { - route: {"utterances": set(), "function_schemas": set()} + route: {"utterances": set(), "function_schemas": {}} for route, _, _ in remote_routes } for route, utterance, function_schema in remote_routes: remote_dict[route]["utterances"].add(utterance) - remote_dict[route]["function_schemas"].add(function_schema) + remote_dict[route]["function_schemas"].update(function_schema) local_dict = { - route: {"utterances": set(), "function_schemas": set()} + route: {"utterances": set(), "function_schemas": {}} for route in local_route_names } for route, utterance, function_schema in zip( local_route_names, local_utterances, local_function_schemas ): + logger.info(f"function_schema: {function_schema}") local_dict[route]["utterances"].add(utterance) - local_dict[route]["function_schemas"].add(json.dumps(function_schema)) + local_dict[route]["function_schemas"].update(function_schema) all_routes = set(remote_dict.keys()).union(local_dict.keys()) routes_to_add = [] @@ -254,12 +255,12 @@ class PineconeIndex(BaseIndex): remote_utterances_set = remote_dict.get(route, {"utterances": set()})[ "utterances" ] - local_function_schemas_set = local_dict.get( - route, {"function_schemas": set()} - )["function_schemas"] + local_function_schemas_dict = local_dict.get(route, {}).get( + "function_schemas", {} + ) - remote_function_schemas_set = remote_dict.get( - route, {"function_schemas": set()} + remote_function_schemas_dict = remote_dict.get( + route, {"function_schemas": {}} )["function_schemas"] if not local_utterances_set and not remote_utterances_set: @@ -291,10 +292,10 @@ class PineconeIndex(BaseIndex): layer_routes[route] = {} if local_utterances_set: layer_routes[route]["utterances"] = list(local_utterances_set) - if local_function_schemas_set: - layer_routes[route]["function_schemas"] = list( - local_function_schemas_set - ) + if local_function_schemas_dict: + layer_routes[route][ + "function_schemas" + ] = local_function_schemas_dict elif self.sync == "merge-force-remote": if route in local_dict and route not in remote_dict: diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 07a39ea8..a7f6c1c0 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -226,6 +226,8 @@ class RouteLayer: def check_for_matching_routes(self, top_class: str) -> Optional[Route]: matching_routes = [route for route in self.routes if route.name == top_class] + logger.info(f"matching_routes: {matching_routes}") + logger.info(f"self.routes: {self.routes}") if not matching_routes: logger.error( f"No route found with name {top_class}. Check to see if any Routes " @@ -516,20 +518,29 @@ class RouteLayer: dimensions=len(self.encoder(["dummy"])[0]), ) - layer_routes = [] + layer_routes: List[Route] = [] + logger.info(f"layer_routes_dict: {layer_routes_dict}") for route in layer_routes_dict.keys(): - route_data = layer_routes_dict[route] - logger.info( - f"route_data[function_schemas][0]: {route_data["function_schemas"][0]}" - ) - if not route_data["function_schemas"][0]: + logger.info(f"route name: {route}") + + route_ = layer_routes_dict[route] + function_schemas = route_.get("function_schemas", None) + if not function_schemas: layer_routes.append( Route( name=route, - utterances=route_data["utterances"], + utterances=route_["utterances"], function_schemas=None, ) ) + else: + layer_routes.append( + Route( + name=route, + utterances=route_["utterances"], + function_schemas=[function_schemas], + ) + ) data_to_delete: dict = {} for route, utterance in routes_to_delete: @@ -550,6 +561,8 @@ class RouteLayer: function_schemas=local_function_schemas, ) + logger.info(f"layer_routes: {layer_routes}") + self._set_layer_routes(layer_routes) def _extract_routes_details( @@ -558,7 +571,7 @@ class RouteLayer: route_names = [route.name for route in routes for _ in route.utterances] utterances = [utterance for route in routes for utterance in route.utterances] function_schemas = [ - route.function_schemas[0] if route.function_schemas is not None else [] + route.function_schemas[0] if route.function_schemas is not None else {} for route in routes for _ in route.utterances ] -- GitLab