From b7a3dd69bf9d137354e46e4d3bad75f3f94c2d1e Mon Sep 17 00:00:00 2001 From: Simonas <20096648+simjak@users.noreply.github.com> Date: Wed, 20 Dec 2023 16:17:30 +0200 Subject: [PATCH] RouteConfig tests --- docs/examples/function_calling.ipynb | 17 +++--- semantic_router/layer.py | 5 +- semantic_router/route.py | 5 +- test_output.json | 1 + test_output.txt | 0 test_output.yaml | 4 ++ tests/unit/test_route_config.py | 80 ++++++++++++++++++++++++++++ 7 files changed, 100 insertions(+), 12 deletions(-) create mode 100644 test_output.json create mode 100644 test_output.txt create mode 100644 test_output.yaml create mode 100644 tests/unit/test_route_config.py diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index d86eba84..14581286 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -37,15 +37,15 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-19 17:46:30 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", - "\u001b[32m2023-12-19 17:46:40 INFO semantic_router.utils.logger Generated route config:\n", + "\u001b[32m2023-12-20 12:21:30 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-20 12:21:33 INFO semantic_router.utils.logger Generated route config:\n", "{\n", " \"name\": \"get_time\",\n", " \"utterances\": [\n", @@ -56,8 +56,8 @@ " \"What's the time in Paris?\"\n", " ]\n", "}\u001b[0m\n", - "\u001b[32m2023-12-19 17:46:40 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", - "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Generated route config:\n", + "\u001b[32m2023-12-20 12:21:33 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-20 12:21:38 INFO semantic_router.utils.logger Generated route config:\n", "{\n", " \"name\": \"get_news\",\n", " \"utterances\": [\n", @@ -67,7 +67,10 @@ " \"Get me the breaking news from the UK\",\n", " \"What's the latest in Germany?\"\n", " ]\n", - "}\u001b[0m\n" + "}\u001b[0m\n", + "/var/folders/gf/cvm58m_x6pvghy227n5cmx5w0000gn/T/ipykernel_65737/1850296463.py:10: RuntimeWarning: coroutine 'Route.from_dynamic_route' was never awaited\n", + " route_config = RouteConfig(routes=routes)\n", + "RuntimeWarning: Enable tracemalloc to get the object allocation traceback\n" ] } ], diff --git a/semantic_router/layer.py b/semantic_router/layer.py index c0670b91..a161e353 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -19,9 +19,8 @@ class RouteLayer: categories = None score_threshold = 0.82 - def __init__( - self, encoder: BaseEncoder = CohereEncoder(), routes: list[Route] = [] - ): + def __init__(self, encoder: BaseEncoder | None = None, routes: list[Route] = []): + self.encoder = encoder if encoder is not None else CohereEncoder() self.routes: list[Route] = routes self.encoder = encoder # decide on default threshold based on encoder diff --git a/semantic_router/route.py b/semantic_router/route.py index f46c005c..69f9d4e6 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -170,11 +170,12 @@ class RouteConfig: self.routes.append(route) logger.info(f"Added route `{route.name}`") - def get(self, name: str): + def get(self, name: str) -> Route | None: for route in self.routes: if route.name == name: return route - raise Exception(f"Route `{name}` not found") + logger.error(f"Route `{name}` not found") + return None def remove(self, name: str): if name not in [route.name for route in self.routes]: diff --git a/test_output.json b/test_output.json new file mode 100644 index 00000000..1f930085 --- /dev/null +++ b/test_output.json @@ -0,0 +1 @@ +[{"name": "test", "utterances": ["utterance"], "description": null}] diff --git a/test_output.txt b/test_output.txt new file mode 100644 index 00000000..e69de29b diff --git a/test_output.yaml b/test_output.yaml new file mode 100644 index 00000000..b7167647 --- /dev/null +++ b/test_output.yaml @@ -0,0 +1,4 @@ +- description: null + name: test + utterances: + - utterance diff --git a/tests/unit/test_route_config.py b/tests/unit/test_route_config.py new file mode 100644 index 00000000..0c964d82 --- /dev/null +++ b/tests/unit/test_route_config.py @@ -0,0 +1,80 @@ +import os +from unittest.mock import mock_open, patch + +import pytest + +from semantic_router.route import Route, RouteConfig + + +class TestRouteConfig: + def test_init(self): + route_config = RouteConfig() + assert route_config.routes == [] + + def test_to_file_json(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + with patch("builtins.open", mock_open()) as mocked_open: + route_config.to_file("data/test_output.json") + mocked_open.assert_called_once_with("data/test_output.json", "w") + + def test_to_file_yaml(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + with patch("builtins.open", mock_open()) as mocked_open: + route_config.to_file("data/test_output.yaml") + mocked_open.assert_called_once_with("data/test_output.yaml", "w") + + def test_to_file_invalid(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + with pytest.raises(ValueError): + route_config.to_file("test_output.txt") + + def test_from_file_json(self): + mock_json_data = '[{"name": "test", "utterances": ["utterance"]}]' + with patch("builtins.open", mock_open(read_data=mock_json_data)) as mocked_open: + route_config = RouteConfig.from_file("data/test.json") + mocked_open.assert_called_once_with("data/test.json", "r") + assert isinstance(route_config, RouteConfig) + + def test_from_file_yaml(self): + mock_yaml_data = "- name: test\n utterances:\n - utterance" + with patch("builtins.open", mock_open(read_data=mock_yaml_data)) as mocked_open: + route_config = RouteConfig.from_file("data/test.yaml") + mocked_open.assert_called_once_with("data/test.yaml", "r") + assert isinstance(route_config, RouteConfig) + + def test_from_file_invalid(self): + with open("test.txt", "w") as f: + f.write("dummy content") + with pytest.raises(ValueError): + RouteConfig.from_file("test.txt") + os.remove("test.txt") + + def test_to_dict(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + assert route_config.to_dict() == [route.to_dict()] + + def test_add(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig() + route_config.add(route) + assert route_config.routes == [route] + + def test_get(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + assert route_config.get("test") == route + + def test_get_not_found(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + assert route_config.get("not_found") is None + + def test_remove(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + route_config.remove("test") + assert route_config.routes == [] -- GitLab