From d551201404a2ed2c67e989bb55487d2f5a2f976e Mon Sep 17 00:00:00 2001
From: Ismail Ashraq <issey1455@gmail.com>
Date: Sun, 7 Jan 2024 19:15:03 +0500
Subject: [PATCH] Update hybrid layer to use score_threshold from encoders

---
 semantic_router/encoders/bm25.py |  6 +++++-
 semantic_router/hybrid_layer.py  | 12 ++----------
 tests/unit/test_hybrid_layer.py  |  6 ++++--
 3 files changed, 11 insertions(+), 13 deletions(-)

diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py
index 68150cb7..69ca58ec 100644
--- a/semantic_router/encoders/bm25.py
+++ b/semantic_router/encoders/bm25.py
@@ -19,7 +19,11 @@ class BM25Encoder(BaseEncoder):
                 "You can install it with: `pip install semantic-router[hybrid]`"
             )
         logger.info("Downloading and initializing BM25 model parameters.")
-        self.model = encoder.default()
+        # self.model = encoder.default()
+        self.model = encoder()
+        self.model.fit(
+            corpus=["test test", "this is another message", "hello how are you"]
+        )
 
         params = self.model.get_params()
         doc_freq = params["doc_freq"]
diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py
index fc63cfa6..cd9f7ccb 100644
--- a/semantic_router/hybrid_layer.py
+++ b/semantic_router/hybrid_layer.py
@@ -4,8 +4,6 @@ from numpy.linalg import norm
 from semantic_router.encoders import (
     BaseEncoder,
     BM25Encoder,
-    CohereEncoder,
-    OpenAIEncoder,
 )
 from semantic_router.route import Route
 from semantic_router.utils.logger import logger
@@ -15,21 +13,15 @@ class HybridRouteLayer:
     index = None
     sparse_index = None
     categories = None
-    score_threshold = 0.82
+    score_threshold: float
 
     def __init__(
         self, encoder: BaseEncoder, routes: list[Route] = [], alpha: float = 0.3
     ):
         self.encoder = encoder
+        self.score_threshold = self.encoder.score_threshold
         self.sparse_encoder = BM25Encoder()
         self.alpha = alpha
-        # decide on default threshold based on encoder
-        if isinstance(encoder, OpenAIEncoder):
-            self.score_threshold = 0.82
-        elif isinstance(encoder, CohereEncoder):
-            self.score_threshold = 0.3
-        else:
-            self.score_threshold = 0.82
         # if routes list has been passed, we initialize index now
         if routes:
             # initialize index now
diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py
index f87cb1d2..6896c4de 100644
--- a/tests/unit/test_hybrid_layer.py
+++ b/tests/unit/test_hybrid_layer.py
@@ -19,7 +19,7 @@ def mock_encoder_call(utterances):
 
 @pytest.fixture
 def base_encoder():
-    return BaseEncoder(name="test-encoder")
+    return BaseEncoder(name="test-encoder", score_threshold=0.5)
 
 
 @pytest.fixture
@@ -46,6 +46,7 @@ class TestHybridRouteLayer:
     def test_initialization(self, openai_encoder, routes):
         route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes)
         assert route_layer.index is not None and route_layer.categories is not None
+        assert openai_encoder.score_threshold == 0.82
         assert route_layer.score_threshold == 0.82
         assert len(route_layer.index) == 5
         assert len(set(route_layer.categories)) == 2
@@ -112,7 +113,8 @@ class TestHybridRouteLayer:
 
     def test_failover_score_threshold(self, base_encoder):
         route_layer = HybridRouteLayer(encoder=base_encoder)
-        assert route_layer.score_threshold == 0.82
+        assert base_encoder.score_threshold == 0.50
+        assert route_layer.score_threshold == 0.50
 
 
 # Add more tests for edge cases and error handling as needed.
-- 
GitLab