From 967fd96b70898db1506f15307fefcbd30791feee Mon Sep 17 00:00:00 2001
From: James Briggs <james.briggs@hotmail.com>
Date: Thu, 13 Jun 2024 16:30:00 +0800
Subject: [PATCH] chore: lint

---
 semantic_router/encoders/base.py            | 12 ++++++++----
 semantic_router/encoders/openai.py          | 21 ++++++++++++++-------
 semantic_router/hybrid_layer.py             |  5 +++++
 semantic_router/index/base.py               | 12 ++++++++++++
 semantic_router/index/pinecone.py           | 12 +++++++-----
 semantic_router/layer.py                    |  5 +++++
 semantic_router/splitters/rolling_window.py |  5 +++++
 7 files changed, 56 insertions(+), 16 deletions(-)

diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py
index bb9ce207..fcc5734d 100644
--- a/semantic_router/encoders/base.py
+++ b/semantic_router/encoders/base.py
@@ -1,18 +1,22 @@
-from typing import Any, List
+from typing import Any, Coroutine, List, Optional
 
-from pydantic.v1 import BaseModel, Field
+from pydantic.v1 import BaseModel, Field, validator
 
 
 class BaseEncoder(BaseModel):
     name: str
-    score_threshold: float
+    score_threshold: Optional[float] = None
     type: str = Field(default="base")
 
     class Config:
         arbitrary_types_allowed = True
 
+    @validator("score_threshold", pre=True, always=True)
+    def set_score_threshold(cls, v):
+        return float(v) if v is not None else None
+
     def __call__(self, docs: List[Any]) -> List[List[float]]:
         raise NotImplementedError("Subclasses must implement this method")
 
-    def acall(self, docs: List[Any]) -> List[List[float]]:
+    def acall(self, docs: List[Any]) -> Coroutine[Any, Any, List[List[float]]]:
         raise NotImplementedError("Subclasses must implement this method")
diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py
index ad328447..e4acc5a4 100644
--- a/semantic_router/encoders/openai.py
+++ b/semantic_router/encoders/openai.py
@@ -18,15 +18,18 @@ from semantic_router.utils.logger import logger
 
 model_configs = {
     "text-embedding-ada-002": EncoderInfo(
-        name="text-embedding-ada-002", token_limit=8192,
+        name="text-embedding-ada-002",
+        token_limit=8192,
         threshold=0.82,
     ),
     "text-embedding-3-small": EncoderInfo(
-        name="text-embedding-3-small", token_limit=8192,
+        name="text-embedding-3-small",
+        token_limit=8192,
         threshold=0.3,
     ),
     "text-embedding-3-large": EncoderInfo(
-        name="text-embedding-3-large", token_limit=8192,
+        name="text-embedding-3-large",
+        token_limit=8192,
         threshold=0.3,
     ),
 }
@@ -52,13 +55,17 @@ class OpenAIEncoder(BaseEncoder):
         if name is None:
             name = EncoderDefault.OPENAI.value["embedding_model"]
         if score_threshold is None and name in model_configs:
-            score_threshold = model_configs[name].threshold
+            set_score_threshold = model_configs[name].threshold
         elif score_threshold is None:
-            logger.warning(f"Score threshold not set for model: {name}. Using default value.")
-            score_threshold = 0.82
+            logger.warning(
+                f"Score threshold not set for model: {name}. Using default value."
+            )
+            set_score_threshold = 0.82
+        else:
+            set_score_threshold = score_threshold
         super().__init__(
             name=name,
-            score_threshold=score_threshold,
+            score_threshold=set_score_threshold,
         )
         api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
         base_url = openai_base_url or os.getenv("OPENAI_BASE_URL")
diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py
index 5f223384..2aaf6f17 100644
--- a/semantic_router/hybrid_layer.py
+++ b/semantic_router/hybrid_layer.py
@@ -28,6 +28,11 @@ class HybridRouteLayer:
         aggregation: str = "sum",
     ):
         self.encoder = encoder
+        if self.encoder.score_threshold is None:
+            raise ValueError(
+                "No score threshold provided for encoder. Please set the score threshold "
+                "in the encoder config."
+            )
         self.score_threshold = self.encoder.score_threshold
 
         if sparse_encoder is None:
diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py
index bc5adcca..e6638fb1 100644
--- a/semantic_router/index/base.py
+++ b/semantic_router/index/base.py
@@ -55,6 +55,18 @@ class BaseIndex(BaseModel):
         """
         raise NotImplementedError("This method should be implemented by subclasses.")
 
+    async def aquery(
+        self,
+        vector: np.ndarray,
+        top_k: int = 5,
+        route_filter: Optional[List[str]] = None,
+    ) -> Tuple[np.ndarray, List[str]]:
+        """
+        Search the index for the query_vector and return top_k results.
+        This method should be implemented by subclasses.
+        """
+        raise NotImplementedError("This method should be implemented by subclasses.")
+
     def delete_index(self):
         """
         Deletes or resets the index.
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 3572144c..7d3828f4 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -102,6 +102,9 @@ class PineconeIndex(BaseIndex):
         return Pinecone(**pinecone_args)
 
     def _initialize_async_client(self, api_key: Optional[str] = None):
+        api_key = api_key or self.api_key
+        if api_key is None:
+            raise ValueError("Pinecone API key is required.")
         async_client = aiohttp.ClientSession(
             headers={
                 "Api-Key": api_key,
@@ -159,13 +162,12 @@ class PineconeIndex(BaseIndex):
             self.host = self.client.describe_index(self.index_name)["host"]
         return index
 
-    async def _init_async_index(self, force_create: bool = False) -> Union[Any, None]:
+    async def _init_async_index(self, force_create: bool = False):
         index_stats = None
         indexes = await self._async_list_indexes()
         index_names = [i["name"] for i in indexes["indexes"]]
         index_exists = self.index_name in index_names
-        dimensions_given = self.dimensions is not None
-        if dimensions_given and not index_exists:
+        if self.dimensions is not None and not index_exists:
             await self._async_create_index(
                 name=self.index_name,
                 dimension=self.dimensions,
@@ -183,7 +185,7 @@ class PineconeIndex(BaseIndex):
             index_stats = await self._async_describe_index(self.index_name)
             # grab dimensions for the index
             self.dimensions = index_stats["dimension"]
-        elif force_create and not dimensions_given:
+        elif force_create and self.dimensions is None:
             raise ValueError(
                 "Cannot create an index without specifying the dimensions."
             )
@@ -348,7 +350,7 @@ class PineconeIndex(BaseIndex):
             filter_query = None
         results = await self._async_query(
             vector=query_vector_list,
-            namespace=self.namespace,
+            namespace=self.namespace or "",
             filter=filter_query,
             top_k=top_k,
             include_metadata=True,
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 68043523..6b9ee1d2 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -196,6 +196,11 @@ class RouteLayer:
             self.encoder = encoder
         self.llm = llm
         self.routes: List[Route] = routes if routes is not None else []
+        if self.encoder.score_threshold is None:
+            raise ValueError(
+                "No score threshold provided for encoder. Please set the score threshold "
+                "in the encoder config."
+            )
         self.score_threshold = self.encoder.score_threshold
         self.top_k = top_k
         if self.top_k < 1:
diff --git a/semantic_router/splitters/rolling_window.py b/semantic_router/splitters/rolling_window.py
index dc393b55..7fdcc63e 100644
--- a/semantic_router/splitters/rolling_window.py
+++ b/semantic_router/splitters/rolling_window.py
@@ -95,6 +95,11 @@ class RollingWindowSplitter(BaseSplitter):
         if self.dynamic_threshold:
             self._find_optimal_threshold(docs, similarities)
         else:
+            if self.encoder.score_threshold is None:
+                raise ValueError(
+                    "No score threshold provided for encoder. Please set the score threshold "
+                    "in the encoder config."
+                )
             self.calculated_threshold = self.encoder.score_threshold
         split_indices = self._find_split_indices(similarities=similarities)
         splits = self._split_documents(docs, split_indices, similarities)
-- 
GitLab