From dd01f68c9beab5f62b0e4dcdcf0a342975c174f8 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Sat, 4 Jan 2025 12:37:52 +0400 Subject: [PATCH] fix: hybrid router encoder score tweak --- semantic_router/routers/hybrid.py | 19 +++++++++++++++++++ tests/unit/test_router.py | 25 ++++++++++++++++++++----- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/semantic_router/routers/hybrid.py b/semantic_router/routers/hybrid.py index 8ab31285..e2f0a23e 100644 --- a/semantic_router/routers/hybrid.py +++ b/semantic_router/routers/hybrid.py @@ -65,6 +65,21 @@ class HybridRouter(BaseRouter): if self.auto_sync: self._init_index_state() + def _set_score_threshold(self): + """Set the score threshold for the HybridRouter. Unlike the base router the + encoder score threshold is not used directly. Instead, the dense encoder + score threshold is multiplied by the alpha value, resulting in a lower + score threshold. This is done to account for the difference in returned + scores from the hybrid router. + """ + if self.encoder.score_threshold is not None: + self.score_threshold = self.encoder.score_threshold * self.alpha + if self.score_threshold is None: + logger.warning( + "No score threshold value found in encoder. Using the default " + "'None' value can lead to unexpected results." + ) + def add(self, routes: List[Route] | Route): """Add a route to the local HybridRouter and index. @@ -226,6 +241,8 @@ class HybridRouter(BaseRouter): route_filter=route_filter, sparse_vector=sparse_vector, ) + logger.warning(f"JBTEMP: {scores}") + logger.warning(f"JBTEMP: {route_names}") query_results = [ {"route": d, "score": s.item()} for d, s in zip(route_names, scores) ] @@ -234,6 +251,8 @@ class HybridRouter(BaseRouter): top_class, top_class_scores = self._semantic_classify( query_results=query_results ) + logger.warning(f"JBTEMP: {top_class}") + logger.warning(f"JBTEMP: {top_class_scores}") passed = self._pass_threshold(top_class_scores, self.score_threshold) if passed: return RouteChoice(name=top_class, similarity_score=max(top_class_scores)) diff --git a/tests/unit/test_router.py b/tests/unit/test_router.py index 15743e7d..c460f121 100644 --- a/tests/unit/test_router.py +++ b/tests/unit/test_router.py @@ -269,7 +269,10 @@ class TestIndexEncoders: if index_cls is PineconeIndex: time.sleep(PINECONE_SLEEP) # allow for index to be populated - assert route_layer.score_threshold == encoder.score_threshold + if isinstance(route_layer, HybridRouter): + assert route_layer.score_threshold == encoder.score_threshold * route_layer.alpha + else: + assert route_layer.score_threshold == encoder.score_threshold assert route_layer.top_k == 10 assert len(route_layer.index) == 5 assert ( @@ -289,7 +292,10 @@ class TestIndexEncoders: def test_initialization_no_encoder(self, index_cls, encoder_cls, router_cls): os.environ["OPENAI_API_KEY"] = "test_api_key" route_layer_none = router_cls(encoder=None) - assert route_layer_none.score_threshold == 0.3 + if isinstance(route_layer_none, HybridRouter): + assert route_layer_none.score_threshold == 0.3 * route_layer_none.alpha + else: + assert route_layer_none.score_threshold == 0.3 class TestRouterConfig: @@ -525,7 +531,10 @@ class TestSemanticRouter: index=index, auto_sync="local", ) - assert route_layer.score_threshold == encoder.score_threshold + if isinstance(route_layer, HybridRouter): + assert route_layer.score_threshold == encoder.score_threshold * route_layer.alpha + else: + assert route_layer.score_threshold == encoder.score_threshold def test_add_single_utterance( self, routes, route_single_utterance, index_cls, encoder_cls, router_cls @@ -539,7 +548,10 @@ class TestSemanticRouter: auto_sync="local", ) route_layer.add(routes=route_single_utterance) - assert route_layer.score_threshold == encoder.score_threshold + if isinstance(route_layer, HybridRouter): + assert route_layer.score_threshold == encoder.score_threshold * route_layer.alpha + else: + assert route_layer.score_threshold == encoder.score_threshold if index_cls is PineconeIndex: time.sleep(PINECONE_SLEEP) # allow for index to be updated _ = route_layer("Hello") @@ -558,7 +570,10 @@ class TestSemanticRouter: if index_cls is PineconeIndex: time.sleep(PINECONE_SLEEP) # allow for index to be updated route_layer.add(routes=route_single_utterance) - assert route_layer.score_threshold == encoder.score_threshold + if isinstance(route_layer, HybridRouter): + assert route_layer.score_threshold == encoder.score_threshold * route_layer.alpha + else: + assert route_layer.score_threshold == encoder.score_threshold _ = route_layer("Hello") assert len(route_layer.index.get_utterances()) == 1 -- GitLab