diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py index 68150cb767faec1a8c3c5417f9d5f6db3a977cda..69ca58eca7842e31507ca989b02ea2da4ba8ee5b 100644 --- a/semantic_router/encoders/bm25.py +++ b/semantic_router/encoders/bm25.py @@ -19,7 +19,11 @@ class BM25Encoder(BaseEncoder): "You can install it with: `pip install semantic-router[hybrid]`" ) logger.info("Downloading and initializing BM25 model parameters.") - self.model = encoder.default() + # self.model = encoder.default() + self.model = encoder() + self.model.fit( + corpus=["test test", "this is another message", "hello how are you"] + ) params = self.model.get_params() doc_freq = params["doc_freq"] diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index fc63cfa6ec2dc618849ffb54c0a49b88288470dd..cd9f7ccb65f50a2ac3eb627278a474d78eea9c41 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -4,8 +4,6 @@ from numpy.linalg import norm from semantic_router.encoders import ( BaseEncoder, BM25Encoder, - CohereEncoder, - OpenAIEncoder, ) from semantic_router.route import Route from semantic_router.utils.logger import logger @@ -15,21 +13,15 @@ class HybridRouteLayer: index = None sparse_index = None categories = None - score_threshold = 0.82 + score_threshold: float def __init__( self, encoder: BaseEncoder, routes: list[Route] = [], alpha: float = 0.3 ): self.encoder = encoder + self.score_threshold = self.encoder.score_threshold self.sparse_encoder = BM25Encoder() self.alpha = alpha - # decide on default threshold based on encoder - if isinstance(encoder, OpenAIEncoder): - self.score_threshold = 0.82 - elif isinstance(encoder, CohereEncoder): - self.score_threshold = 0.3 - else: - self.score_threshold = 0.82 # if routes list has been passed, we initialize index now if routes: # initialize index now diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index f87cb1d281b2884a10a9817b9a838c21e64a9881..6896c4de1cb1e13196d209455f2bd39e8e14915d 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -19,7 +19,7 @@ def mock_encoder_call(utterances): @pytest.fixture def base_encoder(): - return BaseEncoder(name="test-encoder") + return BaseEncoder(name="test-encoder", score_threshold=0.5) @pytest.fixture @@ -46,6 +46,7 @@ class TestHybridRouteLayer: def test_initialization(self, openai_encoder, routes): route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes) assert route_layer.index is not None and route_layer.categories is not None + assert openai_encoder.score_threshold == 0.82 assert route_layer.score_threshold == 0.82 assert len(route_layer.index) == 5 assert len(set(route_layer.categories)) == 2 @@ -112,7 +113,8 @@ class TestHybridRouteLayer: def test_failover_score_threshold(self, base_encoder): route_layer = HybridRouteLayer(encoder=base_encoder) - assert route_layer.score_threshold == 0.82 + assert base_encoder.score_threshold == 0.50 + assert route_layer.score_threshold == 0.50 # Add more tests for edge cases and error handling as needed.