diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 08c6ad6c77fec6bb0b2a119c8654ff720036abbc..5ddb586e75afa471090bcd92caca9d5ea5478952 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -27,6 +27,7 @@ class BaseIndex(BaseModel): routes: List[str], utterances: List[Any], function_schemas: Optional[List[Dict[str, Any]]] = None, + metadata_list: List[Dict[str, Any]] = [], ): """ Add embeddings to the index. @@ -113,8 +114,9 @@ class BaseIndex(BaseModel): self, local_route_names: List[str], local_utterances: List[str], - dimensions: int, local_function_schemas: List[Dict[str, Any]], + local_metadata: List[Dict[str, Any]], + dimensions: int, ): """ Synchronize the local index with the remote index based on the specified mode. diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index 68a859b5ed7049f6043558e8cb129338c641ec56..09e23ffc1b6fb7f152cd4e6aad12ca5a8df89391 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -28,6 +28,7 @@ class LocalIndex(BaseIndex): routes: List[str], utterances: List[str], function_schemas: Optional[List[Dict[str, Any]]] = None, + metadata_list: List[Dict[str, Any]] = [], ): embeds = np.array(embeddings) # type: ignore routes_arr = np.array(routes) @@ -52,8 +53,9 @@ class LocalIndex(BaseIndex): self, local_route_names: List[str], local_utterances: List[str], - dimensions: int, local_function_schemas: List[Dict[str, Any]], + local_metadata: List[Dict[str, Any]], + dimensions: int, ): if self.sync is not None: logger.error("Sync remove is not implemented for LocalIndex.") diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 6f405841878db5563bc4e454cb51ddf20db2613e..0231f53f127d5d684df5fdcae927fb59d75349aa 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -24,6 +24,7 @@ class PineconeRecord(BaseModel): route: str utterance: str function_schema: str + metadata: Dict[str, Any] = {} # Additional metadata dictionary def __init__(self, **data): super().__init__(**data) @@ -31,16 +32,19 @@ class PineconeRecord(BaseModel): # Use SHA-256 for a more secure hash utterance_id = hashlib.sha256(self.utterance.encode()).hexdigest() self.id = f"{clean_route}#{utterance_id}" + self.metadata.update( + { + "sr_route": self.route, + "sr_utterance": self.utterance, + "sr_function_schema": self.function_schema, + } + ) def to_dict(self): return { "id": self.id, "values": self.values, - "metadata": { - "sr_route": self.route, - "sr_utterance": self.utterance, - "sr_function_schemas": self.function_schema, - }, + "metadata": self.metadata, } @@ -214,179 +218,162 @@ class PineconeIndex(BaseIndex): def _sync_index( self, local_route_names: List[str], - local_utterances: List[str], + local_utterances_list: List[str], + local_function_schemas_list: List[Dict[str, Any]], + local_metadata_list: List[Dict[str, Any]], dimensions: int, - local_function_schemas: List[Dict[str, Any]], - ) -> Tuple: - + ) -> Tuple[List,List,Dict]: if self.index is None: self.dimensions = self.dimensions or dimensions self.index = self._init_index(force_create=True) remote_routes = self.get_routes() - remote_dict: Dict[str, Dict[str, Union[set, Dict]]] = { - route: {"utterances": set(), "function_schemas": {}} - for route, _, _ in remote_routes - } - - for route, utterance, function_schema in remote_routes: - remote_dict[route]["utterances"].add(utterance) # type: ignore - - logger.info( - f"function_schema remote is {'empty' if not function_schema else 'not empty'} for {route}" - ) - remote_dict[route]["function_schemas"].update(function_schema or {}) - - local_dict: Dict[str, Dict[str, Union[set, Dict]]] = { - route: {"utterances": set(), "function_schemas": {}} - for route in local_route_names + # Create remote dictionary for storing utterances and metadata + remote_dict: Dict[str, Dict[str, Any]] = { + route: {"utterances": set(), "function_schemas": function_schemas, "metadata": metadata} + for route, utterance, function_schemas, metadata in remote_routes } + for route, utterance, function_schemas, metadata in remote_routes: + remote_dict[route]["utterances"].add(utterance) - for route, utterance, function_schema in zip( - local_route_names, local_utterances, local_function_schemas + # Create local dictionary for storing utterances and metadata + local_dict: Dict[str, Dict[str, Any]] = {} + for route, utterance, function_schemas, metadata in zip( + local_route_names, local_utterances_list, local_function_schemas_list, local_metadata_list ): - local_dict[route]["utterances"].add(utterance) # type: ignore - local_dict[route]["function_schemas"].update(function_schema) + if route not in local_dict: + local_dict[route] = {"utterances": set(), "function_schemas": function_schemas, "metadata": metadata} + local_dict[route]["utterances"].add(utterance) + local_dict[route]["function_schemas"] = function_schemas + local_dict[route]["metadata"] = metadata all_routes = set(remote_dict.keys()).union(local_dict.keys()) + routes_to_add = [] routes_to_delete = [] - layer_routes: Dict[str, Dict[str, Union[List[str], Dict]]] = {} + layer_routes = {} for route in all_routes: - local_utterances_set = local_dict.get(route, {"utterances": set()})[ - "utterances" - ] - remote_utterances_set = remote_dict.get(route, {"utterances": set()})[ - "utterances" - ] - local_function_schemas_dict = local_dict.get(route, {}).get( - "function_schemas", {} - ) - - remote_function_schemas_dict = remote_dict.get( - route, {"function_schemas": {}} - )["function_schemas"] + local_utterances = local_dict.get(route, {}).get("utterances", set()) + remote_utterances = remote_dict.get(route, {}).get("utterances", set()) + local_function_schemas = local_dict.get(route, {}).get("function_schemas", {}) + remote_function_schemas = remote_dict.get(route, {}).get("function_schemas", {}) + local_metadata = local_dict.get(route, {}).get("metadata", {}) + remote_metadata = remote_dict.get(route, {}).get("metadata", {}) - if not local_utterances_set and not remote_utterances_set: - continue + utterances_to_include = set() - utterances_to_include: set = set() + metadata_changed = local_metadata != remote_metadata + function_schema_changed = local_function_schemas != remote_function_schemas if self.sync == "error": - if (local_utterances_set != remote_utterances_set) or ( - local_function_schemas_dict != remote_function_schemas_dict + if ( + local_utterances != remote_utterances + or local_function_schemas != remote_function_schemas + or local_metadata != remote_metadata ): raise ValueError( f"Synchronization error: Differences found in route '{route}'" ) - if local_utterances_set: - layer_routes[route] = {"utterances": list(local_utterances_set)} - if isinstance(local_function_schemas_dict, dict): - layer_routes[route]["function_schemas"] = { - **local_function_schemas_dict + + if local_utterances: + layer_routes[route] = { + "utterances": list(local_utterances), + "function_schemas": local_function_schemas, + "metadata": local_metadata, } + elif self.sync == "remote": - if remote_utterances_set: - layer_routes[route] = {"utterances": list(remote_utterances_set)} - if isinstance(remote_function_schemas_dict, dict): - layer_routes[route]["function_schemas"] = { - **remote_function_schemas_dict + if remote_utterances: + layer_routes[route] = { + "utterances": list(remote_utterances), + "function_schemas": remote_function_schemas, + "metadata": remote_metadata, } + elif self.sync == "local": - utterances_to_include = local_utterances_set - remote_utterances_set # type: ignore + utterances_to_include = local_utterances - remote_utterances routes_to_delete.extend( [ (route, utterance) - for utterance in remote_utterances_set - if utterance not in local_utterances_set + for utterance in remote_utterances + if utterance not in local_utterances ] ) - layer_routes[route] = {} - if local_utterances_set: - layer_routes[route] = {"utterances": list(local_utterances_set)} - if isinstance(local_function_schemas_dict, dict): - layer_routes[route]["function_schemas"] = { - **local_function_schemas_dict + if local_utterances: + layer_routes[route] = { + "utterances": list(local_utterances), + "function_schemas": local_function_schemas, + "metadata": local_metadata, } + elif self.sync == "merge-force-remote": if route in local_dict and route not in remote_dict: - utterances_to_include = ( - local_utterances_set - if isinstance(local_utterances_set, set) - else set() - ) - if local_utterances_set: - layer_routes[route] = {"utterances": list(local_utterances_set)} - if isinstance(local_function_schemas_dict, dict): - layer_routes[route]["function_schemas"] = { - **local_function_schemas_dict + utterances_to_include = local_utterances + if local_utterances: + layer_routes[route] = { + "utterances": list(local_utterances), + "function_schemas": local_function_schemas, + "metadata": local_metadata, } else: - if remote_utterances_set: + if remote_utterances: layer_routes[route] = { - "utterances": list(remote_utterances_set) - } - if isinstance(remote_function_schemas_dict, dict): - layer_routes[route]["function_schemas"] = { - **remote_function_schemas_dict + "utterances": list(remote_utterances), + "function_schemas": remote_function_schemas, + "metadata": remote_metadata, } elif self.sync == "merge-force-local": if route in local_dict: - utterances_to_include = local_utterances_set - remote_utterances_set # type: ignore + utterances_to_include = local_utterances - remote_utterances routes_to_delete.extend( [ (route, utterance) - for utterance in remote_utterances_set - if utterance not in local_utterances_set + for utterance in remote_utterances + if utterance not in local_utterances ] ) - if local_utterances_set: - layer_routes[route] = {"utterances": list(local_utterances_set)} - if isinstance(local_function_schemas_dict, dict): - layer_routes[route]["function_schemas"] = { - **local_function_schemas_dict + if local_utterances: + layer_routes[route] = { + "utterances": list(local_utterances), + "function_schemas": local_function_schemas, + "metadata": local_metadata, } else: - if remote_utterances_set: + if remote_utterances: layer_routes[route] = { - "utterances": list(remote_utterances_set) - } - if isinstance(remote_function_schemas_dict, dict): - layer_routes[route]["function_schemas"] = { - **remote_function_schemas_dict + "utterances": list(remote_utterances), + "function_schemas": remote_function_schemas, + "metadata": remote_metadata, } + elif self.sync == "merge": - utterances_to_include = local_utterances_set - remote_utterances_set # type: ignore - if local_utterances_set or remote_utterances_set: + utterances_to_include = local_utterances - remote_utterances + if local_utterances or remote_utterances: + # Here metadata are merged, with local metadata taking precedence for same keys + merged_metadata = {**remote_metadata, **local_metadata} + merged_function_schemas = {**remote_function_schemas, **local_function_schemas} layer_routes[route] = { - "utterances": list( - remote_utterances_set.union(local_utterances_set) # type: ignore - ) - } - - if local_function_schemas_dict or remote_function_schemas_dict: - # Ensure both are dictionaries before merging - layer_routes[route]["function_schemas"] = { # type: ignore - **( - remote_function_schemas_dict - if isinstance(remote_function_schemas_dict, dict) - else {} - ), - **( - local_function_schemas_dict - if isinstance(local_function_schemas_dict, dict) - else {} - ), + "utterances": list(remote_utterances.union(local_utterances)), + "function_schemas": merged_function_schemas, + "metadata": merged_metadata, } else: raise ValueError("Invalid sync mode specified") - for utterance in utterances_to_include: - routes_to_add.append((route, utterance)) + # Add utterances if metadata has changed or if there are new utterances + if (metadata_changed or function_schema_changed) and self.sync in ["local", "merge-force-local"]: + for utterance in local_utterances: + routes_to_add.append((route, utterance, local_function_schemas, local_metadata)) + if (metadata_changed or function_schema_changed) and self.sync == "merge": + for utterance in local_utterances: + routes_to_add.append((route, utterance, merged_function_schemas, merged_metadata)) + elif utterances_to_include: + for utterance in utterances_to_include: + routes_to_add.append((route, utterance, local_function_schemas, local_metadata)) return routes_to_add, routes_to_delete, layer_routes @@ -403,6 +390,7 @@ class PineconeIndex(BaseIndex): routes: List[str], utterances: List[str], function_schemas: Optional[List[Dict[str, Any]]] = None, + metadata_list: List[Dict[str, Any]] = [], batch_size: int = 100, ): """Add vectors to Pinecone in batches.""" @@ -416,9 +404,10 @@ class PineconeIndex(BaseIndex): route=route, utterance=utterance, function_schema=json.dumps(function_schema), + metadata=metadata, ).to_dict() - for vector, route, utterance, function_schema in zip( - embeddings, routes, utterances, function_schemas # type: ignore + for vector, route, utterance, function_schema, metadata in zip( + embeddings, routes, utterances, function_schemas, metadata_list # type: ignore ) ] @@ -484,26 +473,30 @@ class PineconeIndex(BaseIndex): def get_routes(self) -> List[Tuple]: """ - Gets a list of route and utterance objects currently stored in the index. + Gets a list of route and utterance objects currently stored in the index, including additional metadata. Returns: - List[Tuple]: A list of (route_name, utterance) objects. + List[Tuple]: A list of tuples, each containing route, utterance, function schema and additional metadata. """ - # Get all records _, metadata = self._get_all(include_metadata=True) route_tuples = [ ( - route_objects["sr_route"], - route_objects["sr_utterance"], + data.get("sr_route", ""), + data.get("sr_utterance", ""), ( - json.loads(route_objects["sr_function_schemas"]) - if route_objects["sr_function_schemas"] + json.loads(data["sr_function_schema"]) + if data.get("sr_function_schema", "") else {} ), + { + key: value + for key, value in data.items() + if key not in ["sr_route", "sr_utterance", "sr_function_schema"] + }, ) - for route_objects in metadata + for data in metadata ] - return route_tuples + return route_tuples # type: ignore def delete(self, route_name: str): route_vec_ids = self._get_route_ids(route_name=route_name) @@ -765,16 +758,32 @@ class PineconeIndex(BaseIndex): response_data.get("vectors", {}).get(vector_id, {}).get("metadata", {}) ) - async def _async_get_routes(self) -> list[tuple]: + async def _async_get_routes(self) -> List[Tuple]: """ - Gets a list of route and utterance objects currently stored in the index. + Asynchronously gets a list of route and utterance objects currently stored in the index, including additional metadata. Returns: - List[Tuple]: A list of (route_name, utterance) objects. + List[Tuple]: A list of tuples, each containing route, utterance, function schema and additional metadata. """ _, metadata = await self._async_get_all(include_metadata=True) - route_tuples = [(x["sr_route"], x["sr_utterance"]) for x in metadata] - return route_tuples + route_info = [ + ( + data.get("sr_route", ""), + data.get("sr_utterance", ""), + ( + json.loads(data["sr_function_schema"]) + if data["sr_function_schema"] + else {} + ), + { + key: value + for key, value in data.items() + if key not in ["sr_route", "sr_utterance", "sr_function_schema"] + }, + ) + for data in metadata + ] + return route_info # type: ignore def __len__(self): return self.index.describe_index_stats()["total_vector_count"] diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py index 0d18381f4347f5d1dd6fcdd67c8566230a9b6499..ff63ec09c419a787033a8412a3517fa1bcb49d58 100644 --- a/semantic_router/index/postgres.py +++ b/semantic_router/index/postgres.py @@ -258,8 +258,9 @@ class PostgresIndex(BaseIndex): self, embeddings: List[List[float]], routes: List[str], - utterances: List[Any], + utterances: List[str], function_schemas: Optional[List[Dict[str, Any]]] = None, + metadata_list: List[Dict[str, Any]] = [], ) -> None: """ Adds vectors to the index. diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index 180bb3b38ae1b9acc40fb9b0a7aa61e8df7d46a7..b372c49cbcfe1d01ebb0b0976e342bf2559c3548 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -167,9 +167,10 @@ class QdrantIndex(BaseIndex): def _sync_index( self, local_route_names: List[str], - local_utterances: List[str], - dimensions: int, + local_utterances_list: List[str], local_function_schemas: List[Dict[str, Any]], + local_metadata_list: List[Dict[str, Any]], + dimensions: int, ): if self.sync is not None: logger.error("Sync remove is not implemented for QdrantIndex.") @@ -180,6 +181,7 @@ class QdrantIndex(BaseIndex): routes: List[str], utterances: List[str], function_schemas: Optional[List[Dict[str, Any]]] = None, + metadata_list: List[Dict[str, Any]] = [], batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE, ): self.dimensions = self.dimensions or len(embeddings[0]) diff --git a/semantic_router/layer.py b/semantic_router/layer.py index f5c270cafece0824a6c31158f7ae69c0b5dcc41d..7baa27c8c6dd69cea9fc6b7477085012873a9475 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -218,12 +218,15 @@ class RouteLayer: route.score_threshold = self.score_threshold # if routes list has been passed, we initialize index now if self.index.sync: - self._add_and_sync_routes(routes=self.routes) - elif self.routes: + # initialize index now + if len(self.routes) > 0: + self._add_and_sync_routes(routes=self.routes) + else: + self._add_and_sync_routes(routes=[]) + elif len(self.routes) > 0: self._add_routes(routes=self.routes) def check_for_matching_routes(self, top_class: str) -> Optional[Route]: - # Use next with a generator expression for optimization matching_route = next( (route for route in self.routes if route.name == top_class), None ) @@ -232,6 +235,7 @@ class RouteLayer: f"No route found with name {top_class}. Check to see if any Routes " "have been defined." ) + return None return matching_route def __call__( @@ -388,14 +392,6 @@ class RouteLayer: ) return self._pass_threshold(scores, threshold) - def _set_layer_routes(self, new_routes: List[Route]): - """ - Set and override the current routes with a new list of routes. - - :param new_routes: List of Route objects to set as the current routes. - """ - self.routes = new_routes - def __str__(self): return ( f"RouteLayer(encoder={self.encoder}, " @@ -421,16 +417,9 @@ class RouteLayer: return cls(encoder=encoder, routes=config.routes, index=index) def add(self, route: Route): - logger.info(f"Adding `{route.name}` route") - # create embeddings - embeds = self.encoder(route.utterances) - # if route has no score_threshold, use default - if route.score_threshold is None: - route.score_threshold = self.score_threshold - - # add routes to the index + embedded_utterances = self.encoder(route.utterances) self.index.add( - embeddings=embeds, + embeddings=embedded_utterances, routes=[route.name] * len(route.utterances), utterances=route.utterances, function_schemas=( @@ -438,6 +427,7 @@ class RouteLayer: if route.function_schemas else [{}] * len(route.utterances) ), + metadata_list=[route.metadata] * len(route.utterances), ) self.routes.append(route) @@ -483,36 +473,18 @@ class RouteLayer: if not routes: logger.warning("No routes provided to add.") return - - route_names = [] - all_embeddings = [] - all_utterances: List[str] = [] - all_function_schemas = [] - - for route in routes: - logger.info(f"Adding `{route.name}` route") - route_embeddings = self.encoder(route.utterances) - - # Set score_threshold if not already set - route.score_threshold = route.score_threshold or self.score_threshold - - # Prepare data for batch insertion - route_names.extend([route.name] * len(route.utterances)) - all_embeddings.extend(route_embeddings) - all_utterances.extend(route.utterances) - all_function_schemas.extend( - route.function_schemas * len(route.utterances) - if route.function_schemas - else [{}] * len(route.utterances) - ) - + # create embeddings for all routes + route_names, all_utterances, all_metadata = self._extract_routes_details( + routes, include_metadata=True + ) + embedded_utterances = self.encoder(all_utterances) try: # Batch insertion into the index self.index.add( - embeddings=all_embeddings, + embeddings=embedded_utterances, routes=route_names, utterances=all_utterances, - function_schemas=all_function_schemas, + metadata_list=all_metadata, ) except Exception as e: logger.error(f"Failed to add routes to the index: {e}") @@ -520,55 +492,63 @@ class RouteLayer: def _add_and_sync_routes(self, routes: List[Route]): # create embeddings for all routes and sync at startup with remote ones based on sync setting - local_route_names, local_utterances, local_function_schemas = ( - self._extract_routes_details(routes) + local_route_names, local_utterances, local_function_schemas, local_metadata = ( + self._extract_routes_details(routes, include_metadata=True) ) routes_to_add, routes_to_delete, layer_routes_dict = self.index._sync_index( - local_route_names=local_route_names, - local_utterances=local_utterances, + local_route_names, + local_utterances, + local_function_schemas, + local_metadata, dimensions=len(self.encoder(["dummy"])[0]), - local_function_schemas=local_function_schemas, ) - layer_routes: List[Route] = [] + logger.info(f"Routes to add: {routes_to_add}") + logger.info(f"Routes to delete: {routes_to_delete}") + logger.info(f"Layer routes: {layer_routes_dict}") - for route in layer_routes_dict.keys(): - route_dict = layer_routes_dict[route] - function_schemas = route_dict.get("function_schemas", None) - layer_routes.append( - Route( - name=route, - utterances=route_dict["utterances"], - function_schemas=[function_schemas] if function_schemas else None, - ) - ) - - data_to_delete: dict = {} + data_to_delete = {} # type: ignore for route, utterance in routes_to_delete: data_to_delete.setdefault(route, []).append(utterance) self.index._remove_and_sync(data_to_delete) - all_utterances_to_add = [utt for _, utt in routes_to_add] + # Prepare data for addition + if routes_to_add: + ( + route_names_to_add, + all_utterances_to_add, + function_schemas_to_add, + metadata_to_add, + ) = map(list, zip(*routes_to_add)) + else: + ( + route_names_to_add, + all_utterances_to_add, + function_schemas_to_add, + metadata_to_add, + ) = ([], [], [], []) + embedded_utterances_to_add = ( self.encoder(all_utterances_to_add) if all_utterances_to_add else [] ) - route_names_to_add = [route for route, _, in routes_to_add] - self.index.add( embeddings=embedded_utterances_to_add, routes=route_names_to_add, utterances=all_utterances_to_add, - function_schemas=local_function_schemas, + function_schemas=function_schemas_to_add, + metadata_list=metadata_to_add, ) - logger.info(f"layer_routes: {layer_routes}") - - self._set_layer_routes(layer_routes) + # Update local route layer state + self.routes = [ + Route(name=route, utterances=data.get("utterances", []), function_schemas=[data.get("function_schemas", None)], metadata=data.get("metadata", {})) + for route, data in layer_routes_dict.items() + ] def _extract_routes_details( - self, routes: List[Route] + self, routes: List[Route], include_metadata: bool = False ) -> Tuple[list[str], list[str], List[Dict[str, Any]]]: route_names = [route.name for route in routes for _ in route.utterances] utterances = [utterance for route in routes for utterance in route.utterances] @@ -577,6 +557,10 @@ class RouteLayer: for route in routes for _ in route.utterances ] + + if include_metadata: + metadata = [route.metadata for route in routes for _ in route.utterances] + return route_names, utterances, function_schemas, metadata return route_names, utterances, function_schemas def _encode(self, text: str) -> Any: @@ -771,11 +755,15 @@ class RouteLayer: remote_routes = self.index.get_routes() # TODO Enhance by retrieving directly the vectors instead of embedding all utterances again - routes = [route_tuple[0] for route_tuple in remote_routes] - utterances = [route_tuple[1] for route_tuple in remote_routes] + routes, utterances, metadata = map(list, zip(*remote_routes)) embeddings = self.encoder(utterances) self.index = LocalIndex() - self.index.add(embeddings=embeddings, routes=routes, utterances=utterances) + self.index.add( + embeddings=embeddings, + routes=routes, + utterances=utterances, + metadata_list=metadata, + ) # convert inputs into array Xq: List[List[float]] = [] diff --git a/semantic_router/route.py b/semantic_router/route.py index 3fc3f0407e2c7bda06c627bcedeb4bae48e59ac9..41fd0bf2c4e2b496b92f22f841c8b565dda8bdad 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -50,6 +50,7 @@ class Route(BaseModel): function_schemas: Optional[List[Dict[str, Any]]] = None llm: Optional[BaseLLM] = None score_threshold: Optional[float] = None + metadata: Optional[Dict[str, Any]] = {} class Config: arbitrary_types_allowed = True