From 77d323bad7ba520abe2b4030308f4dbf56ae3a8f Mon Sep 17 00:00:00 2001 From: tolgadevAI <164843802+tolgadevAI@users.noreply.github.com> Date: Mon, 26 Aug 2024 23:52:39 +0300 Subject: [PATCH] Develop the process to create and add the function_schema field for routes with sync="local" --- semantic_router/index/base.py | 2 +- semantic_router/index/pinecone.py | 131 ++++++++++++++++++++---------- semantic_router/layer.py | 31 ++++--- 3 files changed, 109 insertions(+), 55 deletions(-) diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 8ef48967..73467887 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -114,7 +114,7 @@ class BaseIndex(BaseModel): local_route_names: List[str], local_utterances: List[str], dimensions: int, - local_function_schemas: List[str] | None = None, + local_function_schemas: List[Dict[str, Any]], ): """ Synchronize the local index with the remote index based on the specified mode. diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 5b88ba57..70833c4a 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -5,7 +5,7 @@ import os import time import json -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union, Tuple import numpy as np import requests @@ -213,89 +213,127 @@ class PineconeIndex(BaseIndex): local_route_names: List[str], local_utterances: List[str], dimensions: int, - local_function_schemas: List[str] | None = None, - ): + local_function_schemas: List[Dict[str, Any]], + ) -> Tuple: + if self.index is None: self.dimensions = self.dimensions or dimensions self.index = self._init_index(force_create=True) remote_routes = self.get_routes() - remote_dict: dict = {route: set() for route, _ in remote_routes} - for route, utterance in remote_routes: - remote_dict[route].add(utterance) + remote_dict = { + route: {"utterances": set(), "function_schemas": set()} + for route, _, _ in remote_routes + } + + for route, utterance, function_schema in remote_routes: + remote_dict[route]["utterances"].add(utterance) + remote_dict[route]["function_schemas"].add(function_schema) - local_dict: dict = {route: set() for route in local_route_names} - for route, utterance in zip(local_route_names, local_utterances): - local_dict[route].add(utterance) + local_dict = { + route: {"utterances": set(), "function_schemas": set()} + for route in local_route_names + } - all_routes = set(remote_dict.keys()).union(local_dict.keys()) + for route, utterance, function_schema in zip( + local_route_names, local_utterances, local_function_schemas + ): + local_dict[route]["utterances"].add(utterance) + local_dict[route]["function_schemas"].add(json.dumps(function_schema)) + all_routes = set(remote_dict.keys()).union(local_dict.keys()) routes_to_add = [] routes_to_delete = [] layer_routes = {} for route in all_routes: - local_utterances = local_dict.get(route, set()) - remote_utterances = remote_dict.get(route, set()) + local_utterances_set = local_dict.get(route, {"utterances": set()})[ + "utterances" + ] + remote_utterances_set = remote_dict.get(route, {"utterances": set()})[ + "utterances" + ] + local_function_schemas_set = local_dict.get( + route, {"function_schemas": set()} + )["function_schemas"] - if not local_utterances and not remote_utterances: + remote_function_schemas_set = remote_dict.get( + route, {"function_schemas": set()} + )["function_schemas"] + + if not local_utterances_set and not remote_utterances_set: continue + utterances_to_include: set = set() + if self.sync == "error": - if local_utterances != remote_utterances: + if local_utterances_set != remote_utterances_set: raise ValueError( f"Synchronization error: Differences found in route '{route}'" ) - utterances_to_include: set = set() - if local_utterances: - layer_routes[route] = list(local_utterances) + if local_utterances_set: + layer_routes[route] = {"utterances": list(local_utterances_set)} + elif self.sync == "remote": - utterances_to_include = set() - if remote_utterances: - layer_routes[route] = list(remote_utterances) + if remote_utterances_set: + layer_routes[route] = {"utterances": list(remote_utterances_set)} + elif self.sync == "local": - utterances_to_include = local_utterances - remote_utterances + utterances_to_include = local_utterances_set - remote_utterances_set routes_to_delete.extend( [ (route, utterance) - for utterance in remote_utterances - if utterance not in local_utterances + for utterance in remote_utterances_set + if utterance not in local_utterances_set ] ) - if local_utterances: - layer_routes[route] = list(local_utterances) + layer_routes[route] = {} + if local_utterances_set: + layer_routes[route]["utterances"] = list(local_utterances_set) + if local_function_schemas_set: + layer_routes[route]["function_schemas"] = list( + local_function_schemas_set + ) + elif self.sync == "merge-force-remote": if route in local_dict and route not in remote_dict: utterances_to_include = set(local_utterances) if local_utterances: - layer_routes[route] = list(local_utterances) + layer_routes[route] = {"utterances": list(local_utterances)} else: - utterances_to_include = set() - if remote_utterances: - layer_routes[route] = list(remote_utterances) + if remote_utterances_set: + layer_routes[route] = { + "utterances": list(remote_utterances_set) + } + elif self.sync == "merge-force-local": if route in local_dict: - utterances_to_include = local_utterances - remote_utterances + utterances_to_include = local_utterances_set - remote_utterances_set routes_to_delete.extend( [ (route, utterance) - for utterance in remote_utterances - if utterance not in local_utterances + for utterance in remote_utterances_set + if utterance not in local_utterances_set ] ) - if local_utterances: - layer_routes[route] = local_utterances + if local_utterances_set: + layer_routes[route] = {"utterances": list(local_utterances_set)} else: - utterances_to_include = set() - if remote_utterances: - layer_routes[route] = list(remote_utterances) + if remote_utterances_set: + layer_routes[route] = { + "utterances": list(remote_utterances_set) + } + elif self.sync == "merge": - utterances_to_include = local_utterances - remote_utterances - if local_utterances or remote_utterances: - layer_routes[route] = list( - remote_utterances.union(local_utterances) - ) + utterances_to_include = local_utterances_set - remote_utterances_set + if local_utterances_set or remote_utterances_set: + layer_routes[route] = { + "utterances": list( + remote_utterances_set.union(local_utterances_set) + ) + } + else: raise ValueError("Invalid sync mode specified") @@ -437,7 +475,14 @@ class PineconeIndex(BaseIndex): """ # Get all records _, metadata = self._get_all(include_metadata=True) - route_tuples = [(x["sr_route"], x["sr_utterance"]) for x in metadata] + route_tuples = [ + ( + route_objects["sr_route"], + route_objects["sr_utterance"], + route_objects["function_schemas"], + ) + for route_objects in metadata + ] return route_tuples def delete(self, route_name: str): diff --git a/semantic_router/layer.py b/semantic_router/layer.py index a77c5844..07a39ea8 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -217,13 +217,13 @@ class RouteLayer: if route.score_threshold is None: route.score_threshold = self.score_threshold - if self.routes: - self._add_routes(routes=self.routes) - # if routes list has been passed, we initialize index now if self.index.sync: self._add_and_sync_routes(routes=self.routes) + if self.routes: + self._add_routes(routes=self.routes) + def check_for_matching_routes(self, top_class: str) -> Optional[Route]: matching_routes = [route for route in self.routes if route.name == top_class] if not matching_routes: @@ -516,13 +516,20 @@ class RouteLayer: dimensions=len(self.encoder(["dummy"])[0]), ) - layer_routes = [ - Route( - name=route, - utterances=layer_routes_dict[route], + layer_routes = [] + for route in layer_routes_dict.keys(): + route_data = layer_routes_dict[route] + logger.info( + f"route_data[function_schemas][0]: {route_data["function_schemas"][0]}" ) - for route in layer_routes_dict.keys() - ] + if not route_data["function_schemas"][0]: + layer_routes.append( + Route( + name=route, + utterances=route_data["utterances"], + function_schemas=None, + ) + ) data_to_delete: dict = {} for route, utterance in routes_to_delete: @@ -545,11 +552,13 @@ class RouteLayer: self._set_layer_routes(layer_routes) - def _extract_routes_details(self, routes: List[Route]) -> Tuple: + def _extract_routes_details( + self, routes: List[Route] + ) -> Tuple[list[str], list[str], List[Dict[str, Any]]]: route_names = [route.name for route in routes for _ in route.utterances] utterances = [utterance for route in routes for utterance in route.utterances] function_schemas = [ - route.function_schemas if route.function_schemas is not None else "" + route.function_schemas[0] if route.function_schemas is not None else [] for route in routes for _ in route.utterances ] -- GitLab