Skip to content
Snippets Groups Projects
Unverified Commit 37fe08cb authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Bug Fixes

Making pydantic leave me alone and giving the correct index name.
parent 9b9b1e4c
No related branches found
No related tags found
No related merge requests found
import numpy as np import numpy as np
from typing import List, Any from typing import List, Any
from .base import BaseIndex
from semantic_router.linear import similarity_matrix, top_scores from semantic_router.linear import similarity_matrix, top_scores
from typing import Tuple from pydantic import BaseModel
import numpy as np
from typing import List, Any, Tuple, Optional
class LocalIndex(BaseIndex): class LocalIndex(BaseModel):
""" index: Optional[np.ndarray] = None
Local index implementation using numpy arrays.
"""
def __init__(self): class Config: # Stop pydantic from complaining about Optional[np.ndarray] type hints.
self.index = None arbitrary_types_allowed = True
def add(self, embeds: List[Any]): def add(self, embeds: List[Any]):
"""
Add items to the index.
"""
embeds = np.array(embeds) embeds = np.array(embeds)
if self.index is None: if self.index is None:
self.index = embeds self.index = embeds
......
...@@ -13,6 +13,7 @@ from semantic_router.llms import BaseLLM, OpenAILLM ...@@ -13,6 +13,7 @@ from semantic_router.llms import BaseLLM, OpenAILLM
from semantic_router.route import Route from semantic_router.route import Route
from semantic_router.schema import Encoder, EncoderType, RouteChoice, Index from semantic_router.schema import Encoder, EncoderType, RouteChoice, Index
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
from semantic_router.indices.local_index import LocalIndex
IndexType = Union[LocalIndex, None] IndexType = Union[LocalIndex, None]
...@@ -166,7 +167,7 @@ class RouteLayer: ...@@ -166,7 +167,7 @@ class RouteLayer:
index_name: Optional[str] = None, index_name: Optional[str] = None,
): ):
logger.info("local") logger.info("local")
self.index = Index.get_by_name(index_name="index") self.index = Index.get_by_name(index_name="local")
self.categories = None self.categories = None
if encoder is None: if encoder is None:
logger.warning( logger.warning(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment