From b9a4f2c4c9563a4a9ca7ce7ab8e876a2251f5f0a Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Wed, 14 Feb 2024 03:24:36 +0400 Subject: [PATCH] fix tests --- tests/unit/test_layer.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 9ddb1fe0..e5ba76a2 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -157,14 +157,18 @@ class TestRouteLayer: # Initially, the routes list should be empty assert route_layer.routes == [] + # Add route1 and check route_layer.add(route=route1) assert route_layer.routes == [route1] - # TODO add length check + assert route_layer.index is not None + # Use the describe method to get the number of vectors + assert route_layer.index.describe()["vectors"] == 2 + # Add route2 and check route_layer.add(route=route2) assert route_layer.routes == [route1, route2] - # TODO add length check + assert route_layer.index.describe()["vectors"] == 4 def test_list_route_names(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder, routes=routes) @@ -202,7 +206,7 @@ class TestRouteLayer: route_layer = RouteLayer(encoder=openai_encoder) route_layer._add_routes(routes=routes) assert route_layer.index is not None - # TODO add length check + assert route_layer.index.describe()["vectors"] == 5 def test_query_and_classification(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder, routes=routes) @@ -211,7 +215,8 @@ class TestRouteLayer: def test_query_with_no_index(self, openai_encoder): route_layer = RouteLayer(encoder=openai_encoder) - assert route_layer(text="Anything").name is None + with pytest.raises(ValueError): + assert route_layer(text="Anything").name is None def test_query_with_vector(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder, routes=routes) @@ -296,7 +301,7 @@ class TestRouteLayer: assert layer_config.routes == routes # now load from config and confirm it's the same route_layer_from_config = RouteLayer.from_config(layer_config) - assert route_layer_from_config.index == route_layer.index + assert (route_layer_from_config.index.index == route_layer.index.index).all() assert route_layer_from_config._get_route_names() == route_layer._get_route_names() assert route_layer_from_config.score_threshold == route_layer.score_threshold -- GitLab