diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index cd9f7ccb65f50a2ac3eb627278a474d78eea9c41..5273f531c90bc28167017ae5b86366268408cb3b 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -16,11 +16,21 @@ class HybridRouteLayer: score_threshold: float def __init__( - self, encoder: BaseEncoder, routes: list[Route] = [], alpha: float = 0.3 + self, + encoder: BaseEncoder, + sparse_encoder: BM25Encoder | None = None, + routes: list[Route] = [], + alpha: float = 0.3, ): self.encoder = encoder self.score_threshold = self.encoder.score_threshold - self.sparse_encoder = BM25Encoder() + + if sparse_encoder is None: + logger.warning("No sparse_encoder provided. Using default BM25Encoder.") + self.sparse_encoder = BM25Encoder() + else: + self.sparse_encoder = sparse_encoder + self.alpha = alpha # if routes list has been passed, we initialize index now if routes: