From fbdb7be815cb17111cef4d929431c62f0566d1bd Mon Sep 17 00:00:00 2001
From: James Briggs <james.briggs@hotmail.com>
Date: Sat, 9 Nov 2024 21:10:58 +0100
Subject: [PATCH] fix: resolve errors and lint

---
 semantic_router/index/base.py     | 118 ++++++++++++++++++++++++++----
 semantic_router/index/pinecone.py |  60 ++-------------
 semantic_router/index/postgres.py |  13 ----
 semantic_router/index/qdrant.py   |   4 +-
 semantic_router/layer.py          |  42 ++++++-----
 5 files changed, 135 insertions(+), 102 deletions(-)

diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py
index bd5c7c94..da4fba54 100644
--- a/semantic_router/index/base.py
+++ b/semantic_router/index/base.py
@@ -1,4 +1,5 @@
 from typing import Any, List, Optional, Tuple, Union, Dict
+import json
 
 import numpy as np
 from pydantic.v1 import BaseModel
@@ -38,6 +39,20 @@ class BaseIndex(BaseModel):
         """
         raise NotImplementedError("This method should be implemented by subclasses.")
 
+    def get_utterances(self) -> List[Tuple]:
+        """Gets a list of route and utterance objects currently stored in the
+        index, including additional metadata.
+
+        :return: A list of tuples, each containing route, utterance, function
+        schema and additional metadata.
+        :rtype: List[Tuple]
+        """
+        _, metadata = self._get_all(include_metadata=True)
+        route_tuples: List[
+            Tuple[str, str, Optional[Dict[str, Any]], Dict[str, Any]]
+        ] = [(x["sr_route"], x["sr_utterance"], None, {}) for x in metadata]
+        return route_tuples
+
     def get_routes(self) -> List[Route]:
         """Gets a list of route objects currently stored in the index.
 
@@ -45,25 +60,24 @@ class BaseIndex(BaseModel):
         :rtype: List[Route]
         """
         route_tuples = self.get_utterances()
-        routes_dict: Dict[str, List[str]] = {}
-        # first create a dictionary of routes mapping to all their utterances,
-        # function_schema, and metadata
+        routes_dict: Dict[str, Route] = {}
+        # first create a dictionary of route names to Route objects
         for route_name, utterance, function_schema, metadata in route_tuples:
-            routes_dict.setdefault(
-                route_name,
-                {
-                    "function_schemas": None,
-                    "metadata": {},
-                },
-            )
-            routes_dict[route_name]["utterances"] = routes_dict[route_name].get(
-                "utterances", []
-            )
-            routes_dict[route_name]["utterances"].append(utterance)
+            # if the route is not in the dictionary, add it
+            if route_name not in routes_dict:
+                routes_dict[route_name] = Route(
+                    name=route_name,
+                    utterances=[utterance],
+                    function_schemas=function_schema,
+                    metadata=metadata,
+                )
+            else:
+                # otherwise, add the utterance to the route
+                routes_dict[route_name].utterances.append(utterance)
         # then create a list of routes from the dictionary
         routes: List[Route] = []
-        for route_name, route_data in routes_dict.items():
-            routes.append(Route(name=route_name, **route_data))
+        for route_name, route in routes_dict.items():
+            routes.append(route)
         return routes
 
     def _remove_and_sync(self, routes_to_delete: dict):
@@ -181,5 +195,77 @@ class BaseIndex(BaseModel):
         """
         raise NotImplementedError("This method should be implemented by subclasses.")
 
+    def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False):
+        """
+        Retrieves all vector IDs from the index.
+
+        This method should be implemented by subclasses.
+
+        :param prefix: The prefix to filter the vectors by.
+        :type prefix: Optional[str]
+        :param include_metadata: Whether to include metadata in the response.
+        :type include_metadata: bool
+        :return: A tuple containing a list of vector IDs and a list of metadata dictionaries.
+        :rtype: tuple[list[str], list[dict]]
+        """
+        raise NotImplementedError("This method should be implemented by subclasses.")
+
+    async def _async_get_all(
+        self, prefix: Optional[str] = None, include_metadata: bool = False
+    ) -> tuple[list[str], list[dict]]:
+        """Retrieves all vector IDs from the index asynchronously.
+
+        This method should be implemented by subclasses.
+
+        :param prefix: The prefix to filter the vectors by.
+        :type prefix: Optional[str]
+        :param include_metadata: Whether to include metadata in the response.
+        :type include_metadata: bool
+        :return: A tuple containing a list of vector IDs and a list of metadata dictionaries.
+        :rtype: tuple[list[str], list[dict]]
+        """
+        raise NotImplementedError("This method should be implemented by subclasses.")
+
+    async def _async_get_routes(self) -> List[Tuple]:
+        """Asynchronously gets a list of route and utterance objects currently
+        stored in the index, including additional metadata.
+
+        :return: A list of tuples, each containing route, utterance, function
+        schema and additional metadata.
+        :rtype: List[Tuple]
+        """
+        _, metadata = await self._async_get_all(include_metadata=True)
+        route_info = parse_route_info(metadata=metadata)
+        return route_info  # type: ignore
+
     class Config:
         arbitrary_types_allowed = True
+
+
+def parse_route_info(metadata: List[Dict[str, Any]]) -> List[Tuple]:
+    """Parses metadata from index to extract route, utterance, function
+    schema and additional metadata.
+
+    :param metadata: List of metadata dictionaries.
+    :type metadata: List[Dict[str, Any]]
+    :return: A list of tuples, each containing route, utterance, function schema and additional metadata.
+    :rtype: List[Tuple]
+    """
+    route_info = []
+    for record in metadata:
+        sr_route = record.get("sr_route", "")
+        sr_utterance = record.get("sr_utterance", "")
+        sr_function_schema = json.loads(record.get("sr_function_schema", "{}"))
+        if sr_function_schema == {}:
+            sr_function_schema = None
+
+        additional_metadata = {
+            key: value
+            for key, value in record.items()
+            if key not in ["sr_route", "sr_utterance", "sr_function_schema"]
+        }
+        # TODO: Not a fan of tuple packing here
+        route_info.append(
+            (sr_route, sr_utterance, sr_function_schema, additional_metadata)
+        )
+    return route_info
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 25d2f5d9..c413a014 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -539,6 +539,13 @@ class PineconeIndex(BaseIndex):
     def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False):
         """
         Retrieves all vector IDs from the Pinecone index using pagination.
+
+        :param prefix: The prefix to filter the vectors by.
+        :type prefix: Optional[str]
+        :param include_metadata: Whether to include metadata in the response.
+        :type include_metadata: bool
+        :return: A tuple containing a list of vector IDs and a list of metadata dictionaries.
+        :rtype: tuple[list[str], list[dict]]
         """
         if self.index is None:
             raise ValueError("Index is None, could not retrieve vector IDs.")
@@ -561,18 +568,6 @@ class PineconeIndex(BaseIndex):
 
         return all_vector_ids, metadata
 
-    def get_utterances(self) -> List[Tuple]:
-        """Gets a list of route and utterance objects currently stored in the
-        index, including additional metadata.
-
-        :return: A list of tuples, each containing route, utterance, function
-        schema and additional metadata.
-        :rtype: List[Tuple]
-        """
-        _, metadata = self._get_all(include_metadata=True)
-        route_tuples = parse_route_info(metadata=metadata)
-        return route_tuples
-
     def delete(self, route_name: str):
         route_vec_ids = self._get_route_ids(route_name=route_name)
         if self.index is not None:
@@ -877,46 +872,5 @@ class PineconeIndex(BaseIndex):
                 response_data.get("vectors", {}).get(vector_id, {}).get("metadata", {})
             )
 
-    async def _async_get_routes(self) -> List[Tuple]:
-        """Asynchronously gets a list of route and utterance objects currently
-        stored in the index, including additional metadata.
-
-        :return: A list of tuples, each containing route, utterance, function
-        schema and additional metadata.
-        :rtype: List[Tuple]
-        """
-        _, metadata = await self._async_get_all(include_metadata=True)
-        route_info = parse_route_info(metadata=metadata)
-        return route_info  # type: ignore
-
     def __len__(self):
         return self.index.describe_index_stats()["total_vector_count"]
-
-
-def parse_route_info(metadata: List[Dict[str, Any]]) -> List[Tuple]:
-    """Parses metadata from Pinecone index to extract route, utterance, function
-    schema and additional metadata.
-
-    :param metadata: List of metadata dictionaries.
-    :type metadata: List[Dict[str, Any]]
-    :return: A list of tuples, each containing route, utterance, function schema and additional metadata.
-    :rtype: List[Tuple]
-    """
-    route_info = []
-    for record in metadata:
-        sr_route = record.get("sr_route", "")
-        sr_utterance = record.get("sr_utterance", "")
-        sr_function_schema = json.loads(record.get("sr_function_schema", "{}"))
-        if sr_function_schema == {}:
-            sr_function_schema = None
-
-        additional_metadata = {
-            key: value
-            for key, value in record.items()
-            if key not in ["sr_route", "sr_utterance", "sr_function_schema"]
-        }
-        # TODO: Not a fan of tuple packing here
-        route_info.append(
-            (sr_route, sr_utterance, sr_function_schema, additional_metadata)
-        )
-    return route_info
diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py
index 2889110b..eadfeb84 100644
--- a/semantic_router/index/postgres.py
+++ b/semantic_router/index/postgres.py
@@ -422,19 +422,6 @@ class PostgresIndex(BaseIndex):
 
         return all_vector_ids, metadata
 
