diff --git a/semantic_router/indices/base.py b/semantic_router/indices/base.py index baa5a2041e6a6a293274135012a5fa0555c48742..795aa685c2a7b33e2e888b8f573384dd6582a31f 100644 --- a/semantic_router/indices/base.py +++ b/semantic_router/indices/base.py @@ -1,6 +1,46 @@ from pydantic.v1 import BaseModel +from typing import Any, List, Tuple, Optional +import numpy as np class BaseIndex(BaseModel): - # Currently just a placedholder until more indexing methods are added and common attributes/methods are identified. - pass + """ + Base class for indices using Pydantic's BaseModel. + This class outlines the expected interface for index classes. + Actual method implementations should be provided in subclasses. + """ + + # You can define common attributes here if there are any. + # For example, a placeholder for the index attribute: + index: Optional[Any] = None + + def add(self, embeds: List[Any]): + """ + Add embeddings to the index. + This method should be implemented by subclasses. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + + def remove(self, indices_to_remove: List[int]): + """ + Remove items from the index by their indices. + This method should be implemented by subclasses. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + + def is_index_populated(self) -> bool: + """ + Check if the index is populated. + This method should be implemented by subclasses. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + + def query(self, query_vector: Any, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]: + """ + Search the index for the query_vector and return top_k results. + This method should be implemented by subclasses. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + + class Config: + arbitrary_types_allowed = True diff --git a/semantic_router/indices/local_index.py b/semantic_router/indices/local_index.py index 3c0d32dcee605870feb6c4c5d8bf7df05c9ba517..0f60953d85d21dc6b5ab5bc951978f89ccbb32f0 100644 --- a/semantic_router/indices/local_index.py +++ b/semantic_router/indices/local_index.py @@ -11,9 +11,9 @@ class LocalIndex(BaseIndex): arbitrary_types_allowed = True def add(self, embeds: List[Any]): - embeds = np.array(embeds) + embeds = np.array(embeds) # type: ignore if self.index is None: - self.index = embeds + self.index = embeds # type: ignore else: self.index = np.concatenate([self.index, embeds]) @@ -21,14 +21,13 @@ class LocalIndex(BaseIndex): """ Remove all items of a specific category from the index. """ - self.index = np.delete(self.index, indices_to_remove, axis=0) + if self.index is not None: + self.index = np.delete(self.index, indices_to_remove, axis=0) def is_index_populated(self): return self.index is not None and len(self.index) > 0 - def search( - self, query_vector: Any, top_k: int = 5 - ) -> Tuple[np.ndarray, np.ndarray]: + def query(self, query_vector: Any, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]: """ Search the index for the query and return top_k results. """ diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 3061f29e78f53c0cd263cc01553025fd105d3438..9794e2e2dd8127116e17d12e7a1e03cfe9d0845a 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -12,9 +12,7 @@ from semantic_router.llms import BaseLLM, OpenAILLM from semantic_router.route import Route from semantic_router.schema import Encoder, EncoderType, RouteChoice, Index from semantic_router.utils.logger import logger -from semantic_router.indices.local_index import LocalIndex - -IndexType = Union[LocalIndex, None] +from semantic_router.indices.base import BaseIndex def is_valid(layer_config: str) -> bool: @@ -153,11 +151,10 @@ class LayerConfig: class RouteLayer: - index: Optional[np.ndarray] = None categories: Optional[np.ndarray] = None score_threshold: float encoder: BaseEncoder - index: IndexType = None + index: BaseIndex def __init__( self, @@ -167,7 +164,7 @@ class RouteLayer: index_name: Optional[str] = "local", ): logger.info("local") - self.index = Index.get_by_name(index_name=index_name) + self.index: BaseIndex = Index.get_by_name(index_name=index_name) self.categories = None if encoder is None: logger.warning( @@ -338,7 +335,7 @@ class RouteLayer: """Given a query vector, retrieve the top_k most similar records.""" if self.index.is_index_populated(): # calculate similarity matrix - scores, idx = self.index.search(xq, top_k) + scores, idx = self.index.query(xq, top_k) # get the utterance categories (route names) routes = self.categories[idx] if self.categories is not None else [] return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)] diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 95e29bfe43b26732da5a0e0a0898ad70a56ff46d..40c492820c0445893d3de3705c75de756431a59e 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -79,7 +79,7 @@ class DocumentSplit(BaseModel): class Index: @classmethod - def get_by_name(cls, index_name: str): + def get_by_name(cls, index_name: Optional[str] = None): if index_name == "local" or index_name is None: return LocalIndex() # TODO: Later we'll add more index options.