Skip to content
Snippets Groups Projects
Unverified Commit 52770a55 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Simplified tests.

parent 90041092
No related branches found
No related tags found
No related merge requests found
......@@ -503,17 +503,13 @@ class TestRouteLayer:
)
assert route_layer.get_thresholds() == {"Route 1": 0.82, "Route 2": 0.82}
@pytest.fixture
def route_layer(self, openai_encoder, routes, index_cls):
# Initialize RouteLayer with mocked routes and a mock index class
def test_with_multiple_routes_passing_threshold(
self, openai_encoder, routes, index_cls
):
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=index_cls()
)
# Manually set the score_threshold for testing
route_layer.score_threshold = 0.5
return route_layer
def test_with_multiple_routes_passing_threshold(self, route_layer, index_cls):
route_layer.score_threshold = 0.5 # Set the score_threshold if needed
# Assuming route_layer is already set up with routes "Route 1" and "Route 2"
query_results = [
{"route": "Route 1", "score": 0.6},
......@@ -528,7 +524,11 @@ class TestRouteLayer:
expected
), "Should classify and return routes above their thresholds"
def test_with_no_routes_passing_threshold(self, route_layer, index_cls):
def test_with_no_routes_passing_threshold(self, openai_encoder, routes, index_cls):
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=index_cls()
)
route_layer.score_threshold = 0.5
# Override _pass_threshold to always return False for this test
route_layer._pass_threshold = lambda scores, threshold: False
query_results = [
......@@ -541,7 +541,11 @@ class TestRouteLayer:
results == expected
), "Should return an empty list when no routes pass their thresholds"
def test_with_no_query_results(self, route_layer, index_cls):
def test_with_no_query_results(self, openai_encoder, routes, index_cls):
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=index_cls()
)
route_layer.score_threshold = 0.5
query_results = []
expected = []
results = route_layer._semantic_classify_multiple_routes(query_results)
......@@ -549,7 +553,11 @@ class TestRouteLayer:
results == expected
), "Should return an empty list when there are no query results"
def test_with_unrecognized_route(self, route_layer, index_cls):
def test_with_unrecognized_route(self, openai_encoder, routes, index_cls):
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=index_cls()
)
route_layer.score_threshold = 0.5
# Test with a route name that does not exist in the route_layer's routes
query_results = [{"route": "UnrecognizedRoute", "score": 0.9}]
expected = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment