diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index 25cb8b6b29bfc02b703a40eece793a865c3aeefc..fc789351bd453e459417f696e1439bb8622d2374 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -11,7 +11,12 @@ import numpy as np import yaml # type: ignore from tqdm.auto import tqdm -from semantic_router.encoders import AutoEncoder, DenseEncoder, OpenAIEncoder +from semantic_router.encoders import ( + AutoEncoder, + DenseEncoder, + OpenAIEncoder, + SparseEncoder, +) from semantic_router.index.base import BaseIndex from semantic_router.index.local import LocalIndex from semantic_router.index.pinecone import PineconeIndex @@ -298,6 +303,7 @@ def xq_reshape(xq: List[float] | np.ndarray) -> np.ndarray: class BaseRouter(BaseModel): encoder: DenseEncoder = Field(default_factory=OpenAIEncoder) + sparse_encoder: Optional[SparseEncoder] = Field(default=None) index: BaseIndex = Field(default_factory=BaseIndex) score_threshold: Optional[float] = Field(default=None) routes: List[Route] = Field(default_factory=list) @@ -313,6 +319,7 @@ class BaseRouter(BaseModel): def __init__( self, encoder: Optional[DenseEncoder] = None, + sparse_encoder: Optional[SparseEncoder] = None, llm: Optional[BaseLLM] = None, routes: List[Route] = [], index: Optional[BaseIndex] = None, # type: ignore @@ -322,6 +329,7 @@ class BaseRouter(BaseModel): ): super().__init__( encoder=encoder, + sparse_encoder=sparse_encoder, llm=llm, routes=routes, index=index, @@ -330,6 +338,7 @@ class BaseRouter(BaseModel): auto_sync=auto_sync, ) self.encoder = self._get_encoder(encoder=encoder) + self.sparse_encoder = self._get_sparse_encoder(sparse_encoder=sparse_encoder) self.llm = llm self.routes = routes.copy() if routes else [] # initialize index @@ -370,6 +379,15 @@ class BaseRouter(BaseModel): encoder = encoder return encoder + def _get_sparse_encoder( + self, sparse_encoder: Optional[SparseEncoder] + ) -> Optional[SparseEncoder]: + if sparse_encoder is None: + return None + raise NotImplementedError( + f"Sparse encoder not implemented for {self.__class__.__name__}" + ) + def _init_index_state(self): """Initializes an index (where required) and runs auto_sync if active.""" print("JBTEMP _init_index_state") diff --git a/semantic_router/routers/hybrid.py b/semantic_router/routers/hybrid.py index 70e0044022d423b91421f01d79b6b790d06fb5e4..7a1591d4f3a3d4c286506539f07d18b80c15d351 100644 --- a/semantic_router/routers/hybrid.py +++ b/semantic_router/routers/hybrid.py @@ -42,8 +42,11 @@ class HybridRouter(BaseRouter): logger.warning("No index provided. Using default HybridLocalIndex.") index = HybridLocalIndex() encoder = self._get_encoder(encoder=encoder) + # initialize sparse encoder + sparse_encoder = self._get_sparse_encoder(sparse_encoder=sparse_encoder) super().__init__( encoder=encoder, + sparse_encoder=sparse_encoder, llm=llm, routes=routes, index=index, @@ -51,8 +54,6 @@ class HybridRouter(BaseRouter): aggregation=aggregation, auto_sync=auto_sync, ) - # initialize sparse encoder - self.sparse_encoder = self._get_sparse_encoder(sparse_encoder=sparse_encoder) # set alpha self.alpha = alpha # fit sparse encoder if needed @@ -162,7 +163,7 @@ class HybridRouter(BaseRouter): def _get_sparse_encoder( self, sparse_encoder: Optional[SparseEncoder] - ) -> SparseEncoder: + ) -> Optional[SparseEncoder]: if sparse_encoder is None: logger.warning("No sparse_encoder provided. Using default BM25Encoder.") sparse_encoder = BM25Encoder()