From 6eb71c1c5b6a5309c7c178e4bd74ca361d6d3c9e Mon Sep 17 00:00:00 2001
From: jamescalam <james.briggs@hotmail.com>
Date: Thu, 28 Nov 2024 13:01:25 +0100
Subject: [PATCH] chore: mypy lint

---
 semantic_router/encoders/__init__.py  |  4 ++--
 semantic_router/encoders/bm25.py      | 10 +++------
 semantic_router/index/base.py         |  4 +++-
 semantic_router/index/hybrid_local.py | 18 ++++++++++------
 semantic_router/index/local.py        |  4 +++-
 semantic_router/index/pinecone.py     | 29 ++++++++++++++++++-------
 semantic_router/index/postgres.py     |  3 ++-
 semantic_router/index/qdrant.py       |  4 +++-
 semantic_router/routers/base.py       | 31 +++++++++++++++++++++------
 semantic_router/routers/hybrid.py     | 20 ++++++++++++-----
 semantic_router/schema.py             |  5 +++++
 tests/unit/test_hybrid_layer.py       | 10 ++++++++-
 12 files changed, 102 insertions(+), 40 deletions(-)

diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py
index 85b32e2c..07e468d8 100644
--- a/semantic_router/encoders/__init__.py
+++ b/semantic_router/encoders/__init__.py
@@ -15,7 +15,7 @@ from semantic_router.encoders.openai import OpenAIEncoder
 from semantic_router.encoders.tfidf import TfidfEncoder
 from semantic_router.encoders.vit import VitEncoder
 from semantic_router.encoders.zure import AzureOpenAIEncoder
-from semantic_router.schema import EncoderType
+from semantic_router.schema import EncoderType, SparseEmbedding
 
 __all__ = [
     "AurelioSparseEncoder",
@@ -79,5 +79,5 @@ class AutoEncoder:
         else:
             raise ValueError(f"Encoder type '{type}' not supported")
 
-    def __call__(self, texts: List[str]) -> List[List[float]]:
+    def __call__(self, texts: List[str]) -> List[List[float]] | List[SparseEmbedding]:
         return self.model(texts)
diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py
index 3357ded8..ea0a89a6 100644
--- a/semantic_router/encoders/bm25.py
+++ b/semantic_router/encoders/bm25.py
@@ -53,12 +53,8 @@ class BM25Encoder(TfidfEncoder):
         else:
             raise ValueError("No documents to encode.")
 
-        embeds = [[0.0] * len(self.idx_mapping)] * len(docs)
+        embeds = []
         for i, output in enumerate(sparse_dicts):
-            indices = output["indices"]
-            values = output["values"]
-            for idx, val in zip(indices, values):
-                if idx in self.idx_mapping:
-                    position = self.idx_mapping[idx]
-                    embeds[i][position] = val
+            if isinstance(output, dict):
+                embeds.append(SparseEmbedding.from_pinecone_dict(output))
         return embeds
diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py
index 65f2cf1e..97fe3bd4 100644
--- a/semantic_router/index/base.py
+++ b/semantic_router/index/base.py
@@ -4,7 +4,7 @@ import json
 import numpy as np
 from pydantic.v1 import BaseModel
 
-from semantic_router.schema import ConfigParameter, Utterance
+from semantic_router.schema import ConfigParameter, SparseEmbedding, Utterance
 from semantic_router.route import Route
 from semantic_router.utils.logger import logger
 
@@ -108,6 +108,7 @@ class BaseIndex(BaseModel):
         vector: np.ndarray,
         top_k: int = 5,
         route_filter: Optional[List[str]] = None,
+        sparse_vector: dict[int, float] | SparseEmbedding | None = None,
     ) -> Tuple[np.ndarray, List[str]]:
         """
         Search the index for the query_vector and return top_k results.
@@ -120,6 +121,7 @@ class BaseIndex(BaseModel):
         vector: np.ndarray,
         top_k: int = 5,
         route_filter: Optional[List[str]] = None,
+        sparse_vector: dict[int, float] | SparseEmbedding | None = None,
     ) -> Tuple[np.ndarray, List[str]]:
         """
         Search the index for the query_vector and return top_k results.
