Skip to content
Snippets Groups Projects
Unverified Commit 3c30145e authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Proper Index Base Class and Linting.

parent 20133d36
No related branches found
No related tags found
No related merge requests found
from pydantic.v1 import BaseModel
from typing import Any, List, Tuple, Optional
import numpy as np
class BaseIndex(BaseModel):
# Currently just a placedholder until more indexing methods are added and common attributes/methods are identified.
pass
"""
Base class for indices using Pydantic's BaseModel.
This class outlines the expected interface for index classes.
Actual method implementations should be provided in subclasses.
"""
# You can define common attributes here if there are any.
# For example, a placeholder for the index attribute:
index: Optional[Any] = None
def add(self, embeds: List[Any]):
"""
Add embeddings to the index.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
def remove(self, indices_to_remove: List[int]):
"""
Remove items from the index by their indices.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
def is_index_populated(self) -> bool:
"""
Check if the index is populated.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
def query(self, query_vector: Any, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]:
"""
Search the index for the query_vector and return top_k results.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
class Config:
arbitrary_types_allowed = True
......@@ -11,9 +11,9 @@ class LocalIndex(BaseIndex):
arbitrary_types_allowed = True
def add(self, embeds: List[Any]):
embeds = np.array(embeds)
embeds = np.array(embeds) # type: ignore
if self.index is None:
self.index = embeds
self.index = embeds # type: ignore
else:
self.index = np.concatenate([self.index, embeds])
......@@ -21,14 +21,13 @@ class LocalIndex(BaseIndex):
"""
Remove all items of a specific category from the index.
"""
self.index = np.delete(self.index, indices_to_remove, axis=0)
if self.index is not None:
self.index = np.delete(self.index, indices_to_remove, axis=0)
def is_index_populated(self):
return self.index is not None and len(self.index) > 0
def search(
self, query_vector: Any, top_k: int = 5
) -> Tuple[np.ndarray, np.ndarray]:
def query(self, query_vector: Any, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]:
"""
Search the index for the query and return top_k results.
"""
......
......@@ -12,9 +12,7 @@ from semantic_router.llms import BaseLLM, OpenAILLM
from semantic_router.route import Route
from semantic_router.schema import Encoder, EncoderType, RouteChoice, Index
from semantic_router.utils.logger import logger
from semantic_router.indices.local_index import LocalIndex
IndexType = Union[LocalIndex, None]
from semantic_router.indices.base import BaseIndex
def is_valid(layer_config: str) -> bool:
......@@ -153,11 +151,10 @@ class LayerConfig:
class RouteLayer:
index: Optional[np.ndarray] = None
categories: Optional[np.ndarray] = None
score_threshold: float
encoder: BaseEncoder
index: IndexType = None
index: BaseIndex
def __init__(
self,
......@@ -167,7 +164,7 @@ class RouteLayer:
index_name: Optional[str] = "local",
):
logger.info("local")
self.index = Index.get_by_name(index_name=index_name)
self.index: BaseIndex = Index.get_by_name(index_name=index_name)
self.categories = None
if encoder is None:
logger.warning(
......@@ -338,7 +335,7 @@ class RouteLayer:
"""Given a query vector, retrieve the top_k most similar records."""
if self.index.is_index_populated():
# calculate similarity matrix
scores, idx = self.index.search(xq, top_k)
scores, idx = self.index.query(xq, top_k)
# get the utterance categories (route names)
routes = self.categories[idx] if self.categories is not None else []
return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)]
......
......@@ -79,7 +79,7 @@ class DocumentSplit(BaseModel):
class Index:
@classmethod
def get_by_name(cls, index_name: str):
def get_by_name(cls, index_name: Optional[str] = None):
if index_name == "local" or index_name is None:
return LocalIndex()
# TODO: Later we'll add more index options.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment