diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index e42f16305781583163001d276cfc908f74b5cedd..37408ecde152591fdbd5865240a20562b15b94c1 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -1242,6 +1242,7 @@ class BaseRouter(BaseModel): if threshold is None: return True if scores: + # TODO JB is this correct? return max(scores) > threshold else: return False diff --git a/tests/unit/test_router.py b/tests/unit/test_router.py index c4e5dab524d51da90aab8f519571ace68ccfff61..87e889381fe0c84f2d9ece0a14a5c8e2cb1a3fa4 100644 --- a/tests/unit/test_router.py +++ b/tests/unit/test_router.py @@ -904,8 +904,8 @@ class TestSemanticRouter: "index_cls,encoder_cls,router_cls", [ (index, encoder, router) - for index in [LocalIndex] - for encoder in [OpenAIEncoder] + for index in [LocalIndex] # no need to test with multiple indexes + for encoder in [OpenAIEncoder] # no need to test with multiple encoders for router in get_test_routers() ], ) @@ -984,7 +984,10 @@ class TestRouterOnly: index=index, auto_sync="local", ) - assert route_layer.score_threshold == 0.3 + if router_cls is HybridRouter: + assert route_layer.score_threshold == 0.3 * route_layer.alpha + else: + assert route_layer.score_threshold == 0.3 def test_json(self, routes, index_cls, encoder_cls, router_cls): temp = tempfile.NamedTemporaryFile(suffix=".yaml", delete=False)