Skip to content
Snippets Groups Projects
Unverified Commit 0586c670 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

More pytests.

parent 0866dccd
No related branches found
No related tags found
No related merge requests found
...@@ -811,6 +811,36 @@ class TestRouteLayer: ...@@ -811,6 +811,36 @@ class TestRouteLayer:
): ):
route_layer._refresh_routes() route_layer._refresh_routes()
def test_update_threshold(self, openai_encoder, routes, index_cls):
index = init_index(index_cls)
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index)
route_name = "Route 1"
new_threshold = 0.8
route_layer.update(name=route_name, threshold=new_threshold)
updated_route = route_layer.get(route_name)
assert updated_route.score_threshold == new_threshold, f"Expected threshold to be updated to {new_threshold}, but got {updated_route.score_threshold}"
def test_update_non_existent_route(self, openai_encoder, routes, index_cls):
index = init_index(index_cls)
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index)
non_existent_route = "Non-existent Route"
with pytest.raises(ValueError, match=f"Route '{non_existent_route}' not found. Nothing updated."):
route_layer.update(name=non_existent_route, threshold=0.7)
def test_update_without_parameters(self, openai_encoder, routes, index_cls):
index = init_index(index_cls)
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index)
with pytest.raises(ValueError, match="At least one of 'threshold' or 'utterances' must be provided."):
route_layer.update(name="Route 1")
def test_update_utterances_not_implemented(self, openai_encoder, routes, index_cls):
index = init_index(index_cls)
route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index)
with pytest.raises(NotImplementedError, match="The update method cannot be used for updating utterances yet."):
route_layer.update(name="Route 1", utterances=["New utterance"])
class TestLayerFit: class TestLayerFit:
def test_eval(self, openai_encoder, routes, test_data): def test_eval(self, openai_encoder, routes, test_data):
...@@ -948,3 +978,4 @@ class TestLayerConfig: ...@@ -948,3 +978,4 @@ class TestLayerConfig:
elif agg == "max": elif agg == "max":
assert classification == "Route 3" assert classification == "Route 3"
assert score == [0.1, 1.0] assert score == [0.1, 1.0]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment