Skip to content
Snippets Groups Projects
Unverified Commit 041341e3 authored by James Briggs's avatar James Briggs
Browse files

add new tests

parent ac2a527e
No related branches found
No related tags found
No related merge requests found
......@@ -204,10 +204,13 @@ class RouteLayer:
if text is None:
raise ValueError("Either text or vector must be provided")
vector = self._encode(text=text)
else:
vector = np.array(vector)
# get relevant utterances
results = self._retrieve(xq=vector)
# decide most relevant routes
top_class, top_class_scores = self._semantic_classify(results)
# TODO do we need this check?
route = self.check_for_matching_routes(top_class)
if route is None:
return RouteChoice()
......@@ -218,6 +221,10 @@ class RouteLayer:
)
passed = self._pass_threshold(top_class_scores, threshold)
if passed:
if route.function_schema and text is None:
raise ValueError(
"Route has a function schema, but no text was provided."
)
if route.function_schema and not isinstance(route.llm, BaseLLM):
if not self.llm:
logger.warning(
......@@ -228,10 +235,6 @@ class RouteLayer:
self.llm = OpenAILLM()
route.llm = self.llm
elif text is None:
raise ValueError(
"Text must be provided to use dynamic route with function_schema"
)
else:
route.llm = self.llm
return route(text)
......@@ -382,7 +385,7 @@ class RouteLayer:
config = self.to_config()
config.to_file(file_path)
def get_route_thresholds(self) -> Dict[str, float]:
def get_thresholds(self) -> Dict[str, float]:
# TODO: float() below is hacky fix for lint, fix this with new type?
thresholds = {
route.name: float(route.score_threshold or self.score_threshold)
......@@ -400,7 +403,7 @@ class RouteLayer:
Xq: Any = np.array(self.encoder(X))
# initial eval (we will iterate from here)
best_acc = self._vec_evaluate(Xq=Xq, y=y)
best_thresholds = self.get_route_thresholds()
best_thresholds = self.get_thresholds()
# begin fit
for _ in (pbar := tqdm(range(max_iter))):
pbar.set_postfix({"acc": round(best_acc, 2)})
......@@ -447,7 +450,7 @@ def threshold_random_search(
) -> Dict[str, float]:
"""Performs a random search iteration given a route layer and a search range."""
# extract the route names
routes = route_layer.get_route_thresholds()
routes = route_layer.get_thresholds()
route_names = list(routes.keys())
route_thresholds = list(routes.values())
# generate search range for each
......
......@@ -95,8 +95,18 @@ def routes():
@pytest.fixture
def dynamic_routes():
return [
Route(name="Route 1", utterances=["Hello", "Hi"], function_schema="test"),
Route(name="Route 2", utterances=["Goodbye", "Bye", "Au revoir"]),
Route(name="Route 1", utterances=["Hello", "Hi"], function_schema={"name": "test"}),
Route(name="Route 2", utterances=["Goodbye", "Bye", "Au revoir"], function_schema={"name": "test"}),
]
@pytest.fixture
def test_data():
return [
("What's your opinion on the current government?", "politics"),
("what's the weather like today?", "chitchat"),
("what is the Pythagorean theorem?", "mathematics"),
("what is photosynthesis?", "biology"),
("tell me an interesting fact", None)
]
......@@ -124,10 +134,10 @@ class TestRouteLayer:
route_layer_none = RouteLayer(encoder=None)
assert route_layer_none.score_threshold == openai_encoder.score_threshold
def test_initialization_dynamic_route(self, cohere_encoder, openai_encoder):
route_layer_cohere = RouteLayer(encoder=cohere_encoder)
def test_initialization_dynamic_route(self, cohere_encoder, openai_encoder, dynamic_routes):
route_layer_cohere = RouteLayer(encoder=cohere_encoder, routes=dynamic_routes)
assert route_layer_cohere.score_threshold == 0.3
route_layer_openai = RouteLayer(encoder=openai_encoder)
route_layer_openai = RouteLayer(encoder=openai_encoder, routes=dynamic_routes)
assert openai_encoder.score_threshold == 0.82
assert route_layer_openai.score_threshold == 0.82
......@@ -157,12 +167,23 @@ class TestRouteLayer:
def test_query_and_classification(self, openai_encoder, routes):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
query_result = route_layer("Hello").name
query_result = route_layer(text="Hello").name
assert query_result in ["Route 1", "Route 2"]
def test_query_with_no_index(self, openai_encoder):
route_layer = RouteLayer(encoder=openai_encoder)
assert route_layer("Anything").name is None
assert route_layer(text="Anything").name is None
def test_query_with_vector(self, openai_encoder, routes):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
vector = [0.1, 0.2, 0.3]
query_result = route_layer(vector=vector).name
assert query_result in ["Route 1", "Route 2"]
def test_query_with_no_text_or_vector(self, openai_encoder, routes):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
with pytest.raises(ValueError):
route_layer()
def test_semantic_classify(self, openai_encoder, routes):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
......@@ -186,6 +207,12 @@ class TestRouteLayer:
)
assert classification == "Route 1"
assert score == [0.9, 0.8]
def test_query_no_text_dynamic_route(self, openai_encoder, dynamic_routes):
route_layer = RouteLayer(encoder=openai_encoder, routes=dynamic_routes)
vector = [0.1, 0.2, 0.3]
with pytest.raises(ValueError):
route_layer(vector=vector)
def test_pass_threshold(self, openai_encoder):
route_layer = RouteLayer(encoder=openai_encoder)
......@@ -234,6 +261,24 @@ class TestRouteLayer:
assert (route_layer_from_config.categories == route_layer.categories).all()
assert route_layer_from_config.score_threshold == route_layer.score_threshold
def test_get_thresholds(self, openai_encoder, routes):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
assert route_layer.get_thresholds() == {'Route 1': 0.82, 'Route 2': 0.82}
class TestLayerFit:
def test_eval(self, openai_encoder, routes, test_data):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
# unpack test data
X, y = zip(*test_data)
# evaluate
route_layer.evaluate(X=X, y=y)
def test_fit(self, openai_encoder, routes, test_data):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
# unpack test data
X, y = zip(*test_data)
route_layer.fit(X=X, y=y)
# Add more tests for edge cases and error handling as needed.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment