From 37fe08cb4259eba102b5b1970c46bb9444224c16 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Tue, 6 Feb 2024 23:01:33 +0400 Subject: [PATCH] Bug Fixes Making pydantic leave me alone and giving the correct index name. --- semantic_router/indices/local_index.py | 18 +++++++----------- semantic_router/layer.py | 3 ++- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/semantic_router/indices/local_index.py b/semantic_router/indices/local_index.py index 784c9967..eef81b91 100644 --- a/semantic_router/indices/local_index.py +++ b/semantic_router/indices/local_index.py @@ -1,21 +1,17 @@ import numpy as np from typing import List, Any -from .base import BaseIndex from semantic_router.linear import similarity_matrix, top_scores -from typing import Tuple +from pydantic import BaseModel +import numpy as np +from typing import List, Any, Tuple, Optional -class LocalIndex(BaseIndex): - """ - Local index implementation using numpy arrays. - """ +class LocalIndex(BaseModel): + index: Optional[np.ndarray] = None - def __init__(self): - self.index = None + class Config: # Stop pydantic from complaining about Optional[np.ndarray] type hints. + arbitrary_types_allowed = True def add(self, embeds: List[Any]): - """ - Add items to the index. - """ embeds = np.array(embeds) if self.index is None: self.index = embeds diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 487d8b24..ff1168f2 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -13,6 +13,7 @@ from semantic_router.llms import BaseLLM, OpenAILLM from semantic_router.route import Route from semantic_router.schema import Encoder, EncoderType, RouteChoice, Index from semantic_router.utils.logger import logger +from semantic_router.indices.local_index import LocalIndex IndexType = Union[LocalIndex, None] @@ -166,7 +167,7 @@ class RouteLayer: index_name: Optional[str] = None, ): logger.info("local") - self.index = Index.get_by_name(index_name="index") + self.index = Index.get_by_name(index_name="local") self.categories = None if encoder is None: logger.warning( -- GitLab