diff --git a/semantic_router/index/hybrid_local.py b/semantic_router/index/hybrid_local.py
index f2821422..e2a75778 100644
--- a/semantic_router/index/hybrid_local.py
+++ b/semantic_router/index/hybrid_local.py
@@ -3,7 +3,7 @@ from typing import List, Optional, Tuple, Dict
 import numpy as np
 from numpy.linalg import norm
 
-from semantic_router.schema import ConfigParameter, Utterance
+from semantic_router.schema import ConfigParameter, SparseEmbedding, Utterance
 from semantic_router.index.local import LocalIndex
 from semantic_router.utils.logger import logger
 from typing import Any
@@ -76,6 +76,8 @@ class HybridLocalIndex(LocalIndex):
         return sum(vec_a[i] * vec_b.get(i, 0) for i in vec_a)
 
     def _sparse_index_dot_product(self, vec_a: dict[int, float]) -> list[float]:
+        if self.sparse_index is None:
+            raise ValueError("self.sparse_index is not populated.")
         dot_products = [
             self._sparse_dot_product(vec_a, vec_b) for vec_b in self.sparse_index
         ]
@@ -86,7 +88,7 @@ class HybridLocalIndex(LocalIndex):
         vector: np.ndarray,
         top_k: int = 5,
         route_filter: Optional[List[str]] = None,
-        sparse_vector: Optional[dict[int, float]] = None,
+        sparse_vector: dict[int, float] | SparseEmbedding | None = None,
     ) -> Tuple[np.ndarray, List[str]]:
         """Search the index for the query and return top_k results.
 
@@ -103,9 +105,13 @@ class HybridLocalIndex(LocalIndex):
             raise ValueError("Route filter is not supported for HybridLocalIndex.")
 
         xq_d = vector.copy()
-        if sparse_vector is None:
-            raise ValueError("Sparse vector is required for HybridLocalIndex.")
-        xq_s = sparse_vector.copy()
+        # align sparse vector type
+        if isinstance(sparse_vector, SparseEmbedding):
+            xq_s = sparse_vector.to_dict()
+        elif isinstance(sparse_vector, dict):
+            xq_s = sparse_vector
+        else:
+            raise ValueError("Sparse vector must be a SparseEmbedding or dict.")
 
         if self.index is not None and self.sparse_index is not None:
             # calculate dense vec similarity
@@ -130,7 +136,7 @@ class HybridLocalIndex(LocalIndex):
         vector: np.ndarray,
         top_k: int = 5,
         route_filter: Optional[List[str]] = None,
-        sparse_vector: Optional[dict[int, float]] = None,
+        sparse_vector: dict[int, float] | SparseEmbedding | None = None,
     ) -> Tuple[np.ndarray, List[str]]:
         """Search the index for the query and return top_k results. This method calls the
         sync `query` method as everything uses numpy computations which is CPU-bound
diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py
index 9d33163e..83cc4f51 100644
--- a/semantic_router/index/local.py
+++ b/semantic_router/index/local.py
@@ -2,7 +2,7 @@ from typing import List, Optional, Tuple, Dict
 
 import numpy as np
 
-from semantic_router.schema import ConfigParameter, Utterance
+from semantic_router.schema import ConfigParameter, SparseEmbedding, Utterance
 from semantic_router.index.base import BaseIndex
 from semantic_router.linear import similarity_matrix, top_scores
 from semantic_router.utils.logger import logger
@@ -68,6 +68,7 @@ class LocalIndex(BaseIndex):
         vector: np.ndarray,
         top_k: int = 5,
         route_filter: Optional[List[str]] = None,
