diff --git a/semantic_router/indices/base.py b/semantic_router/indices/base.py index 7b48f65d70b18a57baa623c0fa9b0335c90f26af..baa5a2041e6a6a293274135012a5fa0555c48742 100644 --- a/semantic_router/indices/base.py +++ b/semantic_router/indices/base.py @@ -1,5 +1,6 @@ from pydantic.v1 import BaseModel + class BaseIndex(BaseModel): # Currently just a placedholder until more indexing methods are added and common attributes/methods are identified. - pass \ No newline at end of file + pass diff --git a/semantic_router/indices/local_index.py b/semantic_router/indices/local_index.py index eb5d63ef19ed17405fd324a7236e8ba03abfc035..3c0d32dcee605870feb6c4c5d8bf7df05c9ba517 100644 --- a/semantic_router/indices/local_index.py +++ b/semantic_router/indices/local_index.py @@ -1,14 +1,13 @@ import numpy as np -from typing import List, Any +from typing import List, Any, Tuple, Optional from semantic_router.linear import similarity_matrix, top_scores from semantic_router.indices.base import BaseIndex -import numpy as np -from typing import List, Any, Tuple, Optional + class LocalIndex(BaseIndex): index: Optional[np.ndarray] = None - class Config: # Stop pydantic from complaining about Optional[np.ndarray] type hints. + class Config: # Stop pydantic from complaining about Optional[np.ndarray] type hints. arbitrary_types_allowed = True def add(self, embeds: List[Any]): @@ -27,7 +26,9 @@ class LocalIndex(BaseIndex): def is_index_populated(self): return self.index is not None and len(self.index) > 0 - def search(self, query_vector: Any, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]: + def search( + self, query_vector: Any, top_k: int = 5 + ) -> Tuple[np.ndarray, np.ndarray]: """ Search the index for the query and return top_k results. """ @@ -35,4 +36,3 @@ class LocalIndex(BaseIndex): raise ValueError("Index is not populated.") sim = similarity_matrix(query_vector, self.index) return top_scores(sim, top_k) - diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 8b7c485de9844dd0ab9b25a78fccbaac29280aba..3061f29e78f53c0cd263cc01553025fd105d3438 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -8,7 +8,6 @@ import yaml from tqdm.auto import tqdm from semantic_router.encoders import BaseEncoder, OpenAIEncoder -from semantic_router.linear import similarity_matrix, top_scores from semantic_router.llms import BaseLLM, OpenAILLM from semantic_router.route import Route from semantic_router.schema import Encoder, EncoderType, RouteChoice, Index @@ -17,6 +16,7 @@ from semantic_router.indices.local_index import LocalIndex IndexType = Union[LocalIndex, None] + def is_valid(layer_config: str) -> bool: """Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]""" try: diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 917055cb69afee66a2f1581b683c013d756b33e7..95e29bfe43b26732da5a0e0a0898ad70a56ff46d 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -84,4 +84,4 @@ class Index: return LocalIndex() # TODO: Later we'll add more index options. else: - raise ValueError(f"Invalid index name: {index_name}") \ No newline at end of file + raise ValueError(f"Invalid index name: {index_name}")