diff --git a/semantic_router/indices/local_index.py b/semantic_router/indices/local_index.py index 784c99673c38c45adf71f01d41438a7e961ced4e..eef81b910434b0f48bb9f7c0f6ae210ecb0b06c4 100644 --- a/semantic_router/indices/local_index.py +++ b/semantic_router/indices/local_index.py @@ -1,21 +1,17 @@ import numpy as np from typing import List, Any -from .base import BaseIndex from semantic_router.linear import similarity_matrix, top_scores -from typing import Tuple +from pydantic import BaseModel +import numpy as np +from typing import List, Any, Tuple, Optional -class LocalIndex(BaseIndex): - """ - Local index implementation using numpy arrays. - """ +class LocalIndex(BaseModel): + index: Optional[np.ndarray] = None - def __init__(self): - self.index = None + class Config: # Stop pydantic from complaining about Optional[np.ndarray] type hints. + arbitrary_types_allowed = True def add(self, embeds: List[Any]): - """ - Add items to the index. - """ embeds = np.array(embeds) if self.index is None: self.index = embeds diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 487d8b24554bc059e9898d8610fac39ece5186e2..ff1168f2ed9cd2c26a790f6463cc725384d198bf 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -13,6 +13,7 @@ from semantic_router.llms import BaseLLM, OpenAILLM from semantic_router.route import Route from semantic_router.schema import Encoder, EncoderType, RouteChoice, Index from semantic_router.utils.logger import logger +from semantic_router.indices.local_index import LocalIndex IndexType = Union[LocalIndex, None] @@ -166,7 +167,7 @@ class RouteLayer: index_name: Optional[str] = None, ): logger.info("local") - self.index = Index.get_by_name(index_name="index") + self.index = Index.get_by_name(index_name="local") self.categories = None if encoder is None: logger.warning(