From 47056f60234551b6c23f4f57fa4f5634146d6e7c Mon Sep 17 00:00:00 2001 From: James Briggs <james.briggs@hotmail.com> Date: Thu, 14 Nov 2024 15:54:41 +0100 Subject: [PATCH] chore: lint --- semantic_router/index/base.py | 60 ++-- semantic_router/index/local.py | 24 +- semantic_router/index/pinecone.py | 504 +++++++++++++++--------------- semantic_router/index/qdrant.py | 40 ++- semantic_router/layer.py | 79 +++-- semantic_router/schema.py | 193 ++++++------ tests/unit/test_sync.py | 134 ++++---- 7 files changed, 541 insertions(+), 493 deletions(-) diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index be3f6592..d8b7fc5f 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -58,21 +58,21 @@ class BaseIndex(BaseModel): :return: A list of Route objects. :rtype: List[Route] """ - route_tuples = self.get_utterances() + utterances = self.get_utterances() routes_dict: Dict[str, Route] = {} # first create a dictionary of route names to Route objects - for route_name, utterance, function_schema, metadata in route_tuples: + for utt in utterances: # if the route is not in the dictionary, add it - if route_name not in routes_dict: - routes_dict[route_name] = Route( - name=route_name, - utterances=[utterance], - function_schemas=function_schema, - metadata=metadata, + if utt.route not in routes_dict: + routes_dict[utt.route] = Route( + name=utt.route, + utterances=[utt.utterance], + function_schemas=utt.function_schemas, + metadata=utt.metadata, ) else: # otherwise, add the utterance to the route - routes_dict[route_name].utterances.append(utterance) + routes_dict[utt.route].utterances.append(utt.utterance) # then create a list of routes from the dictionary routes: List[Route] = [] for route_name, route in routes_dict.items(): @@ -166,27 +166,27 @@ class BaseIndex(BaseModel): """ raise NotImplementedError("This method should be implemented by subclasses.") - def _sync_index( - self, - local_route_names: List[str], - local_utterances: List[str], - local_function_schemas: List[Dict[str, Any]], - local_metadata: List[Dict[str, Any]], - dimensions: int, - ): - """ - Synchronize the local index with the remote index based on the specified mode. - Modes: - - "error": Raise an error if local and remote are not synchronized. - - "remote": Take remote as the source of truth and update local to align. - - "local": Take local as the source of truth and update remote to align. - - "merge-force-remote": Merge both local and remote taking only remote routes features when a route with same route name is present both locally and remotely. - - "merge-force-local": Merge both local and remote taking only local routes features when a route with same route name is present both locally and remotely. - - "merge": Merge both local and remote, merging also local and remote features when a route with same route name is present both locally and remotely. - - This method should be implemented by subclasses. - """ - raise NotImplementedError("This method should be implemented by subclasses.") + # def _sync_index( + # self, + # local_route_names: List[str], + # local_utterances: List[str], + # local_function_schemas: List[Dict[str, Any]], + # local_metadata: List[Dict[str, Any]], + # dimensions: int, + # ): + # """ + # Synchronize the local index with the remote index based on the specified mode. + # Modes: + # - "error": Raise an error if local and remote are not synchronized. + # - "remote": Take remote as the source of truth and update local to align. + # - "local": Take local as the source of truth and update remote to align. + # - "merge-force-remote": Merge both local and remote taking only remote routes features when a route with same route name is present both locally and remotely. + # - "merge-force-local": Merge both local and remote taking only local routes features when a route with same route name is present both locally and remotely. + # - "merge": Merge both local and remote, merging also local and remote features when a route with same route name is present both locally and remotely. + + # This method should be implemented by subclasses. + # """ + # raise NotImplementedError("This method should be implemented by subclasses.") def _read_hash(self) -> ConfigParameter: """ diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index faf24084..476c1dd9 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -50,16 +50,16 @@ class LocalIndex(BaseIndex): if self.sync is not None: logger.warning("Sync remove is not implemented for LocalIndex.") - def _sync_index( - self, - local_route_names: List[str], - local_utterances: List[str], - local_function_schemas: List[Dict[str, Any]], - local_metadata: List[Dict[str, Any]], - dimensions: int, - ): - if self.sync is not None: - logger.error("Sync remove is not implemented for LocalIndex.") + # def _sync_index( + # self, + # local_route_names: List[str], + # local_utterances: List[str], + # local_function_schemas: List[Dict[str, Any]], + # local_metadata: List[Dict[str, Any]], + # dimensions: int, + # ): + # if self.sync is not None: + # logger.error("Sync remove is not implemented for LocalIndex.") def get_utterances(self) -> List[Utterance]: """ @@ -70,9 +70,7 @@ class LocalIndex(BaseIndex): """ if self.routes is None or self.utterances is None: return [] - return [ - Utterance.from_tuple(x) for x in zip(self.routes, self.utterances) - ] + return [Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)] def describe(self) -> Dict: return { diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 85efd809..d899baf1 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -219,253 +219,253 @@ class PineconeIndex(BaseIndex): self.host = index_stats["host"] if index_stats else None # TODO: deprecate? - def _format_routes_dict_for_sync( - self, - local_route_names: List[str], - local_utterances_list: List[str], - local_function_schemas_list: List[Dict[str, Any]], - local_metadata_list: List[Dict[str, Any]], - remote_routes: List[Tuple], - ) -> Tuple[Dict, Dict]: - remote_dict: Dict[str, Dict[str, Any]] = { - route: { - "utterances": set(), - "function_schemas": function_schemas, - "metadata": metadata, - } - for route, utterance, function_schemas, metadata in remote_routes - } - for route, utterance, function_schemas, metadata in remote_routes: - remote_dict[route]["utterances"].add(utterance) - - local_dict: Dict[str, Dict[str, Any]] = {} - for route, utterance, function_schemas, metadata in zip( - local_route_names, - local_utterances_list, - local_function_schemas_list, - local_metadata_list, - ): - if route not in local_dict: - local_dict[route] = { - "utterances": set(), - "function_schemas": function_schemas, - "metadata": metadata, - } - local_dict[route]["utterances"].add(utterance) - local_dict[route]["function_schemas"] = function_schemas - local_dict[route]["metadata"] = metadata - - return local_dict, remote_dict - - def _sync_index( - self, - local_route_names: List[str], - local_utterances_list: List[str], - local_function_schemas_list: List[Dict[str, Any]], - local_metadata_list: List[Dict[str, Any]], - dimensions: int, - ) -> Tuple[List, List, Dict]: - if self.index is None: - self.dimensions = self.dimensions or dimensions - self.index = self._init_index(force_create=True) - - remote_routes = self.get_utterances() - - local_dict, remote_dict = self._format_routes_dict_for_sync( - local_route_names, - local_utterances_list, - local_function_schemas_list, - local_metadata_list, - remote_routes, - ) - - all_routes = set(remote_dict.keys()).union(local_dict.keys()) - - routes_to_add = [] - routes_to_delete = [] - layer_routes = {} - - for route in all_routes: - local_utterances = local_dict.get(route, {}).get("utterances", set()) - remote_utterances = remote_dict.get(route, {}).get("utterances", set()) - local_function_schemas = ( - local_dict.get(route, {}).get("function_schemas", {}) or {} - ) - remote_function_schemas = ( - remote_dict.get(route, {}).get("function_schemas", {}) or {} - ) - local_metadata = local_dict.get(route, {}).get("metadata", {}) - remote_metadata = remote_dict.get(route, {}).get("metadata", {}) - - utterances_to_include = set() - - metadata_changed = local_metadata != remote_metadata - function_schema_changed = local_function_schemas != remote_function_schemas - - if self.sync == "error": - if ( - local_utterances != remote_utterances - or local_function_schemas != remote_function_schemas - or local_metadata != remote_metadata - ): - raise ValueError( - f"Synchronization error: Differences found in route '{route}'" - ) - - if local_utterances: - layer_routes[route] = { - "utterances": list(local_utterances), - "function_schemas": ( - local_function_schemas if local_function_schemas else None - ), - "metadata": local_metadata, - } - - elif self.sync == "remote": - if remote_utterances: - layer_routes[route] = { - "utterances": list(remote_utterances), - "function_schemas": ( - remote_function_schemas if remote_function_schemas else None - ), - "metadata": remote_metadata, - } - - elif self.sync == "local": - utterances_to_include = local_utterances - remote_utterances - routes_to_delete.extend( - [ - (route, utterance) - for utterance in remote_utterances - if utterance not in local_utterances - ] - ) - if local_utterances: - layer_routes[route] = { - "utterances": list(local_utterances), - "function_schemas": ( - local_function_schemas if local_function_schemas else None - ), - "metadata": local_metadata, - } - - elif self.sync == "merge-force-remote": - if route in local_dict and route not in remote_dict: - utterances_to_include = local_utterances - if local_utterances: - layer_routes[route] = { - "utterances": list(local_utterances), - "function_schemas": ( - local_function_schemas - if local_function_schemas - else None - ), - "metadata": local_metadata, - } - else: - if remote_utterances: - layer_routes[route] = { - "utterances": list(remote_utterances), - "function_schemas": ( - remote_function_schemas - if remote_function_schemas - else None - ), - "metadata": remote_metadata, - } - - elif self.sync == "merge-force-local": - if route in local_dict: - utterances_to_include = local_utterances - remote_utterances - routes_to_delete.extend( - [ - (route, utterance) - for utterance in remote_utterances - if utterance not in local_utterances - ] - ) - if local_utterances: - layer_routes[route] = { - "utterances": list(local_utterances), - "function_schemas": ( - local_function_schemas - if local_function_schemas - else None - ), - "metadata": local_metadata, - } - else: - if remote_utterances: - layer_routes[route] = { - "utterances": list(remote_utterances), - "function_schemas": ( - remote_function_schemas - if remote_function_schemas - else None - ), - "metadata": remote_metadata, - } - - elif self.sync == "merge": - utterances_to_include = local_utterances - remote_utterances - if local_utterances or remote_utterances: - # Here metadata are merged, with local metadata taking precedence for same keys - merged_metadata = {**remote_metadata, **local_metadata} - merged_function_schemas = { - **remote_function_schemas, - **local_function_schemas, - } - layer_routes[route] = { - "utterances": list(remote_utterances.union(local_utterances)), - "function_schemas": ( - merged_function_schemas if merged_function_schemas else None - ), - "metadata": merged_metadata, - } - - else: - raise ValueError("Invalid sync mode specified") - - # Add utterances if metadata has changed or if there are new utterances - if (metadata_changed or function_schema_changed) and self.sync in [ - "local", - "merge-force-local", - ]: - for utterance in local_utterances: - routes_to_add.append( - ( - route, - utterance, - local_function_schemas if local_function_schemas else None, - local_metadata, - ) - ) - if (metadata_changed or function_schema_changed) and self.sync == "merge": - for utterance in local_utterances: - routes_to_add.append( - ( - route, - utterance, - ( - merged_function_schemas - if merged_function_schemas - else None - ), - merged_metadata, - ) - ) - elif utterances_to_include: - for utterance in utterances_to_include: - routes_to_add.append( - ( - route, - utterance, - local_function_schemas if local_function_schemas else None, - local_metadata, - ) - ) - - return routes_to_add, routes_to_delete, layer_routes + # def _format_routes_dict_for_sync( + # self, + # local_route_names: List[str], + # local_utterances_list: List[str], + # local_function_schemas_list: List[Dict[str, Any]], + # local_metadata_list: List[Dict[str, Any]], + # remote_routes: List[Tuple], + # ) -> Tuple[Dict, Dict]: + # remote_dict: Dict[str, Dict[str, Any]] = { + # route: { + # "utterances": set(), + # "function_schemas": function_schemas, + # "metadata": metadata, + # } + # for route, utterance, function_schemas, metadata in remote_routes + # } + # for route, utterance, function_schemas, metadata in remote_routes: + # remote_dict[route]["utterances"].add(utterance) + + # local_dict: Dict[str, Dict[str, Any]] = {} + # for route, utterance, function_schemas, metadata in zip( + # local_route_names, + # local_utterances_list, + # local_function_schemas_list, + # local_metadata_list, + # ): + # if route not in local_dict: + # local_dict[route] = { + # "utterances": set(), + # "function_schemas": function_schemas, + # "metadata": metadata, + # } + # local_dict[route]["utterances"].add(utterance) + # local_dict[route]["function_schemas"] = function_schemas + # local_dict[route]["metadata"] = metadata + + # return local_dict, remote_dict + + # def _sync_index( + # self, + # local_route_names: List[str], + # local_utterances_list: List[str], + # local_function_schemas_list: List[Dict[str, Any]], + # local_metadata_list: List[Dict[str, Any]], + # dimensions: int, + # ) -> Tuple[List, List, Dict]: + # if self.index is None: + # self.dimensions = self.dimensions or dimensions + # self.index = self._init_index(force_create=True) + + # remote_routes = self.get_utterances() + + # local_dict, remote_dict = self._format_routes_dict_for_sync( + # local_route_names, + # local_utterances_list, + # local_function_schemas_list, + # local_metadata_list, + # remote_routes, + # ) + + # all_routes = set(remote_dict.keys()).union(local_dict.keys()) + + # routes_to_add = [] + # routes_to_delete = [] + # layer_routes = {} + + # for route in all_routes: + # local_utterances = local_dict.get(route, {}).get("utterances", set()) + # remote_utterances = remote_dict.get(route, {}).get("utterances", set()) + # local_function_schemas = ( + # local_dict.get(route, {}).get("function_schemas", {}) or {} + # ) + # remote_function_schemas = ( + # remote_dict.get(route, {}).get("function_schemas", {}) or {} + # ) + # local_metadata = local_dict.get(route, {}).get("metadata", {}) + # remote_metadata = remote_dict.get(route, {}).get("metadata", {}) + + # utterances_to_include = set() + + # metadata_changed = local_metadata != remote_metadata + # function_schema_changed = local_function_schemas != remote_function_schemas + + # if self.sync == "error": + # if ( + # local_utterances != remote_utterances + # or local_function_schemas != remote_function_schemas + # or local_metadata != remote_metadata + # ): + # raise ValueError( + # f"Synchronization error: Differences found in route '{route}'" + # ) + + # if local_utterances: + # layer_routes[route] = { + # "utterances": list(local_utterances), + # "function_schemas": ( + # local_function_schemas if local_function_schemas else None + # ), + # "metadata": local_metadata, + # } + + # elif self.sync == "remote": + # if remote_utterances: + # layer_routes[route] = { + # "utterances": list(remote_utterances), + # "function_schemas": ( + # remote_function_schemas if remote_function_schemas else None + # ), + # "metadata": remote_metadata, + # } + + # elif self.sync == "local": + # utterances_to_include = local_utterances - remote_utterances + # routes_to_delete.extend( + # [ + # (route, utterance) + # for utterance in remote_utterances + # if utterance not in local_utterances + # ] + # ) + # if local_utterances: + # layer_routes[route] = { + # "utterances": list(local_utterances), + # "function_schemas": ( + # local_function_schemas if local_function_schemas else None + # ), + # "metadata": local_metadata, + # } + + # elif self.sync == "merge-force-remote": + # if route in local_dict and route not in remote_dict: + # utterances_to_include = local_utterances + # if local_utterances: + # layer_routes[route] = { + # "utterances": list(local_utterances), + # "function_schemas": ( + # local_function_schemas + # if local_function_schemas + # else None + # ), + # "metadata": local_metadata, + # } + # else: + # if remote_utterances: + # layer_routes[route] = { + # "utterances": list(remote_utterances), + # "function_schemas": ( + # remote_function_schemas + # if remote_function_schemas + # else None + # ), + # "metadata": remote_metadata, + # } + + # elif self.sync == "merge-force-local": + # if route in local_dict: + # utterances_to_include = local_utterances - remote_utterances + # routes_to_delete.extend( + # [ + # (route, utterance) + # for utterance in remote_utterances + # if utterance not in local_utterances + # ] + # ) + # if local_utterances: + # layer_routes[route] = { + # "utterances": list(local_utterances), + # "function_schemas": ( + # local_function_schemas + # if local_function_schemas + # else None + # ), + # "metadata": local_metadata, + # } + # else: + # if remote_utterances: + # layer_routes[route] = { + # "utterances": list(remote_utterances), + # "function_schemas": ( + # remote_function_schemas + # if remote_function_schemas + # else None + # ), + # "metadata": remote_metadata, + # } + + # elif self.sync == "merge": + # utterances_to_include = local_utterances - remote_utterances + # if local_utterances or remote_utterances: + # # Here metadata are merged, with local metadata taking precedence for same keys + # merged_metadata = {**remote_metadata, **local_metadata} + # merged_function_schemas = { + # **remote_function_schemas, + # **local_function_schemas, + # } + # layer_routes[route] = { + # "utterances": list(remote_utterances.union(local_utterances)), + # "function_schemas": ( + # merged_function_schemas if merged_function_schemas else None + # ), + # "metadata": merged_metadata, + # } + + # else: + # raise ValueError("Invalid sync mode specified") + + # # Add utterances if metadata has changed or if there are new utterances + # if (metadata_changed or function_schema_changed) and self.sync in [ + # "local", + # "merge-force-local", + # ]: + # for utterance in local_utterances: + # routes_to_add.append( + # ( + # route, + # utterance, + # local_function_schemas if local_function_schemas else None, + # local_metadata, + # ) + # ) + # if (metadata_changed or function_schema_changed) and self.sync == "merge": + # for utterance in local_utterances: + # routes_to_add.append( + # ( + # route, + # utterance, + # ( + # merged_function_schemas + # if merged_function_schemas + # else None + # ), + # merged_metadata, + # ) + # ) + # elif utterances_to_include: + # for utterance in utterances_to_include: + # routes_to_add.append( + # ( + # route, + # utterance, + # local_function_schemas if local_function_schemas else None, + # local_metadata, + # ) + # ) + + # return routes_to_add, routes_to_delete, layer_routes def _batch_upsert(self, batch: List[Dict]): """Helper method for upserting a single batch of records.""" @@ -484,8 +484,8 @@ class PineconeIndex(BaseIndex): batch_size: int = 100, ): """Add vectors to Pinecone in batches.""" - temp = '\n'.join([f"{x[0]}: {x[1]}" for x in zip(routes, utterances)]) - logger.warning("TEMP | add:\n"+temp) + temp = "\n".join([f"{x[0]}: {x[1]}" for x in zip(routes, utterances)]) + logger.warning("TEMP | add:\n" + temp) if self.index is None: self.dimensions = self.dimensions or len(embeddings[0]) self.index = self._init_index(force_create=True) @@ -508,8 +508,10 @@ class PineconeIndex(BaseIndex): self._batch_upsert(batch) def _remove_and_sync(self, routes_to_delete: dict): - temp = '\n'.join([f"{route}: {utterances}" for route, utterances in routes_to_delete.items()]) - logger.warning("TEMP | _remove_and_sync:\n"+temp) + temp = "\n".join( + [f"{route}: {utterances}" for route, utterances in routes_to_delete.items()] + ) + logger.warning("TEMP | _remove_and_sync:\n" + temp) for route, utterances in routes_to_delete.items(): remote_routes = self._get_routes_with_ids(route_name=route) ids_to_delete = [ diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index cfc492f6..2c28aaaf 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -4,7 +4,7 @@ import numpy as np from pydantic.v1 import Field from semantic_router.index.base import BaseIndex -from semantic_router.schema import ConfigParameter, Metric +from semantic_router.schema import ConfigParameter, Metric, Utterance from semantic_router.utils.logger import logger DEFAULT_COLLECTION_NAME = "semantic-router-index" @@ -163,16 +163,16 @@ class QdrantIndex(BaseIndex): if self.sync is not None: logger.error("Sync remove is not implemented for QdrantIndex.") - def _sync_index( - self, - local_route_names: List[str], - local_utterances_list: List[str], - local_function_schemas: List[Dict[str, Any]], - local_metadata_list: List[Dict[str, Any]], - dimensions: int, - ): - if self.sync is not None: - logger.error("Sync remove is not implemented for QdrantIndex.") + # def _sync_index( + # self, + # local_route_names: List[str], + # local_utterances_list: List[str], + # local_function_schemas: List[Dict[str, Any]], + # local_metadata_list: List[Dict[str, Any]], + # dimensions: int, + # ): + # if self.sync is not None: + # logger.error("Sync remove is not implemented for QdrantIndex.") def add( self, @@ -199,7 +199,7 @@ class QdrantIndex(BaseIndex): batch_size=batch_size, ) - def get_utterances(self) -> List[Tuple]: + def get_utterances(self) -> List[Utterance]: """ Gets a list of route and utterance objects currently stored in the index. @@ -228,21 +228,19 @@ class QdrantIndex(BaseIndex): results.extend(records) - route_tuples: List[ - Tuple[str, str, Optional[Dict[str, Any]], Dict[str, Any]] - ] = [ - ( - x.payload[SR_ROUTE_PAYLOAD_KEY], - x.payload[SR_UTTERANCE_PAYLOAD_KEY], - None, - {}, + utterances: List[Utterance] = [ + Utterance( + route=x.payload[SR_ROUTE_PAYLOAD_KEY], + utterance=x.payload[SR_UTTERANCE_PAYLOAD_KEY], + function_schemas=None, + metadata={}, ) for x in results ] except ValueError as e: logger.warning(f"Index likely empty, error: {e}") return [] - return route_tuples + return utterances def delete(self, route_name: str): from qdrant_client import models diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 64186e90..3013e340 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -1,4 +1,3 @@ -from difflib import Differ import importlib import json import os @@ -13,9 +12,16 @@ from tqdm.auto import tqdm from semantic_router.encoders import AutoEncoder, BaseEncoder, OpenAIEncoder from semantic_router.index.base import BaseIndex from semantic_router.index.local import LocalIndex +from semantic_router.index.pinecone import PineconeIndex from semantic_router.llms import BaseLLM, OpenAILLM from semantic_router.route import Route -from semantic_router.schema import ConfigParameter, EncoderType, RouteChoice, Utterance, UtteranceDiff +from semantic_router.schema import ( + ConfigParameter, + EncoderType, + RouteChoice, + Utterance, + UtteranceDiff, +) from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.logger import logger @@ -183,7 +189,7 @@ class LayerConfig: """ remote_routes = index.get_utterances() return cls.from_tuples( - route_tuples=remote_routes, + route_tuples=[utt.to_tuple() for utt in remote_routes], encoder_type=encoder_type, encoder_name=encoder_name, ) @@ -226,14 +232,17 @@ class LayerConfig: """ utterances = [] for route in self.routes: - utterances.extend([ - Utterance( - route=route.name, - utterance=x, - function_schemas=route.function_schemas, - metadata=route.metadata - ) for x in route.utterances - ]) + utterances.extend( + [ + Utterance( + route=route.name, + utterance=x, + function_schemas=route.function_schemas, + metadata=route.metadata or {}, + ) + for x in route.utterances + ] + ) return utterances def add(self, route: Route): @@ -316,13 +325,14 @@ class RouteLayer: dims = len(self.encoder(["test"])[0]) self.index.dimensions = dims # now init index - self.index.index = self.index._init_index(force_create=True) + if isinstance(self.index, PineconeIndex): + self.index.index = self.index._init_index(force_create=True) if len(self.routes) > 0: local_utterances = self.to_config().to_utterances() remote_utterances = self.index.get_utterances() diff = UtteranceDiff.from_utterances( local_utterances=local_utterances, - remote_utterances=remote_utterances + remote_utterances=remote_utterances, ) sync_strategy = diff.get_sync_strategy(self.auto_sync) self._execute_sync_strategy(sync_strategy) @@ -480,7 +490,7 @@ class RouteLayer: ) # generate sync strategy sync_strategy = diff.to_sync_strategy() - #Â and execute + # and execute self._execute_sync_strategy(sync_strategy) return diff.to_utterance_str() @@ -494,9 +504,7 @@ class RouteLayer: if strategy["remote"]["delete"]: data_to_delete = {} # type: ignore for utt_obj in strategy["remote"]["delete"]: - data_to_delete.setdefault( - utt_obj.route, [] - ).append(utt_obj.utterance) + data_to_delete.setdefault(utt_obj.route, []).append(utt_obj.utterance) # TODO: switch to remove without sync?? self.index._remove_and_sync(data_to_delete) if strategy["remote"]["upsert"]: @@ -505,7 +513,9 @@ class RouteLayer: embeddings=self.encoder(utterances_text), routes=[utt.route for utt in strategy["remote"]["upsert"]], utterances=utterances_text, - function_schemas=[utt.function_schemas for utt in strategy["remote"]["upsert"]], + function_schemas=[ + utt.function_schemas for utt in strategy["remote"]["upsert"] # type: ignore + ], metadata_list=[utt.metadata for utt in strategy["remote"]["upsert"]], ) if strategy["local"]["delete"]: @@ -528,15 +538,15 @@ class RouteLayer: name=utt_obj.route, utterances=[utt_obj.utterance], function_schemas=utt_obj.function_schemas, - metadata=utt_obj.metadata + metadata=utt_obj.metadata, ) else: if utt_obj.utterance not in new_routes[utt_obj.route].utterances: new_routes[utt_obj.route].utterances.append(utt_obj.utterance) new_routes[utt_obj.route].function_schemas = utt_obj.function_schemas new_routes[utt_obj.route].metadata = utt_obj.metadata - temp = '\n'.join([f"{name}: {r.utterances}" for name, r in new_routes.items()]) - logger.warning("TEMP | _local_upsert:\n"+temp) + temp = "\n".join([f"{name}: {r.utterances}" for name, r in new_routes.items()]) + logger.warning("TEMP | _local_upsert:\n" + temp) self.routes = list(new_routes.values()) def _local_delete(self, utterances: List[Utterance]): @@ -546,17 +556,19 @@ class RouteLayer: :type utterances: List[Utterance] """ # create dictionary of route names to utterances - route_dict = {} + route_dict: dict[str, List[str]] = {} for utt in utterances: route_dict.setdefault(utt.route, []).append(utt.utterance) - temp = '\n'.join([f"{r}: {u}" for r, u in route_dict.items()]) - logger.warning("TEMP | _local_delete:\n"+temp) + temp = "\n".join([f"{r}: {u}" for r, u in route_dict.items()]) + logger.warning("TEMP | _local_delete:\n" + temp) # iterate over current routes and delete specific utterance if found new_routes = [] for route in self.routes: if route.name in route_dict.keys(): # drop utterances that are in route_dict deletion list - new_utterances = list(set(route.utterances) - set(route_dict[route.name])) + new_utterances = list( + set(route.utterances) - set(route_dict[route.name]) + ) if len(new_utterances) == 0: # the route is now empty, so we skip it continue @@ -567,19 +579,22 @@ class RouteLayer: utterances=new_utterances, # use existing function schemas and metadata function_schemas=route.function_schemas, - metadata=route.metadata + metadata=route.metadata, ) ) - logger.warning(f"TEMP | _local_delete OLD | {route.name}: {route.utterances}") - logger.warning(f"TEMP | _local_delete NEW | {route.name}: {new_routes[-1].utterances}") + logger.warning( + f"TEMP | _local_delete OLD | {route.name}: {route.utterances}" + ) + logger.warning( + f"TEMP | _local_delete NEW | {route.name}: {new_routes[-1].utterances}" + ) else: # the route is not in the route_dict, so we keep it as is new_routes.append(route) - temp = '\n'.join([f"{r}: {u}" for r, u in route_dict.items()]) - logger.warning("TEMP | _local_delete:\n"+temp) - - self.routes = new_routes + temp = "\n".join([f"{r}: {u}" for r, u in route_dict.items()]) + logger.warning("TEMP | _local_delete:\n" + temp) + self.routes = new_routes def _retrieve_top_route( self, vector: List[float], route_filter: Optional[List[str]] = None diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 13f3c5d5..d367e55e 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -20,9 +20,6 @@ class EncoderType(Enum): GOOGLE = "google" BEDROCK = "bedrock" - def to_list(): - return [encoder.value for encoder in EncoderType] - class EncoderInfo(BaseModel): name: str @@ -118,10 +115,15 @@ class Utterance(BaseModel): route=route, utterance=utterance, function_schemas=function_schemas, - metadata=metadata + metadata=metadata, ) def to_tuple(self): + """Convert an Utterance object to a tuple. + + :return: A tuple containing (route, utterance, function schemas, metadata). + :rtype: Tuple + """ return ( self.route, self.utterance, @@ -142,6 +144,7 @@ class SyncMode(Enum): """Synchronization modes for local (route layer) and remote (index) instances. """ + ERROR = "error" REMOTE = "remote" LOCAL = "local" @@ -149,20 +152,23 @@ class SyncMode(Enum): MERGE_FORCE_LOCAL = "merge-force-local" MERGE = "merge" - def to_list() -> List[str]: - return [mode.value for mode in SyncMode] + +SYNC_MODES = [x.value for x in SyncMode] + class UtteranceDiff(BaseModel): diff: List[Utterance] @classmethod def from_utterances( - cls, - local_utterances: List[Utterance], - remote_utterances: List[Utterance] + cls, local_utterances: List[Utterance], remote_utterances: List[Utterance] ): - local_utterances_map = {x.to_str(include_metadata=True): x for x in local_utterances} - remote_utterances_map = {x.to_str(include_metadata=True): x for x in remote_utterances} + local_utterances_map = { + x.to_str(include_metadata=True): x for x in local_utterances + } + remote_utterances_map = { + x.to_str(include_metadata=True): x for x in remote_utterances + } # sort local and remote utterances local_utterances_str = list(local_utterances_map.keys()) local_utterances_str.sort() @@ -179,7 +185,11 @@ class UtteranceDiff(BaseModel): if utterance_diff_tag == "?": # this is a new line from diff string, we can ignore continue - utterance = remote_utterances_map[utterance_str] if utterance_diff_tag == "+" else local_utterances_map[utterance_str] + utterance = ( + remote_utterances_map[utterance_str] + if utterance_diff_tag == "+" + else local_utterances_map[utterance_str] + ) utterance.diff_tag = utterance_diff_tag utterance_diffs.append(utterance) return UtteranceDiff(diff=utterance_diffs) @@ -229,12 +239,16 @@ class UtteranceDiff(BaseModel): :return: A dictionary describing the synchronization strategy. :rtype: dict """ - if sync_mode not in SyncMode.to_list(): - raise ValueError(f"sync_mode must be one of {SyncMode.to_list()}") + if sync_mode not in SYNC_MODES: + raise ValueError(f"sync_mode must be one of {SYNC_MODES}") local_only = self.get_tag("-") - local_only_mapper = {utt.route: (utt.function_schemas, utt.metadata) for utt in local_only} + local_only_mapper = { + utt.route: (utt.function_schemas, utt.metadata) for utt in local_only + } remote_only = self.get_tag("+") - remote_only_mapper = {utt.route: (utt.function_schemas, utt.metadata) for utt in remote_only} + remote_only_mapper = { + utt.route: (utt.function_schemas, utt.metadata) for utt in remote_only + } local_and_remote = self.get_tag(" ") if sync_mode == "error": if len(local_only) > 0 or len(remote_only) > 0: @@ -245,36 +259,21 @@ class UtteranceDiff(BaseModel): ) else: return { - "remote": { - "upsert": [], - "delete": [] - }, - "local": { - "upsert": [], - "delete": [] - } + "remote": {"upsert": [], "delete": []}, + "local": {"upsert": [], "delete": []}, } elif sync_mode == "local": return { "remote": { - "upsert": local_only,# + remote_updates, - "delete": remote_only + "upsert": local_only, # + remote_updates, + "delete": remote_only, }, - "local": { - "upsert": [], - "delete": [] - } + "local": {"upsert": [], "delete": []}, } elif sync_mode == "remote": return { - "remote": { - "upsert": [], - "delete": [] - }, - "local": { - "upsert": remote_only, - "delete": local_only - } + "remote": {"upsert": [], "delete": []}, + "local": {"upsert": remote_only, "delete": local_only}, } elif sync_mode == "merge-force-remote": # merge-to-local merge-join-local # PRIORITIZE LOCAL @@ -282,12 +281,17 @@ class UtteranceDiff(BaseModel): # they are in remote) local_route_names = set([utt.route for utt in local_only]) # if we see route: utterance exists in local, we do not pull it in - #Â from remote + # from remote local_route_utt_strs = set([utt.to_str() for utt in local_only]) # get remote utterances that are in local - remote_to_keep = [utt for utt in remote_only if ( - utt.route in local_route_names and utt.to_str() not in local_route_utt_strs - )] + remote_to_keep = [ + utt + for utt in remote_only + if ( + utt.route in local_route_names + and utt.to_str() not in local_route_utt_strs + ) + ] # overwrite remote routes with local metadata and function schemas logger.info(f"local_only_mapper: {local_only_mapper}") remote_to_update = [ @@ -295,11 +299,14 @@ class UtteranceDiff(BaseModel): route=utt.route, utterance=utt.utterance, metadata=local_only_mapper[utt.route][1], - function_schemas=local_only_mapper[utt.route][0] - ) for utt in remote_only if ( - utt.route in local_only_mapper and ( - utt.metadata != local_only_mapper[utt.route][1] or - utt.function_schemas != local_only_mapper[utt.route][0] + function_schemas=local_only_mapper[utt.route][0], + ) + for utt in remote_only + if ( + utt.route in local_only_mapper + and ( + utt.metadata != local_only_mapper[utt.route][1] + or utt.function_schemas != local_only_mapper[utt.route][0] ) ) ] @@ -308,64 +315,69 @@ class UtteranceDiff(BaseModel): route=utt.route, utterance=utt.utterance, metadata=local_only_mapper[utt.route][1], - function_schemas=local_only_mapper[utt.route][0] - ) for utt in remote_to_keep if utt.to_str() not in [ - x.to_str() for x in remote_to_update - ] + function_schemas=local_only_mapper[utt.route][0], + ) + for utt in remote_to_keep + if utt.to_str() not in [x.to_str() for x in remote_to_update] + ] + # get remote utterances that are NOT in local + remote_to_delete = [ + utt for utt in remote_only if utt.route not in local_route_names ] - #Â get remote utterances that are NOT in local - remote_to_delete = [utt for utt in remote_only if utt.route not in local_route_names] return { "remote": { "upsert": local_only + remote_to_update, - "delete": remote_to_delete + "delete": remote_to_delete, }, - "local": { - "upsert": remote_to_keep, - "delete": [] - } + "local": {"upsert": remote_to_keep, "delete": []}, } elif sync_mode == "merge-force-local": # merge-to-remote merge-join-remote # get set of route names that exist in remote (we keep these if # they are in local) remote_route_names = set([utt.route for utt in remote_only]) # if we see route: utterance exists in remote, we do not pull it in - #Â from local + # from local remote_route_utt_strs = set([utt.to_str() for utt in remote_only]) # get local utterances that are in remote - local_to_keep = [utt for utt in local_only if ( - utt.route in remote_route_names and utt.to_str() not in remote_route_utt_strs - )] + local_to_keep = [ + utt + for utt in local_only + if ( + utt.route in remote_route_names + and utt.to_str() not in remote_route_utt_strs + ) + ] # overwrite remote routes with local metadata and function schemas local_to_keep = [ Utterance( route=utt.route, utterance=utt.utterance, metadata=remote_only_mapper[utt.route][1], - function_schemas=remote_only_mapper[utt.route][0] - ) for utt in local_to_keep + function_schemas=remote_only_mapper[utt.route][0], + ) + for utt in local_to_keep ] # get local utterances that are NOT in remote - local_to_delete = [utt for utt in local_only if utt.route not in remote_route_names] + local_to_delete = [ + utt for utt in local_only if utt.route not in remote_route_names + ] return { - "remote": { - "upsert": local_to_keep, - "delete": [] - }, - "local": { - "upsert": remote_only, - "delete": local_to_delete - } + "remote": {"upsert": local_to_keep, "delete": []}, + "local": {"upsert": remote_only, "delete": local_to_delete}, } elif sync_mode == "merge": # overwrite remote routes with local metadata and function schemas remote_only_updated = [ - Utterance( - route=utt.route, - utterance=utt.utterance, - metadata=local_only_mapper[utt.route][1], - function_schemas=local_only_mapper[utt.route][0] - ) if utt.route in local_only_mapper else utt + ( + Utterance( + route=utt.route, + utterance=utt.utterance, + metadata=local_only_mapper[utt.route][1], + function_schemas=local_only_mapper[utt.route][0], + ) + if utt.route in local_only_mapper + else utt + ) for utt in remote_only ] # propogate same to shared routes @@ -374,25 +386,26 @@ class UtteranceDiff(BaseModel): route=utt.route, utterance=utt.utterance, metadata=local_only_mapper[utt.route][1], - function_schemas=local_only_mapper[utt.route][0] - ) for utt in local_and_remote if ( - utt.route in local_only_mapper and ( - utt.metadata != local_only_mapper[utt.route][1] or - utt.function_schemas != local_only_mapper[utt.route][0] + function_schemas=local_only_mapper[utt.route][0], + ) + for utt in local_and_remote + if ( + utt.route in local_only_mapper + and ( + utt.metadata != local_only_mapper[utt.route][1] + or utt.function_schemas != local_only_mapper[utt.route][0] ) ) ] return { "remote": { "upsert": local_only + shared_updated + remote_only_updated, - "delete": [] + "delete": [], }, - "local": { - "upsert": remote_only_updated + shared_updated, - "delete": [] - } + "local": {"upsert": remote_only_updated + shared_updated, "delete": []}, } - + else: + raise ValueError(f"sync_mode must be one of {SYNC_MODES}") class Metric(Enum): diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index cdc0745f..579e0629 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -214,8 +214,7 @@ class TestRouteLayer: ): index = init_index(index_cls, sync=None) _ = RouteLayer( - encoder=openai_encoder, routes=routes, index=index, - auto_sync="local" + encoder=openai_encoder, routes=routes, index=index, auto_sync="local" ) route_layer = RouteLayer(encoder=openai_encoder, routes=routes_2, index=index) if index_cls is PineconeIndex: @@ -228,12 +227,9 @@ class TestRouteLayer: def test_utterance_diff(self, openai_encoder, routes, routes_2, index_cls): index = init_index(index_cls) _ = RouteLayer( - encoder=openai_encoder, routes=routes, index=index, - auto_sync="local" - ) - route_layer_2 = RouteLayer( - encoder=openai_encoder, routes=routes_2, index=index + encoder=openai_encoder, routes=routes, index=index, auto_sync="local" ) + route_layer_2 = RouteLayer(encoder=openai_encoder, routes=routes_2, index=index) if index_cls is PineconeIndex: time.sleep(PINECONE_SLEEP) # allow for index to be populated diff = route_layer_2.get_utterance_diff(include_metadata=True) @@ -254,12 +250,16 @@ class TestRouteLayer: # TEST LOCAL pinecone_index = init_index(index_cls) _ = RouteLayer( - encoder=openai_encoder, routes=routes, index=pinecone_index, + encoder=openai_encoder, + routes=routes, + index=pinecone_index, ) time.sleep(PINECONE_SLEEP) # allow for index to be populated route_layer = RouteLayer( - encoder=openai_encoder, routes=routes_2, index=pinecone_index, - auto_sync="local" + encoder=openai_encoder, + routes=routes_2, + index=pinecone_index, + auto_sync="local", ) time.sleep(PINECONE_SLEEP) # allow for index to be populated assert route_layer.index.get_utterances() == [ @@ -276,13 +276,17 @@ class TestRouteLayer: # TEST REMOTE pinecone_index = init_index(index_cls) _ = RouteLayer( - encoder=openai_encoder, routes=routes_2, index=pinecone_index, - auto_sync="local" + encoder=openai_encoder, + routes=routes_2, + index=pinecone_index, + auto_sync="local", ) time.sleep(PINECONE_SLEEP) # allow for index to be populated route_layer = RouteLayer( - encoder=openai_encoder, routes=routes, index=pinecone_index, - auto_sync="remote" + encoder=openai_encoder, + routes=routes, + index=pinecone_index, + auto_sync="remote", ) time.sleep(PINECONE_SLEEP) # allow for index to be populated assert route_layer.index.get_utterances() == [ @@ -293,66 +297,82 @@ class TestRouteLayer: @pytest.mark.skipif( os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" ) - def test_auto_sync_merge_force_remote(self, openai_encoder, routes, routes_2, index_cls): + def test_auto_sync_merge_force_remote( + self, openai_encoder, routes, routes_2, index_cls + ): if index_cls is PineconeIndex: # TEST MERGE FORCE REMOTE pinecone_index = init_index(index_cls) route_layer = RouteLayer( - encoder=openai_encoder, routes=routes, index=pinecone_index, - auto_sync="local" + encoder=openai_encoder, + routes=routes, + index=pinecone_index, + auto_sync="local", ) time.sleep(PINECONE_SLEEP) # allow for index to be populated route_layer = RouteLayer( - encoder=openai_encoder, routes=routes_2, index=pinecone_index, - auto_sync="merge-force-remote" + encoder=openai_encoder, + routes=routes_2, + index=pinecone_index, + auto_sync="merge-force-remote", ) time.sleep(PINECONE_SLEEP) # allow for index to be populated # confirm local and remote are synced assert route_layer.is_synced() - #Â now confirm utterances are correct + # now confirm utterances are correct local_utterances = route_layer.index.get_utterances() - #Â we sort to ensure order is the same + # we sort to ensure order is the same local_utterances.sort(key=lambda x: x.to_str(include_metadata=True)) assert local_utterances == [ - Utterance(route='Route 1', utterance='Hello'), - Utterance(route='Route 1', utterance='Hi'), - Utterance(route='Route 2', utterance='Au revoir'), - Utterance(route='Route 2', utterance='Bye'), - Utterance(route='Route 2', utterance='Goodbye'), - Utterance(route='Route 2', utterance='Hi') + Utterance(route="Route 1", utterance="Hello"), + Utterance(route="Route 1", utterance="Hi"), + Utterance(route="Route 2", utterance="Au revoir"), + Utterance(route="Route 2", utterance="Bye"), + Utterance(route="Route 2", utterance="Goodbye"), + Utterance(route="Route 2", utterance="Hi"), ], "The routes in the index should match the local routes" @pytest.mark.skipif( os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" ) - def test_auto_sync_merge_force_local(self, openai_encoder, routes, routes_2, index_cls): + def test_auto_sync_merge_force_local( + self, openai_encoder, routes, routes_2, index_cls + ): if index_cls is PineconeIndex: # TEST MERGE FORCE LOCAL pinecone_index = init_index(index_cls) route_layer = RouteLayer( - encoder=openai_encoder, routes=routes, index=pinecone_index, - auto_sync="local" + encoder=openai_encoder, + routes=routes, + index=pinecone_index, + auto_sync="local", ) time.sleep(PINECONE_SLEEP) # allow for index to be populated route_layer = RouteLayer( - encoder=openai_encoder, routes=routes_2, index=pinecone_index, - auto_sync="merge-force-local" + encoder=openai_encoder, + routes=routes_2, + index=pinecone_index, + auto_sync="merge-force-local", ) time.sleep(PINECONE_SLEEP) # allow for index to be populated # confirm local and remote are synced assert route_layer.is_synced() - #Â now confirm utterances are correct + # now confirm utterances are correct local_utterances = route_layer.index.get_utterances() - #Â we sort to ensure order is the same + # we sort to ensure order is the same local_utterances.sort(key=lambda x: x.to_str(include_metadata=True)) assert local_utterances == [ - Utterance(route='Route 1', utterance='Hello', metadata={'type': 'default'}), - Utterance(route='Route 1', utterance='Hi', metadata={'type': 'default'}), - Utterance(route='Route 2', utterance='Au revoir'), - Utterance(route='Route 2', utterance='Bye'), - Utterance(route='Route 2', utterance='Goodbye'), - Utterance(route='Route 2', utterance='Hi'), - Utterance(route='Route 3', utterance='Boo') + Utterance( + route="Route 1", utterance="Hello", metadata={"type": "default"} + ), + Utterance( + route="Route 1", utterance="Hi", metadata={"type": "default"} + ), + Utterance(route="Route 2", utterance="Au revoir"), + Utterance(route="Route 2", utterance="Bye"), + Utterance(route="Route 2", utterance="Goodbye"), + Utterance(route="Route 2", utterance="Hi"), + Utterance(route="Route 3", utterance="Boo"), ], "The routes in the index should match the local routes" @pytest.mark.skipif( @@ -363,35 +383,37 @@ class TestRouteLayer: # TEST MERGE pinecone_index = init_index(index_cls) route_layer = RouteLayer( - encoder=openai_encoder, routes=routes_2, index=pinecone_index, - auto_sync="local" + encoder=openai_encoder, + routes=routes_2, + index=pinecone_index, + auto_sync="local", ) time.sleep(PINECONE_SLEEP) # allow for index to be populated route_layer = RouteLayer( - encoder=openai_encoder, routes=routes, index=pinecone_index, - auto_sync="merge" + encoder=openai_encoder, + routes=routes, + index=pinecone_index, + auto_sync="merge", ) time.sleep(PINECONE_SLEEP) # allow for index to be populated # confirm local and remote are synced assert route_layer.is_synced() - #Â now confirm utterances are correct + # now confirm utterances are correct local_utterances = route_layer.index.get_utterances() - #Â we sort to ensure order is the same + # we sort to ensure order is the same local_utterances.sort(key=lambda x: x.to_str(include_metadata=True)) assert local_utterances == [ Utterance( - route='Route 1', utterance='Hello', - metadata={'type': 'default'} + route="Route 1", utterance="Hello", metadata={"type": "default"} ), Utterance( - route='Route 1', utterance='Hi', - metadata={'type': 'default'} + route="Route 1", utterance="Hi", metadata={"type": "default"} ), - Utterance(route='Route 2', utterance='Au revoir'), - Utterance(route='Route 2', utterance='Bye'), - Utterance(route='Route 2', utterance='Goodbye'), - Utterance(route='Route 2', utterance='Hi'), - Utterance(route='Route 3', utterance='Boo') + Utterance(route="Route 2", utterance="Au revoir"), + Utterance(route="Route 2", utterance="Bye"), + Utterance(route="Route 2", utterance="Goodbye"), + Utterance(route="Route 2", utterance="Hi"), + Utterance(route="Route 3", utterance="Boo"), ], "The routes in the index should match the local routes" # clear index -- GitLab