diff --git a/tests/unit/test_router.py b/tests/unit/test_router.py index 4698e476a9bf1cd70336d8f790bdc4899f30cf45..5b0312dda93a1e4266486470e82f0ef5780b7902 100644 --- a/tests/unit/test_router.py +++ b/tests/unit/test_router.py @@ -218,9 +218,9 @@ def get_test_routers(): ], ) class TestIndexEncoders: - def test_initialization(self, routes, openai_encoder, index_cls, encoder_cls): + def test_initialization(self, routes, openai_encoder, index_cls, encoder_cls, router_cls): index = init_index(index_cls) - route_layer = SemanticRouter( + route_layer = router_cls( encoder=encoder_cls(), routes=routes, index=index, @@ -240,15 +240,15 @@ class TestIndexEncoders: else 0 == 2 ) - def test_initialization_different_encoders(self, encoder_cls, index_cls): + def test_initialization_different_encoders(self, encoder_cls, index_cls, router_cls): index = init_index(index_cls) encoder = encoder_cls() - route_layer = SemanticRouter(encoder=encoder, index=index) + route_layer = router_cls(encoder=encoder, index=index) assert route_layer.score_threshold == encoder.score_threshold - def test_initialization_no_encoder(self, openai_encoder, index_cls, encoder_cls): + def test_initialization_no_encoder(self, openai_encoder, index_cls, router_cls): os.environ["OPENAI_API_KEY"] = "test_api_key" - route_layer_none = SemanticRouter(encoder=None) + route_layer_none = router_cls(encoder=None) assert route_layer_none.score_threshold == openai_encoder.score_threshold