From 77d323bad7ba520abe2b4030308f4dbf56ae3a8f Mon Sep 17 00:00:00 2001
From: tolgadevAI <164843802+tolgadevAI@users.noreply.github.com>
Date: Mon, 26 Aug 2024 23:52:39 +0300
Subject: [PATCH] Develop the process to create and add the function_schema
 field for routes with sync="local"

---
 semantic_router/index/base.py     |   2 +-
 semantic_router/index/pinecone.py | 131 ++++++++++++++++++++----------
 semantic_router/layer.py          |  31 ++++---
 3 files changed, 109 insertions(+), 55 deletions(-)

diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py
index 8ef48967..73467887 100644
--- a/semantic_router/index/base.py
+++ b/semantic_router/index/base.py
@@ -114,7 +114,7 @@ class BaseIndex(BaseModel):
         local_route_names: List[str],
         local_utterances: List[str],
         dimensions: int,
-        local_function_schemas: List[str] | None = None,
+        local_function_schemas: List[Dict[str, Any]],
     ):
         """
         Synchronize the local index with the remote index based on the specified mode.
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 5b88ba57..70833c4a 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -5,7 +5,7 @@ import os
 import time
 import json
 
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Union, Tuple
 
 import numpy as np
 import requests
@@ -213,89 +213,127 @@ class PineconeIndex(BaseIndex):
         local_route_names: List[str],
         local_utterances: List[str],
         dimensions: int,
-        local_function_schemas: List[str] | None = None,
-    ):
+        local_function_schemas: List[Dict[str, Any]],
+    ) -> Tuple:
+
         if self.index is None:
             self.dimensions = self.dimensions or dimensions
             self.index = self._init_index(force_create=True)
 
         remote_routes = self.get_routes()
 
-        remote_dict: dict = {route: set() for route, _ in remote_routes}
-        for route, utterance in remote_routes:
-            remote_dict[route].add(utterance)
+        remote_dict = {
+            route: {"utterances": set(), "function_schemas": set()}
+            for route, _, _ in remote_routes
+        }
+
+        for route, utterance, function_schema in remote_routes:
+            remote_dict[route]["utterances"].add(utterance)
+            remote_dict[route]["function_schemas"].add(function_schema)
 
-        local_dict: dict = {route: set() for route in local_route_names}
-        for route, utterance in zip(local_route_names, local_utterances):
-            local_dict[route].add(utterance)
+        local_dict = {
+            route: {"utterances": set(), "function_schemas": set()}
+            for route in local_route_names
+        }
 
-        all_routes = set(remote_dict.keys()).union(local_dict.keys())
+        for route, utterance, function_schema in zip(
+            local_route_names, local_utterances, local_function_schemas
+        ):
+            local_dict[route]["utterances"].add(utterance)
+            local_dict[route]["function_schemas"].add(json.dumps(function_schema))
 
+        all_routes = set(remote_dict.keys()).union(local_dict.keys())
         routes_to_add = []
         routes_to_delete = []
         layer_routes = {}
 
         for route in all_routes:
-            local_utterances = local_dict.get(route, set())
-            remote_utterances = remote_dict.get(route, set())
+            local_utterances_set = local_dict.get(route, {"utterances": set()})[
+                "utterances"
+            ]
+            remote_utterances_set = remote_dict.get(route, {"utterances": set()})[
+                "utterances"
+            ]
+            local_function_schemas_set = local_dict.get(
+                route, {"function_schemas": set()}
+            )["function_schemas"]
 
-            if not local_utterances and not remote_utterances:
+            remote_function_schemas_set = remote_dict.get(
+                route, {"function_schemas": set()}
+            )["function_schemas"]
+
+            if not local_utterances_set and not remote_utterances_set:
                 continue
 
+            utterances_to_include: set = set()
+
             if self.sync == "error":
-                if local_utterances != remote_utterances:
+                if local_utterances_set != remote_utterances_set:
                     raise ValueError(
                         f"Synchronization error: Differences found in route '{route}'"
                     )
-                utterances_to_include: set = set()
-                if local_utterances:
-                    layer_routes[route] = list(local_utterances)
+                if local_utterances_set:
+                    layer_routes[route] = {"utterances": list(local_utterances_set)}
+
             elif self.sync == "remote":
-                utterances_to_include = set()
-                if remote_utterances:
-                    layer_routes[route] = list(remote_utterances)
+                if remote_utterances_set:
+                    layer_routes[route] = {"utterances": list(remote_utterances_set)}
+
             elif self.sync == "local":
-                utterances_to_include = local_utterances - remote_utterances
+                utterances_to_include = local_utterances_set - remote_utterances_set
                 routes_to_delete.extend(
                     [
                         (route, utterance)
-                        for utterance in remote_utterances
-                        if utterance not in local_utterances
+                        for utterance in remote_utterances_set
+                        if utterance not in local_utterances_set
                     ]
                 )
-                if local_utterances:
-                    layer_routes[route] = list(local_utterances)
+                layer_routes[route] = {}
+                if local_utterances_set:
+                    layer_routes[route]["utterances"] = list(local_utterances_set)
+                if local_function_schemas_set:
+                    layer_routes[route]["function_schemas"] = list(
+                        local_function_schemas_set
+                    )
+
             elif self.sync == "merge-force-remote":
                 if route in local_dict and route not in remote_dict:
                     utterances_to_include = set(local_utterances)
                     if local_utterances:
-                        layer_routes[route] = list(local_utterances)
+                        layer_routes[route] = {"utterances": list(local_utterances)}
                 else:
-                    utterances_to_include = set()
-                    if remote_utterances:
-                        layer_routes[route] = list(remote_utterances)
+                    if remote_utterances_set:
+                        layer_routes[route] = {
+                            "utterances": list(remote_utterances_set)
+                        }
+
             elif self.sync == "merge-force-local":
                 if route in local_dict:
-                    utterances_to_include = local_utterances - remote_utterances
+                    utterances_to_include = local_utterances_set - remote_utterances_set
                     routes_to_delete.extend(
                         [
                             (route, utterance)
-                            for utterance in remote_utterances
-                            if utterance not in local_utterances
+                            for utterance in remote_utterances_set
+                            if utterance not in local_utterances_set
                         ]
                     )
-                    if local_utterances:
-                        layer_routes[route] = local_utterances
+                    if local_utterances_set:
+                        layer_routes[route] = {"utterances": list(local_utterances_set)}
                 else:
-                    utterances_to_include = set()
-                    if remote_utterances:
-                        layer_routes[route] = list(remote_utterances)
+                    if remote_utterances_set:
+                        layer_routes[route] = {
+                            "utterances": list(remote_utterances_set)
+                        }
+
             elif self.sync == "merge":
-                utterances_to_include = local_utterances - remote_utterances
-                if local_utterances or remote_utterances:
-                    layer_routes[route] = list(
-                        remote_utterances.union(local_utterances)
-                    )
+                utterances_to_include = local_utterances_set - remote_utterances_set
+                if local_utterances_set or remote_utterances_set:
+                    layer_routes[route] = {
+                        "utterances": list(
+                            remote_utterances_set.union(local_utterances_set)
+                        )
+                    }
+
             else:
                 raise ValueError("Invalid sync mode specified")
 
@@ -437,7 +475,14 @@ class PineconeIndex(BaseIndex):
         """
         # Get all records
         _, metadata = self._get_all(include_metadata=True)
-        route_tuples = [(x["sr_route"], x["sr_utterance"]) for x in metadata]
+        route_tuples = [
+            (
+                route_objects["sr_route"],
+                route_objects["sr_utterance"],
+                route_objects["function_schemas"],
+            )
+            for route_objects in metadata
+        ]
         return route_tuples
 
     def delete(self, route_name: str):
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index a77c5844..07a39ea8 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -217,13 +217,13 @@ class RouteLayer:
             if route.score_threshold is None:
                 route.score_threshold = self.score_threshold
 
-        if self.routes:
-            self._add_routes(routes=self.routes)
-
         # if routes list has been passed, we initialize index now
         if self.index.sync:
             self._add_and_sync_routes(routes=self.routes)
 
+        if self.routes:
+            self._add_routes(routes=self.routes)
+
     def check_for_matching_routes(self, top_class: str) -> Optional[Route]:
         matching_routes = [route for route in self.routes if route.name == top_class]
         if not matching_routes:
@@ -516,13 +516,20 @@ class RouteLayer:
             dimensions=len(self.encoder(["dummy"])[0]),
         )
 
-        layer_routes = [
-            Route(
-                name=route,
-                utterances=layer_routes_dict[route],
+        layer_routes = []
+        for route in layer_routes_dict.keys():
+            route_data = layer_routes_dict[route]
+            logger.info(
+                f"route_data[function_schemas][0]: {route_data["function_schemas"][0]}"
             )
-            for route in layer_routes_dict.keys()
-        ]
+            if not route_data["function_schemas"][0]:
+                layer_routes.append(
+                    Route(
+                        name=route,
+                        utterances=route_data["utterances"],
+                        function_schemas=None,
+                    )
+                )
 
         data_to_delete: dict = {}
         for route, utterance in routes_to_delete:
@@ -545,11 +552,13 @@ class RouteLayer:
 
         self._set_layer_routes(layer_routes)
 
-    def _extract_routes_details(self, routes: List[Route]) -> Tuple:
+    def _extract_routes_details(
+        self, routes: List[Route]
+    ) -> Tuple[list[str], list[str], List[Dict[str, Any]]]:
         route_names = [route.name for route in routes for _ in route.utterances]
         utterances = [utterance for route in routes for utterance in route.utterances]
         function_schemas = [
-            route.function_schemas if route.function_schemas is not None else ""
+            route.function_schemas[0] if route.function_schemas is not None else []
             for route in routes
             for _ in route.utterances
         ]
-- 
GitLab