From f230223021de7e94fd6002b1dfbdb4ae187136e7 Mon Sep 17 00:00:00 2001
From: tolgadevAI <164843802+tolgadevAI@users.noreply.github.com>
Date: Tue, 27 Aug 2024 12:06:33 +0300
Subject: [PATCH] various optimizations for remote and local routes

---
 semantic_router/index/base.py     |   2 +-
 semantic_router/index/local.py    |   2 +-
 semantic_router/index/pinecone.py | 107 ++++++++++++++++--------------
 semantic_router/index/qdrant.py   |   2 +-
 semantic_router/layer.py          |  38 ++++-------
 5 files changed, 74 insertions(+), 77 deletions(-)

diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py
index 9d824661..750e5d87 100644
--- a/semantic_router/index/base.py
+++ b/semantic_router/index/base.py
@@ -113,8 +113,8 @@ class BaseIndex(BaseModel):
         self,
         local_route_names: List[str],
         local_utterances: List[str],
-        dimensions: int,
         local_function_schemas: List[Dict[str, Any]],
+        dimensions: int,
     ):
         """
         Synchronize the local index with the remote index based on the specified mode.
diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py
index 7bc12bba..be4b48dd 100644
--- a/semantic_router/index/local.py
+++ b/semantic_router/index/local.py
@@ -52,8 +52,8 @@ class LocalIndex(BaseIndex):
         self,
         local_route_names: List[str],
         local_utterances: List[str],
+        local_function_schemas: List[Dict[str, Any]],
         dimensions: int,
-        local_function_schemas: List[str] | None = None,
     ):
         if self.sync is not None:
             logger.error("Sync remove is not implemented for LocalIndex.")
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 31c1bf1b..246f0851 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -212,8 +212,8 @@ class PineconeIndex(BaseIndex):
         self,
         local_route_names: List[str],
         local_utterances: List[str],
-        dimensions: int,
         local_function_schemas: List[Dict[str, Any]],
+        dimensions: int,
     ) -> Tuple:
 
         if self.index is None:
@@ -222,23 +222,20 @@ class PineconeIndex(BaseIndex):
 
         remote_routes = self.get_routes()
 
-        remote_dict = {
+        remote_dict: Dict[str, Dict[str, Union[set, Dict]]] = {
             route: {"utterances": set(), "function_schemas": {}}
             for route, _, _ in remote_routes
         }
 
         for route, utterance, function_schema in remote_routes:
-            logger.info(f"function_schema remote: {function_schema}")
-            remote_dict[route]["utterances"].add(utterance)
+            remote_dict[route]["utterances"].add(utterance)  # type: ignore
 
-            if not function_schema:
-                logger.info(f"function_schema remote is empty for {route}")
-                remote_dict[route]["function_schemas"].update({})
-            else:
-                logger.info(f"function_schema remote is not empty for {route}")
-                remote_dict[route]["function_schemas"].update(function_schema)
+            logger.info(
+                f"function_schema remote is {'empty' if not function_schema else 'not empty'} for {route}"
+            )
+            remote_dict[route]["function_schemas"].update(function_schema or {})
 
-        local_dict = {
+        local_dict: Dict[str, Dict[str, Union[set, Dict]]] = {
             route: {"utterances": set(), "function_schemas": {}}
             for route in local_route_names
         }
@@ -246,14 +243,13 @@ class PineconeIndex(BaseIndex):
         for route, utterance, function_schema in zip(
             local_route_names, local_utterances, local_function_schemas
         ):
-            logger.info(f"function_schema local: {function_schema}")
-            local_dict[route]["utterances"].add(utterance)
+            local_dict[route]["utterances"].add(utterance)  # type: ignore
             local_dict[route]["function_schemas"].update(function_schema)
 
         all_routes = set(remote_dict.keys()).union(local_dict.keys())
         routes_to_add = []
         routes_to_delete = []
-        layer_routes = {}
+        layer_routes: Dict[str, Dict[str, Union[List[str], Dict]]] = {}
 
         for route in all_routes:
             local_utterances_set = local_dict.get(route, {"utterances": set()})[
@@ -276,25 +272,27 @@ class PineconeIndex(BaseIndex):
             utterances_to_include: set = set()
 
             if self.sync == "error":
-                if (local_utterances_set != remote_utterances_set) or (local_function_schemas_dict != remote_function_schemas_dict):
+                if (local_utterances_set != remote_utterances_set) or (
+                    local_function_schemas_dict != remote_function_schemas_dict
+                ):
                     raise ValueError(
                         f"Synchronization error: Differences found in route '{route}'"
                     )
                 if local_utterances_set:
                     layer_routes[route] = {"utterances": list(local_utterances_set)}
-                if local_function_schemas_dict:
-                    layer_routes[route][
-                        "function_schemas"
-                    ] = local_function_schemas_dict
+                if isinstance(local_function_schemas_dict, dict):
+                    layer_routes[route]["function_schemas"] = {
+                        **local_function_schemas_dict
+                    }
             elif self.sync == "remote":
                 if remote_utterances_set:
                     layer_routes[route] = {"utterances": list(remote_utterances_set)}
-                if remote_function_schemas_dict:
-                    layer_routes[route][
-                        "function_schemas"
-                    ] = remote_function_schemas_dict
+                if isinstance(remote_function_schemas_dict, dict):
+                    layer_routes[route]["function_schemas"] = {
+                        **remote_function_schemas_dict
+                    }
             elif self.sync == "local":
-                utterances_to_include = local_utterances_set - remote_utterances_set
+                utterances_to_include = local_utterances_set - remote_utterances_set  # type: ignore
                 routes_to_delete.extend(
                     [
                         (route, utterance)
@@ -305,32 +303,32 @@ class PineconeIndex(BaseIndex):
                 layer_routes[route] = {}
                 if local_utterances_set:
                     layer_routes[route] = {"utterances": list(local_utterances_set)}
-                if local_function_schemas_dict:
-                    layer_routes[route][
-                        "function_schemas"
-                    ] = local_function_schemas_dict
+                if isinstance(local_function_schemas_dict, dict):
+                    layer_routes[route]["function_schemas"] = {
+                        **local_function_schemas_dict
+                    }
             elif self.sync == "merge-force-remote":
                 if route in local_dict and route not in remote_dict:
                     utterances_to_include = set(local_utterances)
                     if local_utterances:
                         layer_routes[route] = {"utterances": list(local_utterances)}
-                    if local_function_schemas_dict:
-                        layer_routes[route][
-                            "function_schemas"
-                        ] = local_function_schemas_dict
+                    if isinstance(local_function_schemas_dict, dict):
+                        layer_routes[route]["function_schemas"] = {
+                            **local_function_schemas_dict
+                        }
                 else:
                     if remote_utterances_set:
                         layer_routes[route] = {
                             "utterances": list(remote_utterances_set)
                         }
-                    if remote_function_schemas_dict:
-                        layer_routes[route][
-                            "function_schemas"
-                        ] = remote_function_schemas_dict
+                    if isinstance(remote_function_schemas_dict, dict):
+                        layer_routes[route]["function_schemas"] = {
+                            **remote_function_schemas_dict
+                        }
 
             elif self.sync == "merge-force-local":
                 if route in local_dict:
-                    utterances_to_include = local_utterances_set - remote_utterances_set
+                    utterances_to_include = local_utterances_set - remote_utterances_set  # type: ignore
                     routes_to_delete.extend(
                         [
                             (route, utterance)
@@ -340,32 +338,41 @@ class PineconeIndex(BaseIndex):
                     )
                     if local_utterances_set:
                         layer_routes[route] = {"utterances": list(local_utterances_set)}
-                    if local_function_schemas_dict:
-                        layer_routes[route][
-                            "function_schemas"
-                        ] = local_function_schemas_dict
+                    if isinstance(local_function_schemas_dict, dict):
+                        layer_routes[route]["function_schemas"] = {
+                            **local_function_schemas_dict
+                        }
                 else:
                     if remote_utterances_set:
                         layer_routes[route] = {
                             "utterances": list(remote_utterances_set)
                         }
-                    if remote_function_schemas_dict:
-                        layer_routes[route][
-                            "function_schemas"
-                        ] = remote_function_schemas_dict
+                    if isinstance(remote_function_schemas_dict, dict):
+                        layer_routes[route]["function_schemas"] = {
+                            **remote_function_schemas_dict
+                        }
             elif self.sync == "merge":
-                utterances_to_include = local_utterances_set - remote_utterances_set
+                utterances_to_include = local_utterances_set - remote_utterances_set  # type: ignore
                 if local_utterances_set or remote_utterances_set:
                     layer_routes[route] = {
                         "utterances": list(
-                            remote_utterances_set.union(local_utterances_set)
+                            remote_utterances_set.union(local_utterances_set)  # type: ignore
                         )
                     }
 
                 if local_function_schemas_dict or remote_function_schemas_dict:
-                    layer_routes[route]["function_schemas"] = {
-                        **remote_function_schemas_dict,
-                        **local_function_schemas_dict,
+                    # Ensure both are dictionaries before merging
+                    layer_routes[route]["function_schemas"] = {  # type: ignore
+                        **(
+                            remote_function_schemas_dict
+                            if isinstance(remote_function_schemas_dict, dict)
+                            else {}
+                        ),
+                        **(
+                            local_function_schemas_dict
+                            if isinstance(local_function_schemas_dict, dict)
+                            else {}
+                        ),
                     }
 
             else:
diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py
index 0b414cff..11a0a076 100644
--- a/semantic_router/index/qdrant.py
+++ b/semantic_router/index/qdrant.py
@@ -168,8 +168,8 @@ class QdrantIndex(BaseIndex):
         self,
         local_route_names: List[str],
         local_utterances: List[str],
+        local_function_schemas: List[Dict[str, Any]],
         dimensions: int,
-        local_function_schemas: List[str] | None = None,
     ):
         if self.sync is not None:
             logger.error("Sync remove is not implemented for QdrantIndex.")
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index fee3b665..6a4c4cdf 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -225,16 +225,16 @@ class RouteLayer:
             self._add_routes(routes=self.routes)
 
     def check_for_matching_routes(self, top_class: str) -> Optional[Route]:
-        matching_routes = [route for route in self.routes if route.name == top_class]
-        logger.info(f"matching_routes: {matching_routes}")
-        logger.info(f"self.routes: {self.routes}")
-        if not matching_routes:
+        # Use next with a generator expression for optimization
+        matching_route = next(
+            (route for route in self.routes if route.name == top_class), None
+        )
+        if matching_route is None:
             logger.error(
                 f"No route found with name {top_class}. Check to see if any Routes "
                 "have been defined."
             )
-            return None
-        return matching_routes[0]
+        return matching_route
 
     def __call__(
         self,
@@ -496,7 +496,7 @@ class RouteLayer:
                         function_schemas=(
                             route.function_schemas * len(route.utterances)
                             if route.function_schemas
-                            else [""] * len(route.utterances)  # type: ignore
+                            else [{}] * len(route.utterances)
                         ),
                     )
                 except Exception as e:
@@ -519,27 +519,17 @@ class RouteLayer:
         )
 
         layer_routes: List[Route] = []
-        logger.info(f"layer_routes_dict: {layer_routes_dict}")
+
         for route in layer_routes_dict.keys():
-            logger.info(f"route name: {route}")
             route_dict = layer_routes_dict[route]
             function_schemas = route_dict.get("function_schemas", None)
-            if not function_schemas:
-                layer_routes.append(
-                    Route(
-                        name=route,
-                        utterances=route_dict["utterances"],
-                        function_schemas=None,
-                    )
-                )
-            else:
-                layer_routes.append(
-                    Route(
-                        name=route,
-                        utterances=route_dict["utterances"],
-                        function_schemas=[function_schemas],
-                    )
+            layer_routes.append(
+                Route(
+                    name=route,
+                    utterances=route_dict["utterances"],
+                    function_schemas=[function_schemas] if function_schemas else None,
                 )
+            )
 
         data_to_delete: dict = {}
         for route, utterance in routes_to_delete:
-- 
GitLab