From d551201404a2ed2c67e989bb55487d2f5a2f976e Mon Sep 17 00:00:00 2001 From: Ismail Ashraq <issey1455@gmail.com> Date: Sun, 7 Jan 2024 19:15:03 +0500 Subject: [PATCH] Update hybrid layer to use score_threshold from encoders --- semantic_router/encoders/bm25.py | 6 +++++- semantic_router/hybrid_layer.py | 12 ++---------- tests/unit/test_hybrid_layer.py | 6 ++++-- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py index 68150cb7..69ca58ec 100644 --- a/semantic_router/encoders/bm25.py +++ b/semantic_router/encoders/bm25.py @@ -19,7 +19,11 @@ class BM25Encoder(BaseEncoder): "You can install it with: `pip install semantic-router[hybrid]`" ) logger.info("Downloading and initializing BM25 model parameters.") - self.model = encoder.default() + # self.model = encoder.default() + self.model = encoder() + self.model.fit( + corpus=["test test", "this is another message", "hello how are you"] + ) params = self.model.get_params() doc_freq = params["doc_freq"] diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index fc63cfa6..cd9f7ccb 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -4,8 +4,6 @@ from numpy.linalg import norm from semantic_router.encoders import ( BaseEncoder, BM25Encoder, - CohereEncoder, - OpenAIEncoder, ) from semantic_router.route import Route from semantic_router.utils.logger import logger @@ -15,21 +13,15 @@ class HybridRouteLayer: index = None sparse_index = None categories = None - score_threshold = 0.82 + score_threshold: float def __init__( self, encoder: BaseEncoder, routes: list[Route] = [], alpha: float = 0.3 ): self.encoder = encoder + self.score_threshold = self.encoder.score_threshold self.sparse_encoder = BM25Encoder() self.alpha = alpha - # decide on default threshold based on encoder - if isinstance(encoder, OpenAIEncoder): - self.score_threshold = 0.82 - elif isinstance(encoder, CohereEncoder): - self.score_threshold = 0.3 - else: - self.score_threshold = 0.82 # if routes list has been passed, we initialize index now if routes: # initialize index now diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index f87cb1d2..6896c4de 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -19,7 +19,7 @@ def mock_encoder_call(utterances): @pytest.fixture def base_encoder(): - return BaseEncoder(name="test-encoder") + return BaseEncoder(name="test-encoder", score_threshold=0.5) @pytest.fixture @@ -46,6 +46,7 @@ class TestHybridRouteLayer: def test_initialization(self, openai_encoder, routes): route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes) assert route_layer.index is not None and route_layer.categories is not None + assert openai_encoder.score_threshold == 0.82 assert route_layer.score_threshold == 0.82 assert len(route_layer.index) == 5 assert len(set(route_layer.categories)) == 2 @@ -112,7 +113,8 @@ class TestHybridRouteLayer: def test_failover_score_threshold(self, base_encoder): route_layer = HybridRouteLayer(encoder=base_encoder) - assert route_layer.score_threshold == 0.82 + assert base_encoder.score_threshold == 0.50 + assert route_layer.score_threshold == 0.50 # Add more tests for edge cases and error handling as needed. -- GitLab