From 3c30145e1e8682455d061660a34d2720bd825fb2 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Wed, 7 Feb 2024 00:30:29 +0400
Subject: [PATCH] Proper Index Base Class and Linting.

---
 semantic_router/indices/base.py        | 44 ++++++++++++++++++++++++--
 semantic_router/indices/local_index.py | 11 +++----
 semantic_router/layer.py               | 11 +++----
 semantic_router/schema.py              |  2 +-
 4 files changed, 52 insertions(+), 16 deletions(-)

diff --git a/semantic_router/indices/base.py b/semantic_router/indices/base.py
index baa5a204..795aa685 100644
--- a/semantic_router/indices/base.py
+++ b/semantic_router/indices/base.py
@@ -1,6 +1,46 @@
 from pydantic.v1 import BaseModel
+from typing import Any, List, Tuple, Optional
+import numpy as np
 
 
 class BaseIndex(BaseModel):
-    # Currently just a placedholder until more indexing methods are added and common attributes/methods are identified.
-    pass
+    """
+    Base class for indices using Pydantic's BaseModel.
+    This class outlines the expected interface for index classes.
+    Actual method implementations should be provided in subclasses.
+    """
+
+    # You can define common attributes here if there are any.
+    # For example, a placeholder for the index attribute:
+    index: Optional[Any] = None
+
+    def add(self, embeds: List[Any]):
+        """
+        Add embeddings to the index.
+        This method should be implemented by subclasses.
+        """
+        raise NotImplementedError("This method should be implemented by subclasses.")
+
+    def remove(self, indices_to_remove: List[int]):
+        """
+        Remove items from the index by their indices.
+        This method should be implemented by subclasses.
+        """
+        raise NotImplementedError("This method should be implemented by subclasses.")
+
+    def is_index_populated(self) -> bool:
+        """
+        Check if the index is populated.
+        This method should be implemented by subclasses.
+        """
+        raise NotImplementedError("This method should be implemented by subclasses.")
+
+    def query(self, query_vector: Any, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]:
+        """
+        Search the index for the query_vector and return top_k results.
+        This method should be implemented by subclasses.
+        """
+        raise NotImplementedError("This method should be implemented by subclasses.")
+
+    class Config:
+        arbitrary_types_allowed = True
diff --git a/semantic_router/indices/local_index.py b/semantic_router/indices/local_index.py
index 3c0d32dc..0f60953d 100644
--- a/semantic_router/indices/local_index.py
+++ b/semantic_router/indices/local_index.py
@@ -11,9 +11,9 @@ class LocalIndex(BaseIndex):
         arbitrary_types_allowed = True
 
     def add(self, embeds: List[Any]):
-        embeds = np.array(embeds)
+        embeds = np.array(embeds)  # type: ignore
         if self.index is None:
-            self.index = embeds
+            self.index = embeds  # type: ignore
         else:
             self.index = np.concatenate([self.index, embeds])
 
@@ -21,14 +21,13 @@ class LocalIndex(BaseIndex):
         """
         Remove all items of a specific category from the index.
         """
-        self.index = np.delete(self.index, indices_to_remove, axis=0)
+        if self.index is not None:
+            self.index = np.delete(self.index, indices_to_remove, axis=0)
 
     def is_index_populated(self):
         return self.index is not None and len(self.index) > 0
 
-    def search(
-        self, query_vector: Any, top_k: int = 5
-    ) -> Tuple[np.ndarray, np.ndarray]:
+    def query(self, query_vector: Any, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]:
         """
         Search the index for the query and return top_k results.
         """
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 3061f29e..9794e2e2 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -12,9 +12,7 @@ from semantic_router.llms import BaseLLM, OpenAILLM
 from semantic_router.route import Route
 from semantic_router.schema import Encoder, EncoderType, RouteChoice, Index
 from semantic_router.utils.logger import logger
-from semantic_router.indices.local_index import LocalIndex
-
-IndexType = Union[LocalIndex, None]
+from semantic_router.indices.base import BaseIndex
 
 
 def is_valid(layer_config: str) -> bool:
@@ -153,11 +151,10 @@ class LayerConfig:
 
 
 class RouteLayer:
-    index: Optional[np.ndarray] = None
     categories: Optional[np.ndarray] = None
     score_threshold: float
     encoder: BaseEncoder
-    index: IndexType = None
+    index: BaseIndex
 
     def __init__(
         self,
@@ -167,7 +164,7 @@ class RouteLayer:
         index_name: Optional[str] = "local",
     ):
         logger.info("local")
-        self.index = Index.get_by_name(index_name=index_name)
+        self.index: BaseIndex = Index.get_by_name(index_name=index_name)
         self.categories = None
         if encoder is None:
             logger.warning(
@@ -338,7 +335,7 @@ class RouteLayer:
         """Given a query vector, retrieve the top_k most similar records."""
         if self.index.is_index_populated():
             # calculate similarity matrix
-            scores, idx = self.index.search(xq, top_k)
+            scores, idx = self.index.query(xq, top_k)
             # get the utterance categories (route names)
             routes = self.categories[idx] if self.categories is not None else []
             return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)]
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 95e29bfe..40c49282 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -79,7 +79,7 @@ class DocumentSplit(BaseModel):
 
 class Index:
     @classmethod
-    def get_by_name(cls, index_name: str):
+    def get_by_name(cls, index_name: Optional[str] = None):
         if index_name == "local" or index_name is None:
             return LocalIndex()
         # TODO: Later we'll add more index options.
-- 
GitLab