diff --git a/tests/unit/test_router.py b/tests/unit/test_router.py index cd4d14204d0cb58d3fdd8bdbdea709d2e33d1a6f..cfdc1f6bacc79e0d3b60b7f492837f542d082453 100644 --- a/tests/unit/test_router.py +++ b/tests/unit/test_router.py @@ -295,7 +295,13 @@ class TestIndexEncoders: encoder = encoder_cls() index = init_index(index_cls, index_name=encoder.__class__.__name__) route_layer = router_cls(encoder=encoder, index=index) - assert route_layer.score_threshold == encoder.score_threshold + if isinstance(route_layer, HybridRouter): + assert ( + route_layer.score_threshold + == encoder.score_threshold * route_layer.alpha + ) + else: + assert route_layer.score_threshold == encoder.score_threshold def test_initialization_no_encoder(self, index_cls, encoder_cls, router_cls): os.environ["OPENAI_API_KEY"] = "test_api_key"