From fbdb7be815cb17111cef4d929431c62f0566d1bd Mon Sep 17 00:00:00 2001 From: James Briggs <james.briggs@hotmail.com> Date: Sat, 9 Nov 2024 21:10:58 +0100 Subject: [PATCH] fix: resolve errors and lint --- semantic_router/index/base.py | 118 ++++++++++++++++++++++++++---- semantic_router/index/pinecone.py | 60 ++------------- semantic_router/index/postgres.py | 13 ---- semantic_router/index/qdrant.py | 4 +- semantic_router/layer.py | 42 ++++++----- 5 files changed, 135 insertions(+), 102 deletions(-) diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index bd5c7c94..da4fba54 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -1,4 +1,5 @@ from typing import Any, List, Optional, Tuple, Union, Dict +import json import numpy as np from pydantic.v1 import BaseModel @@ -38,6 +39,20 @@ class BaseIndex(BaseModel): """ raise NotImplementedError("This method should be implemented by subclasses.") + def get_utterances(self) -> List[Tuple]: + """Gets a list of route and utterance objects currently stored in the + index, including additional metadata. + + :return: A list of tuples, each containing route, utterance, function + schema and additional metadata. + :rtype: List[Tuple] + """ + _, metadata = self._get_all(include_metadata=True) + route_tuples: List[ + Tuple[str, str, Optional[Dict[str, Any]], Dict[str, Any]] + ] = [(x["sr_route"], x["sr_utterance"], None, {}) for x in metadata] + return route_tuples + def get_routes(self) -> List[Route]: """Gets a list of route objects currently stored in the index. @@ -45,25 +60,24 @@ class BaseIndex(BaseModel): :rtype: List[Route] """ route_tuples = self.get_utterances() - routes_dict: Dict[str, List[str]] = {} - # first create a dictionary of routes mapping to all their utterances, - # function_schema, and metadata + routes_dict: Dict[str, Route] = {} + # first create a dictionary of route names to Route objects for route_name, utterance, function_schema, metadata in route_tuples: - routes_dict.setdefault( - route_name, - { - "function_schemas": None, - "metadata": {}, - }, - ) - routes_dict[route_name]["utterances"] = routes_dict[route_name].get( - "utterances", [] - ) - routes_dict[route_name]["utterances"].append(utterance) + # if the route is not in the dictionary, add it + if route_name not in routes_dict: + routes_dict[route_name] = Route( + name=route_name, + utterances=[utterance], + function_schemas=function_schema, + metadata=metadata, + ) + else: + # otherwise, add the utterance to the route + routes_dict[route_name].utterances.append(utterance) # then create a list of routes from the dictionary routes: List[Route] = [] - for route_name, route_data in routes_dict.items(): - routes.append(Route(name=route_name, **route_data)) + for route_name, route in routes_dict.items(): + routes.append(route) return routes def _remove_and_sync(self, routes_to_delete: dict): @@ -181,5 +195,77 @@ class BaseIndex(BaseModel): """ raise NotImplementedError("This method should be implemented by subclasses.") + def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False): + """ + Retrieves all vector IDs from the index. + + This method should be implemented by subclasses. + + :param prefix: The prefix to filter the vectors by. + :type prefix: Optional[str] + :param include_metadata: Whether to include metadata in the response. + :type include_metadata: bool + :return: A tuple containing a list of vector IDs and a list of metadata dictionaries. + :rtype: tuple[list[str], list[dict]] + """ + raise NotImplementedError("This method should be implemented by subclasses.") + + async def _async_get_all( + self, prefix: Optional[str] = None, include_metadata: bool = False + ) -> tuple[list[str], list[dict]]: + """Retrieves all vector IDs from the index asynchronously. + + This method should be implemented by subclasses. + + :param prefix: The prefix to filter the vectors by. + :type prefix: Optional[str] + :param include_metadata: Whether to include metadata in the response. + :type include_metadata: bool + :return: A tuple containing a list of vector IDs and a list of metadata dictionaries. + :rtype: tuple[list[str], list[dict]] + """ + raise NotImplementedError("This method should be implemented by subclasses.") + + async def _async_get_routes(self) -> List[Tuple]: + """Asynchronously gets a list of route and utterance objects currently + stored in the index, including additional metadata. + + :return: A list of tuples, each containing route, utterance, function + schema and additional metadata. + :rtype: List[Tuple] + """ + _, metadata = await self._async_get_all(include_metadata=True) + route_info = parse_route_info(metadata=metadata) + return route_info # type: ignore + class Config: arbitrary_types_allowed = True + + +def parse_route_info(metadata: List[Dict[str, Any]]) -> List[Tuple]: + """Parses metadata from index to extract route, utterance, function + schema and additional metadata. + + :param metadata: List of metadata dictionaries. + :type metadata: List[Dict[str, Any]] + :return: A list of tuples, each containing route, utterance, function schema and additional metadata. + :rtype: List[Tuple] + """ + route_info = [] + for record in metadata: + sr_route = record.get("sr_route", "") + sr_utterance = record.get("sr_utterance", "") + sr_function_schema = json.loads(record.get("sr_function_schema", "{}")) + if sr_function_schema == {}: + sr_function_schema = None + + additional_metadata = { + key: value + for key, value in record.items() + if key not in ["sr_route", "sr_utterance", "sr_function_schema"] + } + # TODO: Not a fan of tuple packing here + route_info.append( + (sr_route, sr_utterance, sr_function_schema, additional_metadata) + ) + return route_info diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 25d2f5d9..c413a014 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -539,6 +539,13 @@ class PineconeIndex(BaseIndex): def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False): """ Retrieves all vector IDs from the Pinecone index using pagination. + + :param prefix: The prefix to filter the vectors by. + :type prefix: Optional[str] + :param include_metadata: Whether to include metadata in the response. + :type include_metadata: bool + :return: A tuple containing a list of vector IDs and a list of metadata dictionaries. + :rtype: tuple[list[str], list[dict]] """ if self.index is None: raise ValueError("Index is None, could not retrieve vector IDs.") @@ -561,18 +568,6 @@ class PineconeIndex(BaseIndex): return all_vector_ids, metadata - def get_utterances(self) -> List[Tuple]: - """Gets a list of route and utterance objects currently stored in the - index, including additional metadata. - - :return: A list of tuples, each containing route, utterance, function - schema and additional metadata. - :rtype: List[Tuple] - """ - _, metadata = self._get_all(include_metadata=True) - route_tuples = parse_route_info(metadata=metadata) - return route_tuples - def delete(self, route_name: str): route_vec_ids = self._get_route_ids(route_name=route_name) if self.index is not None: @@ -877,46 +872,5 @@ class PineconeIndex(BaseIndex): response_data.get("vectors", {}).get(vector_id, {}).get("metadata", {}) ) - async def _async_get_routes(self) -> List[Tuple]: - """Asynchronously gets a list of route and utterance objects currently - stored in the index, including additional metadata. - - :return: A list of tuples, each containing route, utterance, function - schema and additional metadata. - :rtype: List[Tuple] - """ - _, metadata = await self._async_get_all(include_metadata=True) - route_info = parse_route_info(metadata=metadata) - return route_info # type: ignore - def __len__(self): return self.index.describe_index_stats()["total_vector_count"] - - -def parse_route_info(metadata: List[Dict[str, Any]]) -> List[Tuple]: - """Parses metadata from Pinecone index to extract route, utterance, function - schema and additional metadata. - - :param metadata: List of metadata dictionaries. - :type metadata: List[Dict[str, Any]] - :return: A list of tuples, each containing route, utterance, function schema and additional metadata. - :rtype: List[Tuple] - """ - route_info = [] - for record in metadata: - sr_route = record.get("sr_route", "") - sr_utterance = record.get("sr_utterance", "") - sr_function_schema = json.loads(record.get("sr_function_schema", "{}")) - if sr_function_schema == {}: - sr_function_schema = None - - additional_metadata = { - key: value - for key, value in record.items() - if key not in ["sr_route", "sr_utterance", "sr_function_schema"] - } - # TODO: Not a fan of tuple packing here - route_info.append( - (sr_route, sr_utterance, sr_function_schema, additional_metadata) - ) - return route_info diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py index 2889110b..eadfeb84 100644 --- a/semantic_router/index/postgres.py +++ b/semantic_router/index/postgres.py @@ -422,19 +422,6 @@ class PostgresIndex(BaseIndex): return all_vector_ids, metadata - def get_utterances(self) -> List[Tuple]: - """ - Gets a list of route and utterance objects currently stored in the index. - - :return: A list of (route_name, utterance, function_schema, metadata) tuples. - :rtype: List[Tuple] - """ - # Get all records with metadata - _, metadata = self._get_all(include_metadata=True) - # Create a list of (route_name, utterance, function_schema, metadata) tuples - route_tuples = [(x["sr_route"], x["sr_utterance"], None, {}) for x in metadata] - return route_tuples - def delete_all(self): """ Deletes all records from the Postgres index. diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index 10268688..4f564c8d 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -227,7 +227,9 @@ class QdrantIndex(BaseIndex): results.extend(records) - route_tuples = [ + route_tuples: List[ + Tuple[str, str, Optional[Dict[str, Any]], Dict[str, Any]] + ] = [ ( x.payload[SR_ROUTE_PAYLOAD_KEY], x.payload[SR_UTTERANCE_PAYLOAD_KEY], diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 8637dfdf..efa1289e 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -127,7 +127,9 @@ class LayerConfig: @classmethod def from_tuples( cls, - route_tuples: List[Tuple[str, str]], + route_tuples: List[ + Tuple[str, str, Optional[List[Dict[str, Any]]], Dict[str, Any]] + ], encoder_type: str = "openai", encoder_name: Optional[str] = None, ): @@ -142,25 +144,25 @@ class LayerConfig: :param encoder_name: The name of the encoder to use, defaults to None. :type encoder_name: Optional[str], optional """ - routes: List[Route] = [] - routes_dict: Dict[str, List[str]] = {} - # first create a dictionary of routes mapping to all their utterances, - # function_schema, and metadata + routes_dict: Dict[str, Route] = {} + # first create a dictionary of route names to Route objects + # TODO: duplicated code with BaseIndex.get_routes() for route_name, utterance, function_schema, metadata in route_tuples: - routes_dict.setdefault( - route_name, - { - "function_schemas": None, - "metadata": {}, - }, - ) - routes_dict[route_name]["utterances"] = routes_dict[route_name].get( - "utterances", [] - ) - routes_dict[route_name]["utterances"].append(utterance) + # if the route is not in the dictionary, add it + if route_name not in routes_dict: + routes_dict[route_name] = Route( + name=route_name, + utterances=[utterance], + function_schemas=function_schema, + metadata=metadata, + ) + else: + # otherwise, add the utterance to the route + routes_dict[route_name].utterances.append(utterance) # then create a list of routes from the dictionary - for route_name, route_data in routes_dict.items(): - routes.append(Route(name=route_name, **route_data)) + routes: List[Route] = [] + for route_name, route in routes_dict.items(): + routes.append(route) return cls(routes=routes, encoder_type=encoder_type, encoder_name=encoder_name) @classmethod @@ -216,7 +218,7 @@ class LayerConfig: elif ext in [".yaml", ".yml"]: yaml.safe_dump(self.to_dict(), f) - def _get_diff(self, other: "LayerConfig") -> List[Dict[str, Any]]: + def _get_diff(self, other: "LayerConfig") -> List[str]: """Get the difference between two LayerConfigs. :param other: The LayerConfig to compare to. @@ -224,6 +226,8 @@ class LayerConfig: :return: A list of differences between the two LayerConfigs. :rtype: List[Dict[str, Any]] """ + # TODO: formalize diffs into likely LayerDiff objects that can then + # output different formats as required to enable smarter syncs self_yaml = yaml.dump(self.to_dict()) other_yaml = yaml.dump(other.to_dict()) differ = Differ() -- GitLab