-    def get_utterances(self) -> List[Tuple]:
-        """
-        Gets a list of route and utterance objects currently stored in the index.
-
-        :return: A list of (route_name, utterance, function_schema, metadata) tuples.
-        :rtype: List[Tuple]
-        """
-        # Get all records with metadata
-        _, metadata = self._get_all(include_metadata=True)
-        # Create a list of (route_name, utterance, function_schema, metadata) tuples
-        route_tuples = [(x["sr_route"], x["sr_utterance"], None, {}) for x in metadata]
-        return route_tuples
-
     def delete_all(self):
         """
         Deletes all records from the Postgres index.
diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py
index 10268688..4f564c8d 100644
--- a/semantic_router/index/qdrant.py
+++ b/semantic_router/index/qdrant.py
@@ -227,7 +227,9 @@ class QdrantIndex(BaseIndex):
 
             results.extend(records)
 
-        route_tuples = [
+        route_tuples: List[
+            Tuple[str, str, Optional[Dict[str, Any]], Dict[str, Any]]
+        ] = [
             (
                 x.payload[SR_ROUTE_PAYLOAD_KEY],
                 x.payload[SR_UTTERANCE_PAYLOAD_KEY],
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 8637dfdf..efa1289e 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -127,7 +127,9 @@ class LayerConfig:
     @classmethod
     def from_tuples(
         cls,
-        route_tuples: List[Tuple[str, str]],
+        route_tuples: List[
+            Tuple[str, str, Optional[List[Dict[str, Any]]], Dict[str, Any]]
+        ],
         encoder_type: str = "openai",
         encoder_name: Optional[str] = None,
     ):
@@ -142,25 +144,25 @@ class LayerConfig:
         :param encoder_name: The name of the encoder to use, defaults to None.
         :type encoder_name: Optional[str], optional
         """
-        routes: List[Route] = []
-        routes_dict: Dict[str, List[str]] = {}
-        # first create a dictionary of routes mapping to all their utterances,
-        # function_schema, and metadata
+        routes_dict: Dict[str, Route] = {}
+        # first create a dictionary of route names to Route objects
+        # TODO: duplicated code with BaseIndex.get_routes()
         for route_name, utterance, function_schema, metadata in route_tuples:
-            routes_dict.setdefault(
-                route_name,
-                {
-                    "function_schemas": None,
-                    "metadata": {},
-                },
-            )
-            routes_dict[route_name]["utterances"] = routes_dict[route_name].get(
-                "utterances", []
-            )
-            routes_dict[route_name]["utterances"].append(utterance)
+            # if the route is not in the dictionary, add it
+            if route_name not in routes_dict:
+                routes_dict[route_name] = Route(
+                    name=route_name,
+                    utterances=[utterance],
+                    function_schemas=function_schema,
+                    metadata=metadata,
+                )
+            else:
+                # otherwise, add the utterance to the route
+                routes_dict[route_name].utterances.append(utterance)
         # then create a list of routes from the dictionary
-        for route_name, route_data in routes_dict.items():
-            routes.append(Route(name=route_name, **route_data))
+        routes: List[Route] = []
+        for route_name, route in routes_dict.items():
+            routes.append(route)
         return cls(routes=routes, encoder_type=encoder_type, encoder_name=encoder_name)
 
     @classmethod
@@ -216,7 +218,7 @@ class LayerConfig:
             elif ext in [".yaml", ".yml"]:
                 yaml.safe_dump(self.to_dict(), f)
 
-    def _get_diff(self, other: "LayerConfig") -> List[Dict[str, Any]]:
+    def _get_diff(self, other: "LayerConfig") -> List[str]:
         """Get the difference between two LayerConfigs.
 
         :param other: The LayerConfig to compare to.
@@ -224,6 +226,8 @@ class LayerConfig:
         :return: A list of differences between the two LayerConfigs.
         :rtype: List[Dict[str, Any]]
         """
+        # TODO: formalize diffs into likely LayerDiff objects that can then
+        # output different formats as required to enable smarter syncs
         self_yaml = yaml.dump(self.to_dict())
         other_yaml = yaml.dump(other.to_dict())
         differ = Differ()
-- 
GitLab