Skip to content
Snippets Groups Projects
Unverified Commit b9a4f2c4 authored by James Briggs's avatar James Briggs
Browse files

fix tests

parent cbf7e977
No related branches found
No related tags found
No related merge requests found
......@@ -157,14 +157,18 @@ class TestRouteLayer:
# Initially, the routes list should be empty
assert route_layer.routes == []
# Add route1 and check
route_layer.add(route=route1)
assert route_layer.routes == [route1]
# TODO add length check
assert route_layer.index is not None
# Use the describe method to get the number of vectors
assert route_layer.index.describe()["vectors"] == 2
# Add route2 and check
route_layer.add(route=route2)
assert route_layer.routes == [route1, route2]
# TODO add length check
assert route_layer.index.describe()["vectors"] == 4
def test_list_route_names(self, openai_encoder, routes):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
......@@ -202,7 +206,7 @@ class TestRouteLayer:
route_layer = RouteLayer(encoder=openai_encoder)
route_layer._add_routes(routes=routes)
assert route_layer.index is not None
# TODO add length check
assert route_layer.index.describe()["vectors"] == 5
def test_query_and_classification(self, openai_encoder, routes):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
......@@ -211,7 +215,8 @@ class TestRouteLayer:
def test_query_with_no_index(self, openai_encoder):
route_layer = RouteLayer(encoder=openai_encoder)
assert route_layer(text="Anything").name is None
with pytest.raises(ValueError):
assert route_layer(text="Anything").name is None
def test_query_with_vector(self, openai_encoder, routes):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
......@@ -296,7 +301,7 @@ class TestRouteLayer:
assert layer_config.routes == routes
# now load from config and confirm it's the same
route_layer_from_config = RouteLayer.from_config(layer_config)
assert route_layer_from_config.index == route_layer.index
assert (route_layer_from_config.index.index == route_layer.index.index).all()
assert route_layer_from_config._get_route_names() == route_layer._get_route_names()
assert route_layer_from_config.score_threshold == route_layer.score_threshold
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment