diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index d86eba846d39cc5f4eff681dba80180f0c5e5944..14581286c87668c415e49a51b6081bbea2d1cab5 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -37,15 +37,15 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-19 17:46:30 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", - "\u001b[32m2023-12-19 17:46:40 INFO semantic_router.utils.logger Generated route config:\n", + "\u001b[32m2023-12-20 12:21:30 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-20 12:21:33 INFO semantic_router.utils.logger Generated route config:\n", "{\n", " \"name\": \"get_time\",\n", " \"utterances\": [\n", @@ -56,8 +56,8 @@ " \"What's the time in Paris?\"\n", " ]\n", "}\u001b[0m\n", - "\u001b[32m2023-12-19 17:46:40 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", - "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Generated route config:\n", + "\u001b[32m2023-12-20 12:21:33 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-20 12:21:38 INFO semantic_router.utils.logger Generated route config:\n", "{\n", " \"name\": \"get_news\",\n", " \"utterances\": [\n", @@ -67,7 +67,10 @@ " \"Get me the breaking news from the UK\",\n", " \"What's the latest in Germany?\"\n", " ]\n", - "}\u001b[0m\n" + "}\u001b[0m\n", + "/var/folders/gf/cvm58m_x6pvghy227n5cmx5w0000gn/T/ipykernel_65737/1850296463.py:10: RuntimeWarning: coroutine 'Route.from_dynamic_route' was never awaited\n", + " route_config = RouteConfig(routes=routes)\n", + "RuntimeWarning: Enable tracemalloc to get the object allocation traceback\n" ] } ], diff --git a/semantic_router/layer.py b/semantic_router/layer.py index c0670b916eed760ff524a049c288e766da03796d..a161e353cf2545c291b687493e30d8995dbc233f 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -19,9 +19,8 @@ class RouteLayer: categories = None score_threshold = 0.82 - def __init__( - self, encoder: BaseEncoder = CohereEncoder(), routes: list[Route] = [] - ): + def __init__(self, encoder: BaseEncoder | None = None, routes: list[Route] = []): + self.encoder = encoder if encoder is not None else CohereEncoder() self.routes: list[Route] = routes self.encoder = encoder # decide on default threshold based on encoder diff --git a/semantic_router/route.py b/semantic_router/route.py index f46c005cc4d5021391acc9196b63e96113edd5b7..69f9d4e6305de0f3ecea308fec556d73efc3f869 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -170,11 +170,12 @@ class RouteConfig: self.routes.append(route) logger.info(f"Added route `{route.name}`") - def get(self, name: str): + def get(self, name: str) -> Route | None: for route in self.routes: if route.name == name: return route - raise Exception(f"Route `{name}` not found") + logger.error(f"Route `{name}` not found") + return None def remove(self, name: str): if name not in [route.name for route in self.routes]: diff --git a/test_output.json b/test_output.json new file mode 100644 index 0000000000000000000000000000000000000000..1f93008593dc770f1f001a47b819d652c14af179 --- /dev/null +++ b/test_output.json @@ -0,0 +1 @@ +[{"name": "test", "utterances": ["utterance"], "description": null}] diff --git a/test_output.txt b/test_output.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test_output.yaml b/test_output.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b71676477f7a48fff6174221c48d3c0595dbf14d --- /dev/null +++ b/test_output.yaml @@ -0,0 +1,4 @@ +- description: null + name: test + utterances: + - utterance diff --git a/tests/unit/test_route_config.py b/tests/unit/test_route_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0c964d82bd0ee9f4e2612f26f9d0ef88cc01a2c9 --- /dev/null +++ b/tests/unit/test_route_config.py @@ -0,0 +1,80 @@ +import os +from unittest.mock import mock_open, patch + +import pytest + +from semantic_router.route import Route, RouteConfig + + +class TestRouteConfig: + def test_init(self): + route_config = RouteConfig() + assert route_config.routes == [] + + def test_to_file_json(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + with patch("builtins.open", mock_open()) as mocked_open: + route_config.to_file("data/test_output.json") + mocked_open.assert_called_once_with("data/test_output.json", "w") + + def test_to_file_yaml(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + with patch("builtins.open", mock_open()) as mocked_open: + route_config.to_file("data/test_output.yaml") + mocked_open.assert_called_once_with("data/test_output.yaml", "w") + + def test_to_file_invalid(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + with pytest.raises(ValueError): + route_config.to_file("test_output.txt") + + def test_from_file_json(self): + mock_json_data = '[{"name": "test", "utterances": ["utterance"]}]' + with patch("builtins.open", mock_open(read_data=mock_json_data)) as mocked_open: + route_config = RouteConfig.from_file("data/test.json") + mocked_open.assert_called_once_with("data/test.json", "r") + assert isinstance(route_config, RouteConfig) + + def test_from_file_yaml(self): + mock_yaml_data = "- name: test\n utterances:\n - utterance" + with patch("builtins.open", mock_open(read_data=mock_yaml_data)) as mocked_open: + route_config = RouteConfig.from_file("data/test.yaml") + mocked_open.assert_called_once_with("data/test.yaml", "r") + assert isinstance(route_config, RouteConfig) + + def test_from_file_invalid(self): + with open("test.txt", "w") as f: + f.write("dummy content") + with pytest.raises(ValueError): + RouteConfig.from_file("test.txt") + os.remove("test.txt") + + def test_to_dict(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + assert route_config.to_dict() == [route.to_dict()] + + def test_add(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig() + route_config.add(route) + assert route_config.routes == [route] + + def test_get(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + assert route_config.get("test") == route + + def test_get_not_found(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + assert route_config.get("not_found") is None + + def test_remove(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + route_config.remove("test") + assert route_config.routes == []