+        sparse_vector: dict[int, float] | SparseEmbedding | None = None,
     ) -> Tuple[np.ndarray, List[str]]:
         """
         Search the index for the query and return top_k results.
@@ -97,6 +98,7 @@ class LocalIndex(BaseIndex):
         vector: np.ndarray,
         top_k: int = 5,
         route_filter: Optional[List[str]] = None,
+        sparse_vector: dict[int, float] | SparseEmbedding | None = None,
     ) -> Tuple[np.ndarray, List[str]]:
         """
         Search the index for the query and return top_k results.
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 7b4dba53..de324f28 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -237,19 +237,19 @@ class PineconeIndex(BaseIndex):
         embeddings: List[List[float]],
         routes: List[str],
         utterances: List[str],
-        function_schemas: Optional[List[Dict[str, Any]]] = None,
+        function_schemas: Optional[Optional[List[Dict[str, Any]]]] = None,
         metadata_list: List[Dict[str, Any]] = [],
         batch_size: int = 100,
-        sparse_embeddings: Optional[List[dict[int, float]]] = None,
+        sparse_embeddings: Optional[Optional[List[dict[int, float]]]] = None,
     ):
         """Add vectors to Pinecone in batches."""
         if self.index is None:
             self.dimensions = self.dimensions or len(embeddings[0])
             self.index = self._init_index(force_create=True)
         if function_schemas is None:
-            function_schemas = [None] * len(embeddings)
+            function_schemas = [{}] * len(embeddings)
         if sparse_embeddings is None:
-            sparse_embeddings = [None] * len(embeddings)
+            sparse_embeddings = [{}] * len(embeddings)
 
         vectors_to_upsert = [
             PineconeRecord(
@@ -261,7 +261,12 @@ class PineconeIndex(BaseIndex):
                 metadata=metadata,
             ).to_dict()
             for vector, route, utterance, function_schema, metadata, sparse_dict in zip(
-                embeddings, routes, utterances, function_schemas, metadata_list, sparse_embeddings  # type: ignore
+                embeddings,
+                routes,
+                utterances,
+                function_schemas,
+                metadata_list,
+                sparse_embeddings,
             )
         ]
 
@@ -449,7 +454,7 @@ class PineconeIndex(BaseIndex):
         vector: np.ndarray,
         top_k: int = 5,
         route_filter: Optional[List[str]] = None,
-        **kwargs: Any,
+        sparse_vector: dict[int, float] | SparseEmbedding | None = None,
     ) -> Tuple[np.ndarray, List[str]]:
         """
         Asynchronously search the index for the query vector and return the top_k results.
@@ -475,9 +480,17 @@ class PineconeIndex(BaseIndex):
             filter_query = {"sr_route": {"$in": route_filter}}
         else:
             filter_query = None
+        # set sparse_vector_obj
+        sparse_vector_obj: dict[str, Any] | None = None
+        if sparse_vector is not None:
+            if isinstance(sparse_vector, dict):
+                sparse_vector_obj = SparseEmbedding.from_dict(sparse_vector)
+            if isinstance(sparse_vector, SparseEmbedding):
+                # unnecessary if-statement but mypy didn't like this otherwise
+                sparse_vector_obj = sparse_vector.to_pinecone()
         results = await self._async_query(
             vector=query_vector_list,
-            sparse_vector=kwargs.get("sparse_vector", None),
+            sparse_vector=sparse_vector_obj,
             namespace=self.namespace or "",
             filter=filter_query,
             top_k=top_k,
@@ -507,7 +520,7 @@ class PineconeIndex(BaseIndex):
     async def _async_query(
         self,
         vector: list[float],
-        sparse_vector: Optional[dict] = None,
+        sparse_vector: dict[str, Any] | None = None,
         namespace: str = "",
         filter: Optional[dict] = None,
         top_k: int = 5,
diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py
index 0d80d745..67de0ada 100644
--- a/semantic_router/index/postgres.py
+++ b/semantic_router/index/postgres.py
@@ -8,7 +8,7 @@ import psycopg2
 from pydantic import BaseModel
 
 from semantic_router.index.base import BaseIndex
-from semantic_router.schema import ConfigParameter, Metric
+from semantic_router.schema import ConfigParameter, Metric, SparseEmbedding
 from semantic_router.utils.logger import logger
 
 
@@ -340,6 +340,7 @@ class PostgresIndex(BaseIndex):
         vector: np.ndarray,
         top_k: int = 5,
         route_filter: Optional[List[str]] = None,
+        sparse_vector: dict[int, float] | SparseEmbedding | None = None,
     ) -> Tuple[np.ndarray, List[str]]:
         """
         Searches the index for the query vector and returns the top_k results.
diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py
index eb7dc688..0da5c25e 100644
--- a/semantic_router/index/qdrant.py
+++ b/semantic_router/index/qdrant.py
@@ -4,7 +4,7 @@ import numpy as np
 from pydantic.v1 import Field
 
 from semantic_router.index.base import BaseIndex
-from semantic_router.schema import ConfigParameter, Metric, Utterance
+from semantic_router.schema import ConfigParameter, Metric, SparseEmbedding, Utterance
 from semantic_router.utils.logger import logger
 
 DEFAULT_COLLECTION_NAME = "semantic-router-index"
@@ -259,6 +259,7 @@ class QdrantIndex(BaseIndex):
         vector: np.ndarray,
         top_k: int = 5,
         route_filter: Optional[List[str]] = None,
+        sparse_vector: dict[int, float] | SparseEmbedding | None = None,
     ) -> Tuple[np.ndarray, List[str]]:
         from qdrant_client import QdrantClient, models
 
@@ -292,6 +293,7 @@ class QdrantIndex(BaseIndex):
         vector: np.ndarray,
         top_k: int = 5,
         route_filter: Optional[List[str]] = None,
+        sparse_vector: dict[int, float] | SparseEmbedding | None = None,
     ) -> Tuple[np.ndarray, List[str]]:
         from qdrant_client import AsyncQdrantClient, models
 
diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py
index 5d59ae58..ad2971a0 100644
--- a/semantic_router/routers/base.py
+++ b/semantic_router/routers/base.py
@@ -1027,6 +1027,15 @@ class BaseRouter(BaseModel):
             )
 
     def _semantic_classify(self, query_results: List[Dict]) -> Tuple[str, List[float]]:
+        """Classify the query results into a single class based on the highest total score.
+        If no classification is found, return an empty string and an empty list.
+
+        :param query_results: The query results to classify. Expected format is a list of
+        dictionaries with "route" and "score" keys.
+        :type query_results: List[Dict]
+        :return: A tuple containing the top class and its associated scores.
+        :rtype: Tuple[str, List[float]]
+        """
         scores_by_class = self.group_scores_by_class(query_results)
 
         if self.aggregation_method is None:
@@ -1049,6 +1058,15 @@ class BaseRouter(BaseModel):
     async def _async_semantic_classify(
         self, query_results: List[Dict]
     ) -> Tuple[str, List[float]]:
+        """Classify the query results into a single class based on the highest total score.
+        If no classification is found, return an empty string and an empty list.
+
+        :param query_results: The query results to classify. Expected format is a list of
+        dictionaries with "route" and "score" keys.
+        :type query_results: List[Dict]
+        :return: A tuple containing the top class and its associated scores.
+        :rtype: Tuple[str, List[float]]
+        """
         scores_by_class = await self.async_group_scores_by_class(query_results)
 
         if self.aggregation_method is None:
@@ -1125,8 +1143,8 @@ class BaseRouter(BaseModel):
         return scores_by_class
 
     def _pass_threshold(self, scores: List[float], threshold: float | None) -> bool:
-        """Test if the route score passes the minimum threshold. If a threshold of None is
-        set, then the route will always pass no matter how low it scores.
+        """Test if the route score passes the minimum threshold. A threshold of None defaults
+        to 0.0, so the route will always pass no matter how low it scores.
 
         :param scores: The scores to test.
         :type scores: List[float]
@@ -1168,9 +1186,9 @@ class BaseRouter(BaseModel):
             for route in self.routes:
                 route.score_threshold = threshold
         else:
