diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 9ddb1fe0f7a14644a8df7ca1a07354154fc3a10b..e5ba76a2f2fbd997cd1db9f1d593aa6e9a834cbe 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -157,14 +157,18 @@ class TestRouteLayer: # Initially, the routes list should be empty assert route_layer.routes == [] + # Add route1 and check route_layer.add(route=route1) assert route_layer.routes == [route1] - # TODO add length check + assert route_layer.index is not None + # Use the describe method to get the number of vectors + assert route_layer.index.describe()["vectors"] == 2 + # Add route2 and check route_layer.add(route=route2) assert route_layer.routes == [route1, route2] - # TODO add length check + assert route_layer.index.describe()["vectors"] == 4 def test_list_route_names(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder, routes=routes) @@ -202,7 +206,7 @@ class TestRouteLayer: route_layer = RouteLayer(encoder=openai_encoder) route_layer._add_routes(routes=routes) assert route_layer.index is not None - # TODO add length check + assert route_layer.index.describe()["vectors"] == 5 def test_query_and_classification(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder, routes=routes) @@ -211,7 +215,8 @@ class TestRouteLayer: def test_query_with_no_index(self, openai_encoder): route_layer = RouteLayer(encoder=openai_encoder) - assert route_layer(text="Anything").name is None + with pytest.raises(ValueError): + assert route_layer(text="Anything").name is None def test_query_with_vector(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder, routes=routes) @@ -296,7 +301,7 @@ class TestRouteLayer: assert layer_config.routes == routes # now load from config and confirm it's the same route_layer_from_config = RouteLayer.from_config(layer_config) - assert route_layer_from_config.index == route_layer.index + assert (route_layer_from_config.index.index == route_layer.index.index).all() assert route_layer_from_config._get_route_names() == route_layer._get_route_names() assert route_layer_from_config.score_threshold == route_layer.score_threshold