From ab9dc774eceb50f2fd102ce7fe685427460abf39 Mon Sep 17 00:00:00 2001
From: tolgadevAI <164843802+tolgadevAI@users.noreply.github.com>
Date: Tue, 27 Aug 2024 02:39:09 +0300
Subject: [PATCH] fix the local sync

---
 semantic_router/index/pinecone.py | 27 ++++++++++++++-------------
 semantic_router/layer.py          | 29 +++++++++++++++++++++--------
 2 files changed, 35 insertions(+), 21 deletions(-)

diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 70833c4a..f4d003f6 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -223,24 +223,25 @@ class PineconeIndex(BaseIndex):
         remote_routes = self.get_routes()
 
         remote_dict = {
-            route: {"utterances": set(), "function_schemas": set()}
+            route: {"utterances": set(), "function_schemas": {}}
             for route, _, _ in remote_routes
         }
 
         for route, utterance, function_schema in remote_routes:
             remote_dict[route]["utterances"].add(utterance)
-            remote_dict[route]["function_schemas"].add(function_schema)
+            remote_dict[route]["function_schemas"].update(function_schema)
 
         local_dict = {
-            route: {"utterances": set(), "function_schemas": set()}
+            route: {"utterances": set(), "function_schemas": {}}
             for route in local_route_names
         }
 
         for route, utterance, function_schema in zip(
             local_route_names, local_utterances, local_function_schemas
         ):
+            logger.info(f"function_schema: {function_schema}")
             local_dict[route]["utterances"].add(utterance)
-            local_dict[route]["function_schemas"].add(json.dumps(function_schema))
+            local_dict[route]["function_schemas"].update(function_schema)
 
         all_routes = set(remote_dict.keys()).union(local_dict.keys())
         routes_to_add = []
@@ -254,12 +255,12 @@ class PineconeIndex(BaseIndex):
             remote_utterances_set = remote_dict.get(route, {"utterances": set()})[
                 "utterances"
             ]
-            local_function_schemas_set = local_dict.get(
-                route, {"function_schemas": set()}
-            )["function_schemas"]
+            local_function_schemas_dict = local_dict.get(route, {}).get(
+                "function_schemas", {}
+            )
 
-            remote_function_schemas_set = remote_dict.get(
-                route, {"function_schemas": set()}
+            remote_function_schemas_dict = remote_dict.get(
+                route, {"function_schemas": {}}
             )["function_schemas"]
 
             if not local_utterances_set and not remote_utterances_set:
@@ -291,10 +292,10 @@ class PineconeIndex(BaseIndex):
                 layer_routes[route] = {}
                 if local_utterances_set:
                     layer_routes[route]["utterances"] = list(local_utterances_set)
-                if local_function_schemas_set:
-                    layer_routes[route]["function_schemas"] = list(
-                        local_function_schemas_set
-                    )
+                if local_function_schemas_dict:
+                    layer_routes[route][
+                        "function_schemas"
+                    ] = local_function_schemas_dict
 
             elif self.sync == "merge-force-remote":
                 if route in local_dict and route not in remote_dict:
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 07a39ea8..a7f6c1c0 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -226,6 +226,8 @@ class RouteLayer:
 
     def check_for_matching_routes(self, top_class: str) -> Optional[Route]:
         matching_routes = [route for route in self.routes if route.name == top_class]
+        logger.info(f"matching_routes: {matching_routes}")
+        logger.info(f"self.routes: {self.routes}")
         if not matching_routes:
             logger.error(
                 f"No route found with name {top_class}. Check to see if any Routes "
@@ -516,20 +518,29 @@ class RouteLayer:
             dimensions=len(self.encoder(["dummy"])[0]),
         )
 
-        layer_routes = []
+        layer_routes: List[Route] = []
+        logger.info(f"layer_routes_dict: {layer_routes_dict}")
         for route in layer_routes_dict.keys():
-            route_data = layer_routes_dict[route]
-            logger.info(
-                f"route_data[function_schemas][0]: {route_data["function_schemas"][0]}"
-            )
-            if not route_data["function_schemas"][0]:
+            logger.info(f"route name: {route}")
+
+            route_ = layer_routes_dict[route]
+            function_schemas = route_.get("function_schemas", None)
+            if not function_schemas:
                 layer_routes.append(
                     Route(
                         name=route,
-                        utterances=route_data["utterances"],
+                        utterances=route_["utterances"],
                         function_schemas=None,
                     )
                 )
+            else:
+                layer_routes.append(
+                    Route(
+                        name=route,
+                        utterances=route_["utterances"],
+                        function_schemas=[function_schemas],
+                    )
+                )
 
         data_to_delete: dict = {}
         for route, utterance in routes_to_delete:
@@ -550,6 +561,8 @@ class RouteLayer:
             function_schemas=local_function_schemas,
         )
 
+        logger.info(f"layer_routes: {layer_routes}")
+
         self._set_layer_routes(layer_routes)
 
     def _extract_routes_details(
@@ -558,7 +571,7 @@ class RouteLayer:
         route_names = [route.name for route in routes for _ in route.utterances]
         utterances = [utterance for route in routes for utterance in route.utterances]
         function_schemas = [
-            route.function_schemas[0] if route.function_schemas is not None else []
+            route.function_schemas[0] if route.function_schemas is not None else {}
             for route in routes
             for _ in route.utterances
         ]
-- 
GitLab