-            route = self.get(route_name)
-            if route is not None:
-                route.score_threshold = threshold
+            route_get: Route | None = self.get(route_name)
+            if route_get is not None:
+                route_get.score_threshold = threshold
             else:
                 logger.error(f"Route `{route_name}` not found")
 
@@ -1190,9 +1208,8 @@ class BaseRouter(BaseModel):
         config.to_file(file_path)
 
     def get_thresholds(self) -> Dict[str, float]:
-        # TODO: float() below is hacky fix for lint, fix this with new type?
         thresholds = {
-            route.name: float(route.score_threshold or self.score_threshold)
+            route.name: route.score_threshold or self.score_threshold or 0.0
             for route in self.routes
         }
         return thresholds
diff --git a/semantic_router/routers/hybrid.py b/semantic_router/routers/hybrid.py
index 91ecf2ef..f07429c3 100644
--- a/semantic_router/routers/hybrid.py
+++ b/semantic_router/routers/hybrid.py
@@ -124,24 +124,34 @@ class HybridRouter(BaseRouter):
         route_filter: Optional[List[str]] = None,
         sparse_vector: dict[int, float] | SparseEmbedding | None = None,
     ) -> RouteChoice:
+        vector_arr: np.ndarray | None = None
+        potential_sparse_vector: List[SparseEmbedding] | None = None
         # if no vector provided, encode text to get vector
         if vector is None:
             if text is None:
                 raise ValueError("Either text or vector must be provided")
-            vector, potential_sparse_vector = self._encode(text=[text])
+            vector_arr, potential_sparse_vector = self._encode(text=[text])
         if sparse_vector is None:
             if text is None:
                 raise ValueError("Either text or sparse_vector must be provided")
-            sparse_vector = potential_sparse_vector
+            sparse_vector = (
+                potential_sparse_vector[0] if potential_sparse_vector else None
+            )
+        if sparse_vector is None:
+            raise ValueError("Sparse vector is required for HybridLocalIndex.")
+        vector_arr = vector_arr if vector_arr else np.array(vector)
         # TODO: add alpha as a parameter
         scores, route_names = self.index.query(
-            vector=np.array(vector) if isinstance(vector, list) else vector,
+            vector=vector_arr,
             top_k=self.top_k,
             route_filter=route_filter,
-            sparse_vector=sparse_vector[0],
+            sparse_vector=sparse_vector,
         )
         top_class, top_class_scores = self._semantic_classify(
-            list(zip(scores, route_names))
+            [
+                {"score": score, "route": route}
+                for score, route in zip(scores, route_names)
+            ]
         )
         passed = self._pass_threshold(top_class_scores, self.score_threshold)
         if passed:
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 7fcc8371..c9e63943 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -444,6 +444,11 @@ class SparseEmbedding(BaseModel):
         arr = np.array([list(sparse_dict.keys()), list(sparse_dict.values())]).T
         return cls.from_compact_array(arr)
 
+    @classmethod
+    def from_pinecone_dict(cls, sparse_dict: dict):
+        arr = np.array([sparse_dict["indices"], sparse_dict["values"]]).T
+        return cls.from_compact_array(arr)
+
     def to_dict(self):
         return {
             i: v for i, v in zip(self.embedding[:, 0].astype(int), self.embedding[:, 1])
diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py
index 564bd3a1..f3fbe6da 100644
--- a/tests/unit/test_hybrid_layer.py
+++ b/tests/unit/test_hybrid_layer.py
@@ -75,7 +75,15 @@ def routes():
 
 
 sparse_encoder = BM25Encoder(use_default_params=False)
-sparse_encoder.fit(["The quick brown fox", "jumps over the lazy dog", "Hello, world!"])
+sparse_encoder.fit(
+    [
+        Route(
+            name="Route 1",
+            utterances=["The quick brown fox", "jumps over the lazy dog"],
+        ),
+        Route(name="Route 2", utterances=["Hello, world!"]),
+    ]
+)
 
 
 class TestHybridRouter:
-- 
GitLab