From dd01f68c9beab5f62b0e4dcdcf0a342975c174f8 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Sat, 4 Jan 2025 12:37:52 +0400
Subject: [PATCH] fix: hybrid router encoder score tweak

---
 semantic_router/routers/hybrid.py | 19 +++++++++++++++++++
 tests/unit/test_router.py         | 25 ++++++++++++++++++++-----
 2 files changed, 39 insertions(+), 5 deletions(-)

diff --git a/semantic_router/routers/hybrid.py b/semantic_router/routers/hybrid.py
index 8ab31285..e2f0a23e 100644
--- a/semantic_router/routers/hybrid.py
+++ b/semantic_router/routers/hybrid.py
@@ -65,6 +65,21 @@ class HybridRouter(BaseRouter):
         if self.auto_sync:
             self._init_index_state()
 
+    def _set_score_threshold(self):
+        """Set the score threshold for the HybridRouter. Unlike the base router the
+        encoder score threshold is not used directly. Instead, the dense encoder
+        score threshold is multiplied by the alpha value, resulting in a lower
+        score threshold. This is done to account for the difference in returned
+        scores from the hybrid router.
+        """
+        if self.encoder.score_threshold is not None:
+            self.score_threshold = self.encoder.score_threshold * self.alpha
+            if self.score_threshold is None:
+                logger.warning(
+                    "No score threshold value found in encoder. Using the default "
+                    "'None' value can lead to unexpected results."
+                )
+
     def add(self, routes: List[Route] | Route):
         """Add a route to the local HybridRouter and index.
 
@@ -226,6 +241,8 @@ class HybridRouter(BaseRouter):
             route_filter=route_filter,
             sparse_vector=sparse_vector,
         )
+        logger.warning(f"JBTEMP: {scores}")
+        logger.warning(f"JBTEMP: {route_names}")
         query_results = [
             {"route": d, "score": s.item()} for d, s in zip(route_names, scores)
         ]
@@ -234,6 +251,8 @@ class HybridRouter(BaseRouter):
         top_class, top_class_scores = self._semantic_classify(
             query_results=query_results
         )
+        logger.warning(f"JBTEMP: {top_class}")
+        logger.warning(f"JBTEMP: {top_class_scores}")
         passed = self._pass_threshold(top_class_scores, self.score_threshold)
         if passed:
             return RouteChoice(name=top_class, similarity_score=max(top_class_scores))
diff --git a/tests/unit/test_router.py b/tests/unit/test_router.py
index 15743e7d..c460f121 100644
--- a/tests/unit/test_router.py
+++ b/tests/unit/test_router.py
@@ -269,7 +269,10 @@ class TestIndexEncoders:
         if index_cls is PineconeIndex:
             time.sleep(PINECONE_SLEEP)  # allow for index to be populated
 
-        assert route_layer.score_threshold == encoder.score_threshold
+        if isinstance(route_layer, HybridRouter):
+            assert route_layer.score_threshold == encoder.score_threshold * route_layer.alpha
+        else:
+            assert route_layer.score_threshold == encoder.score_threshold
         assert route_layer.top_k == 10
         assert len(route_layer.index) == 5
         assert (
@@ -289,7 +292,10 @@ class TestIndexEncoders:
     def test_initialization_no_encoder(self, index_cls, encoder_cls, router_cls):
         os.environ["OPENAI_API_KEY"] = "test_api_key"
         route_layer_none = router_cls(encoder=None)
-        assert route_layer_none.score_threshold == 0.3
+        if isinstance(route_layer_none, HybridRouter):
+            assert route_layer_none.score_threshold == 0.3 * route_layer_none.alpha
+        else:
+            assert route_layer_none.score_threshold == 0.3
 
 
 class TestRouterConfig:
@@ -525,7 +531,10 @@ class TestSemanticRouter:
             index=index,
             auto_sync="local",
         )
-        assert route_layer.score_threshold == encoder.score_threshold
+        if isinstance(route_layer, HybridRouter):
+            assert route_layer.score_threshold == encoder.score_threshold * route_layer.alpha
+        else:
+            assert route_layer.score_threshold == encoder.score_threshold
 
     def test_add_single_utterance(
         self, routes, route_single_utterance, index_cls, encoder_cls, router_cls
@@ -539,7 +548,10 @@ class TestSemanticRouter:
             auto_sync="local",
         )
         route_layer.add(routes=route_single_utterance)
-        assert route_layer.score_threshold == encoder.score_threshold
+        if isinstance(route_layer, HybridRouter):
+            assert route_layer.score_threshold == encoder.score_threshold * route_layer.alpha
+        else:
+            assert route_layer.score_threshold == encoder.score_threshold
         if index_cls is PineconeIndex:
             time.sleep(PINECONE_SLEEP)  # allow for index to be updated
         _ = route_layer("Hello")
@@ -558,7 +570,10 @@ class TestSemanticRouter:
         if index_cls is PineconeIndex:
             time.sleep(PINECONE_SLEEP)  # allow for index to be updated
         route_layer.add(routes=route_single_utterance)
-        assert route_layer.score_threshold == encoder.score_threshold
+        if isinstance(route_layer, HybridRouter):
+            assert route_layer.score_threshold == encoder.score_threshold * route_layer.alpha
+        else:
+            assert route_layer.score_threshold == encoder.score_threshold
         _ = route_layer("Hello")
         assert len(route_layer.index.get_utterances()) == 1
 
-- 
GitLab