From 92e8f8423bd944cd5eea4802c6af34c9e033bb56 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Wed, 27 Dec 2023 09:30:31 +0100 Subject: [PATCH] update tests and added type to encoder --- semantic_router/encoders/base.py | 4 +- tests/unit/test_layer.py | 138 +++++++++++++++++++++++++++++-- tests/unit/test_route.py | 77 +---------------- 3 files changed, 136 insertions(+), 83 deletions(-) diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py index 4e9d02a0..bd952403 100644 --- a/semantic_router/encoders/base.py +++ b/semantic_router/encoders/base.py @@ -1,9 +1,9 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field class BaseEncoder(BaseModel): name: str - type: str + type: str = Field(default="base") class Config: arbitrary_types_allowed = True diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 21b48917..873e488a 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -1,7 +1,9 @@ +import os import pytest +from unittest.mock import mock_open, patch from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder -from semantic_router.layer import RouteLayer +from semantic_router.layer import LayerConfig, RouteLayer from semantic_router.route import Route @@ -17,6 +19,50 @@ def mock_encoder_call(utterances): return [mock_responses.get(u, [0, 0, 0]) for u in utterances] +def layer_json(): + return """{ + "encoder_type": "cohere", + "encoder_name": "embed-english-v3.0", + "routes": [ + { + "name": "politics", + "utterances": [ + "isn't politics the best thing ever", + "why don't you tell me about your political opinions" + ], + "description": null, + "function_schema": null + }, + { + "name": "chitchat", + "utterances": [ + "how's the weather today?", + "how are things going?" + ], + "description": null, + "function_schema": null + } + ] +}""" + +def layer_yaml(): + return """encoder_name: embed-english-v3.0 +encoder_type: cohere +routes: +- description: null + function_schema: null + name: politics + utterances: + - isn't politics the best thing ever + - why don't you tell me about your political opinions +- description: null + function_schema: null + name: chitchat + utterances: + - how's the weather today? + - how are things going? + """ + @pytest.fixture def base_encoder(): return BaseEncoder(name="test-encoder") @@ -67,30 +113,31 @@ class TestRouteLayer: route_layer.add(route=route1) assert route_layer.index is not None and route_layer.categories is not None - assert len(route_layer.index) == 2 + assert route_layer.index.shape[0] == 2 assert len(set(route_layer.categories)) == 1 assert set(route_layer.categories) == {"Route 1"} route_layer.add(route=route2) - assert len(route_layer.index) == 4 + assert route_layer.index.shape[0] == 4 assert len(set(route_layer.categories)) == 2 assert set(route_layer.categories) == {"Route 1", "Route 2"} + del route_layer def test_add_multiple_routes(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder) route_layer._add_routes(routes=routes) assert route_layer.index is not None and route_layer.categories is not None - assert len(route_layer.index) == 5 + assert route_layer.index.shape[0] == 5 assert len(set(route_layer.categories)) == 2 def test_query_and_classification(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder, routes=routes) - query_result = route_layer("Hello") + query_result = route_layer("Hello").name assert query_result in ["Route 1", "Route 2"] def test_query_with_no_index(self, openai_encoder): route_layer = RouteLayer(encoder=openai_encoder) - assert route_layer("Anything") is None + assert route_layer("Anything").name is None def test_semantic_classify(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder, routes=routes) @@ -126,3 +173,82 @@ class TestRouteLayer: # Add more tests for edge cases and error handling as needed. + + +class TestLayerConfig: + def test_init(self): + layer_config = LayerConfig() + assert layer_config.routes == [] + + def test_to_file_json(self): + route = Route(name="test", utterances=["utterance"]) + layer_config = LayerConfig(routes=[route]) + with patch("builtins.open", mock_open()) as mocked_open: + layer_config.to_file("data/test_output.json") + mocked_open.assert_called_once_with("data/test_output.json", "w") + + def test_to_file_yaml(self): + route = Route(name="test", utterances=["utterance"]) + layer_config = LayerConfig(routes=[route]) + with patch("builtins.open", mock_open()) as mocked_open: + layer_config.to_file("data/test_output.yaml") + mocked_open.assert_called_once_with("data/test_output.yaml", "w") + + def test_to_file_invalid(self): + route = Route(name="test", utterances=["utterance"]) + layer_config = LayerConfig(routes=[route]) + with pytest.raises(ValueError): + layer_config.to_file("test_output.txt") + + def test_from_file_json(self): + mock_json_data = layer_json() + with patch("builtins.open", mock_open(read_data=mock_json_data)) as mocked_open: + layer_config = LayerConfig.from_file("data/test.json") + mocked_open.assert_called_once_with("data/test.json", "r") + assert isinstance(layer_config, LayerConfig) + + def test_from_file_yaml(self): + mock_yaml_data = layer_yaml() + with patch("builtins.open", mock_open(read_data=mock_yaml_data)) as mocked_open: + layer_config = LayerConfig.from_file("data/test.yaml") + mocked_open.assert_called_once_with("data/test.yaml", "r") + assert isinstance(layer_config, LayerConfig) + + def test_from_file_invalid(self): + with open("test.txt", "w") as f: + f.write("dummy content") + with pytest.raises(ValueError): + LayerConfig.from_file("test.txt") + os.remove("test.txt") + + def test_to_dict(self): + route = Route(name="test", utterances=["utterance"]) + layer_config = LayerConfig(routes=[route]) + assert layer_config.to_dict()["routes"] == [route.to_dict()] + + def test_add(self): + route = Route(name="test", utterances=["utterance"]) + route2 = Route(name="test2", utterances=["utterance2"]) + layer_config = LayerConfig() + layer_config.add(route) + # confirm route added + assert layer_config.routes == [route] + # add second route and check updates + layer_config.add(route2) + assert layer_config.routes == [route, route2] + + def test_get(self): + route = Route(name="test", utterances=["utterance"]) + layer_config = LayerConfig(routes=[route]) + assert layer_config.get("test") == route + + def test_get_not_found(self): + route = Route(name="test", utterances=["utterance"]) + layer_config = LayerConfig(routes=[route]) + assert layer_config.get("not_found") is None + + def test_remove(self): + route = Route(name="test", utterances=["utterance"]) + layer_config = LayerConfig(routes=[route]) + layer_config.remove("test") + assert layer_config.routes == [] diff --git a/tests/unit/test_route.py b/tests/unit/test_route.py index 1de3f0e5..4e19db24 100644 --- a/tests/unit/test_route.py +++ b/tests/unit/test_route.py @@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, mock_open, patch import pytest -from semantic_router.route import Route, RouteConfig, is_valid +from semantic_router.route import Route, is_valid # Is valid test: @@ -78,6 +78,7 @@ class TestRoute: "name": "test", "utterances": ["utterance"], "description": None, + "function_schema": None, } assert route.to_dict() == expected_dict @@ -146,77 +147,3 @@ class TestRoute: } """ assert Route._parse_route_config(config).strip() == expected_config.strip() - - -class TestRouteConfig: - def test_init(self): - route_config = RouteConfig() - assert route_config.routes == [] - - def test_to_file_json(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig(routes=[route]) - with patch("builtins.open", mock_open()) as mocked_open: - route_config.to_file("data/test_output.json") - mocked_open.assert_called_once_with("data/test_output.json", "w") - - def test_to_file_yaml(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig(routes=[route]) - with patch("builtins.open", mock_open()) as mocked_open: - route_config.to_file("data/test_output.yaml") - mocked_open.assert_called_once_with("data/test_output.yaml", "w") - - def test_to_file_invalid(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig(routes=[route]) - with pytest.raises(ValueError): - route_config.to_file("test_output.txt") - - def test_from_file_json(self): - mock_json_data = '[{"name": "test", "utterances": ["utterance"]}]' - with patch("builtins.open", mock_open(read_data=mock_json_data)) as mocked_open: - route_config = RouteConfig.from_file("data/test.json") - mocked_open.assert_called_once_with("data/test.json", "r") - assert isinstance(route_config, RouteConfig) - - def test_from_file_yaml(self): - mock_yaml_data = "- name: test\n utterances:\n - utterance" - with patch("builtins.open", mock_open(read_data=mock_yaml_data)) as mocked_open: - route_config = RouteConfig.from_file("data/test.yaml") - mocked_open.assert_called_once_with("data/test.yaml", "r") - assert isinstance(route_config, RouteConfig) - - def test_from_file_invalid(self): - with open("test.txt", "w") as f: - f.write("dummy content") - with pytest.raises(ValueError): - RouteConfig.from_file("test.txt") - os.remove("test.txt") - - def test_to_dict(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig(routes=[route]) - assert route_config.to_dict() == [route.to_dict()] - - def test_add(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig() - route_config.add(route) - assert route_config.routes == [route] - - def test_get(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig(routes=[route]) - assert route_config.get("test") == route - - def test_get_not_found(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig(routes=[route]) - assert route_config.get("not_found") is None - - def test_remove(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig(routes=[route]) - route_config.remove("test") - assert route_config.routes == [] -- GitLab