Skip to content
Snippets Groups Projects
Unverified Commit 20133d36 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Linting.

parent b0d164c7
No related branches found
No related tags found
No related merge requests found
from pydantic.v1 import BaseModel from pydantic.v1 import BaseModel
class BaseIndex(BaseModel): class BaseIndex(BaseModel):
# Currently just a placedholder until more indexing methods are added and common attributes/methods are identified. # Currently just a placedholder until more indexing methods are added and common attributes/methods are identified.
pass pass
\ No newline at end of file
import numpy as np import numpy as np
from typing import List, Any from typing import List, Any, Tuple, Optional
from semantic_router.linear import similarity_matrix, top_scores from semantic_router.linear import similarity_matrix, top_scores
from semantic_router.indices.base import BaseIndex from semantic_router.indices.base import BaseIndex
import numpy as np
from typing import List, Any, Tuple, Optional
class LocalIndex(BaseIndex): class LocalIndex(BaseIndex):
index: Optional[np.ndarray] = None index: Optional[np.ndarray] = None
class Config: # Stop pydantic from complaining about Optional[np.ndarray] type hints. class Config: # Stop pydantic from complaining about Optional[np.ndarray] type hints.
arbitrary_types_allowed = True arbitrary_types_allowed = True
def add(self, embeds: List[Any]): def add(self, embeds: List[Any]):
...@@ -27,7 +26,9 @@ class LocalIndex(BaseIndex): ...@@ -27,7 +26,9 @@ class LocalIndex(BaseIndex):
def is_index_populated(self): def is_index_populated(self):
return self.index is not None and len(self.index) > 0 return self.index is not None and len(self.index) > 0
def search(self, query_vector: Any, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]: def search(
self, query_vector: Any, top_k: int = 5
) -> Tuple[np.ndarray, np.ndarray]:
""" """
Search the index for the query and return top_k results. Search the index for the query and return top_k results.
""" """
...@@ -35,4 +36,3 @@ class LocalIndex(BaseIndex): ...@@ -35,4 +36,3 @@ class LocalIndex(BaseIndex):
raise ValueError("Index is not populated.") raise ValueError("Index is not populated.")
sim = similarity_matrix(query_vector, self.index) sim = similarity_matrix(query_vector, self.index)
return top_scores(sim, top_k) return top_scores(sim, top_k)
...@@ -8,7 +8,6 @@ import yaml ...@@ -8,7 +8,6 @@ import yaml
from tqdm.auto import tqdm from tqdm.auto import tqdm
from semantic_router.encoders import BaseEncoder, OpenAIEncoder from semantic_router.encoders import BaseEncoder, OpenAIEncoder
from semantic_router.linear import similarity_matrix, top_scores
from semantic_router.llms import BaseLLM, OpenAILLM from semantic_router.llms import BaseLLM, OpenAILLM
from semantic_router.route import Route from semantic_router.route import Route
from semantic_router.schema import Encoder, EncoderType, RouteChoice, Index from semantic_router.schema import Encoder, EncoderType, RouteChoice, Index
...@@ -17,6 +16,7 @@ from semantic_router.indices.local_index import LocalIndex ...@@ -17,6 +16,7 @@ from semantic_router.indices.local_index import LocalIndex
IndexType = Union[LocalIndex, None] IndexType = Union[LocalIndex, None]
def is_valid(layer_config: str) -> bool: def is_valid(layer_config: str) -> bool:
"""Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]""" """Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]"""
try: try:
......
...@@ -84,4 +84,4 @@ class Index: ...@@ -84,4 +84,4 @@ class Index:
return LocalIndex() return LocalIndex()
# TODO: Later we'll add more index options. # TODO: Later we'll add more index options.
else: else:
raise ValueError(f"Invalid index name: {index_name}") raise ValueError(f"Invalid index name: {index_name}")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment