From 172a486cafc4631ea23aeac1d3be65003841be75 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Tue, 6 Feb 2024 23:31:54 +0400 Subject: [PATCH] LocalIndex Now Using Placeholder BaseIndex Class Also fixed bug in layer.py where we weren't: a) Setting the index_name to a default value. b) Weren't using index_name. --- semantic_router/indices/base.py | 1 - semantic_router/indices/local_index.py | 4 ++-- semantic_router/layer.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/semantic_router/indices/base.py b/semantic_router/indices/base.py index f5c3e8d6..bfaa903e 100644 --- a/semantic_router/indices/base.py +++ b/semantic_router/indices/base.py @@ -1,5 +1,4 @@ from pydantic.v1 import BaseModel class BaseIndex(BaseModel): - pass \ No newline at end of file diff --git a/semantic_router/indices/local_index.py b/semantic_router/indices/local_index.py index eef81b91..eb5d63ef 100644 --- a/semantic_router/indices/local_index.py +++ b/semantic_router/indices/local_index.py @@ -1,11 +1,11 @@ import numpy as np from typing import List, Any from semantic_router.linear import similarity_matrix, top_scores -from pydantic import BaseModel +from semantic_router.indices.base import BaseIndex import numpy as np from typing import List, Any, Tuple, Optional -class LocalIndex(BaseModel): +class LocalIndex(BaseIndex): index: Optional[np.ndarray] = None class Config: # Stop pydantic from complaining about Optional[np.ndarray] type hints. diff --git a/semantic_router/layer.py b/semantic_router/layer.py index ff1168f2..8b7c485d 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -164,10 +164,10 @@ class RouteLayer: encoder: Optional[BaseEncoder] = None, llm: Optional[BaseLLM] = None, routes: Optional[List[Route]] = None, - index_name: Optional[str] = None, + index_name: Optional[str] = "local", ): logger.info("local") - self.index = Index.get_by_name(index_name="local") + self.index = Index.get_by_name(index_name=index_name) self.categories = None if encoder is None: logger.warning( -- GitLab