Skip to content
Snippets Groups Projects
Unverified Commit 3c30145e authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Proper Index Base Class and Linting.

parent 20133d36
No related branches found
No related tags found
No related merge requests found
from pydantic.v1 import BaseModel from pydantic.v1 import BaseModel
from typing import Any, List, Tuple, Optional
import numpy as np
class BaseIndex(BaseModel): class BaseIndex(BaseModel):
# Currently just a placedholder until more indexing methods are added and common attributes/methods are identified. """
pass Base class for indices using Pydantic's BaseModel.
This class outlines the expected interface for index classes.
Actual method implementations should be provided in subclasses.
"""
# You can define common attributes here if there are any.
# For example, a placeholder for the index attribute:
index: Optional[Any] = None
def add(self, embeds: List[Any]):
"""
Add embeddings to the index.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
def remove(self, indices_to_remove: List[int]):
"""
Remove items from the index by their indices.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
def is_index_populated(self) -> bool:
"""
Check if the index is populated.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
def query(self, query_vector: Any, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]:
"""
Search the index for the query_vector and return top_k results.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
class Config:
arbitrary_types_allowed = True
...@@ -11,9 +11,9 @@ class LocalIndex(BaseIndex): ...@@ -11,9 +11,9 @@ class LocalIndex(BaseIndex):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def add(self, embeds: List[Any]): def add(self, embeds: List[Any]):
embeds = np.array(embeds) embeds = np.array(embeds) # type: ignore
if self.index is None: if self.index is None:
self.index = embeds self.index = embeds # type: ignore
else: else:
self.index = np.concatenate([self.index, embeds]) self.index = np.concatenate([self.index, embeds])
...@@ -21,14 +21,13 @@ class LocalIndex(BaseIndex): ...@@ -21,14 +21,13 @@ class LocalIndex(BaseIndex):
""" """
Remove all items of a specific category from the index. Remove all items of a specific category from the index.
""" """
self.index = np.delete(self.index, indices_to_remove, axis=0) if self.index is not None:
self.index = np.delete(self.index, indices_to_remove, axis=0)
def is_index_populated(self): def is_index_populated(self):
return self.index is not None and len(self.index) > 0 return self.index is not None and len(self.index) > 0
def search( def query(self, query_vector: Any, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]:
self, query_vector: Any, top_k: int = 5
) -> Tuple[np.ndarray, np.ndarray]:
""" """
Search the index for the query and return top_k results. Search the index for the query and return top_k results.
""" """
......
...@@ -12,9 +12,7 @@ from semantic_router.llms import BaseLLM, OpenAILLM ...@@ -12,9 +12,7 @@ from semantic_router.llms import BaseLLM, OpenAILLM
from semantic_router.route import Route from semantic_router.route import Route
from semantic_router.schema import Encoder, EncoderType, RouteChoice, Index from semantic_router.schema import Encoder, EncoderType, RouteChoice, Index
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
from semantic_router.indices.local_index import LocalIndex from semantic_router.indices.base import BaseIndex
IndexType = Union[LocalIndex, None]
def is_valid(layer_config: str) -> bool: def is_valid(layer_config: str) -> bool:
...@@ -153,11 +151,10 @@ class LayerConfig: ...@@ -153,11 +151,10 @@ class LayerConfig:
class RouteLayer: class RouteLayer:
index: Optional[np.ndarray] = None
categories: Optional[np.ndarray] = None categories: Optional[np.ndarray] = None
score_threshold: float score_threshold: float
encoder: BaseEncoder encoder: BaseEncoder
index: IndexType = None index: BaseIndex
def __init__( def __init__(
self, self,
...@@ -167,7 +164,7 @@ class RouteLayer: ...@@ -167,7 +164,7 @@ class RouteLayer:
index_name: Optional[str] = "local", index_name: Optional[str] = "local",
): ):
logger.info("local") logger.info("local")
self.index = Index.get_by_name(index_name=index_name) self.index: BaseIndex = Index.get_by_name(index_name=index_name)
self.categories = None self.categories = None
if encoder is None: if encoder is None:
logger.warning( logger.warning(
...@@ -338,7 +335,7 @@ class RouteLayer: ...@@ -338,7 +335,7 @@ class RouteLayer:
"""Given a query vector, retrieve the top_k most similar records.""" """Given a query vector, retrieve the top_k most similar records."""
if self.index.is_index_populated(): if self.index.is_index_populated():
# calculate similarity matrix # calculate similarity matrix
scores, idx = self.index.search(xq, top_k) scores, idx = self.index.query(xq, top_k)
# get the utterance categories (route names) # get the utterance categories (route names)
routes = self.categories[idx] if self.categories is not None else [] routes = self.categories[idx] if self.categories is not None else []
return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)] return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)]
......
...@@ -79,7 +79,7 @@ class DocumentSplit(BaseModel): ...@@ -79,7 +79,7 @@ class DocumentSplit(BaseModel):
class Index: class Index:
@classmethod @classmethod
def get_by_name(cls, index_name: str): def get_by_name(cls, index_name: Optional[str] = None):
if index_name == "local" or index_name is None: if index_name == "local" or index_name is None:
return LocalIndex() return LocalIndex()
# TODO: Later we'll add more index options. # TODO: Later we'll add more index options.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment