From b7a3dd69bf9d137354e46e4d3bad75f3f94c2d1e Mon Sep 17 00:00:00 2001
From: Simonas <20096648+simjak@users.noreply.github.com>
Date: Wed, 20 Dec 2023 16:17:30 +0200
Subject: [PATCH] RouteConfig tests

---
 docs/examples/function_calling.ipynb | 17 +++---
 semantic_router/layer.py             |  5 +-
 semantic_router/route.py             |  5 +-
 test_output.json                     |  1 +
 test_output.txt                      |  0
 test_output.yaml                     |  4 ++
 tests/unit/test_route_config.py      | 80 ++++++++++++++++++++++++++++
 7 files changed, 100 insertions(+), 12 deletions(-)
 create mode 100644 test_output.json
 create mode 100644 test_output.txt
 create mode 100644 test_output.yaml
 create mode 100644 tests/unit/test_route_config.py

diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb
index d86eba84..14581286 100644
--- a/docs/examples/function_calling.ipynb
+++ b/docs/examples/function_calling.ipynb
@@ -9,7 +9,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 15,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -37,15 +37,15 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 17,
    "metadata": {},
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "\u001b[32m2023-12-19 17:46:30 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n",
-      "\u001b[32m2023-12-19 17:46:40 INFO semantic_router.utils.logger Generated route config:\n",
+      "\u001b[32m2023-12-20 12:21:30 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n",
+      "\u001b[32m2023-12-20 12:21:33 INFO semantic_router.utils.logger Generated route config:\n",
       "{\n",
       "    \"name\": \"get_time\",\n",
       "    \"utterances\": [\n",
@@ -56,8 +56,8 @@
       "        \"What's the time in Paris?\"\n",
       "    ]\n",
       "}\u001b[0m\n",
-      "\u001b[32m2023-12-19 17:46:40 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n",
-      "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Generated route config:\n",
+      "\u001b[32m2023-12-20 12:21:33 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n",
+      "\u001b[32m2023-12-20 12:21:38 INFO semantic_router.utils.logger Generated route config:\n",
       "{\n",
       "    \"name\": \"get_news\",\n",
       "    \"utterances\": [\n",
@@ -67,7 +67,10 @@
       "        \"Get me the breaking news from the UK\",\n",
       "        \"What's the latest in Germany?\"\n",
       "    ]\n",
-      "}\u001b[0m\n"
+      "}\u001b[0m\n",
+      "/var/folders/gf/cvm58m_x6pvghy227n5cmx5w0000gn/T/ipykernel_65737/1850296463.py:10: RuntimeWarning: coroutine 'Route.from_dynamic_route' was never awaited\n",
+      "  route_config = RouteConfig(routes=routes)\n",
+      "RuntimeWarning: Enable tracemalloc to get the object allocation traceback\n"
      ]
     }
    ],
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index c0670b91..a161e353 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -19,9 +19,8 @@ class RouteLayer:
     categories = None
     score_threshold = 0.82
 
-    def __init__(
-        self, encoder: BaseEncoder = CohereEncoder(), routes: list[Route] = []
-    ):
+    def __init__(self, encoder: BaseEncoder | None = None, routes: list[Route] = []):
+        self.encoder = encoder if encoder is not None else CohereEncoder()
         self.routes: list[Route] = routes
         self.encoder = encoder
         # decide on default threshold based on encoder
diff --git a/semantic_router/route.py b/semantic_router/route.py
index f46c005c..69f9d4e6 100644
--- a/semantic_router/route.py
+++ b/semantic_router/route.py
@@ -170,11 +170,12 @@ class RouteConfig:
         self.routes.append(route)
         logger.info(f"Added route `{route.name}`")
 
-    def get(self, name: str):
+    def get(self, name: str) -> Route | None:
         for route in self.routes:
             if route.name == name:
                 return route
-        raise Exception(f"Route `{name}` not found")
+        logger.error(f"Route `{name}` not found")
+        return None
 
     def remove(self, name: str):
         if name not in [route.name for route in self.routes]:
diff --git a/test_output.json b/test_output.json
new file mode 100644
index 00000000..1f930085
--- /dev/null
+++ b/test_output.json
@@ -0,0 +1 @@
+[{"name": "test", "utterances": ["utterance"], "description": null}]
diff --git a/test_output.txt b/test_output.txt
new file mode 100644
index 00000000..e69de29b
diff --git a/test_output.yaml b/test_output.yaml
new file mode 100644
index 00000000..b7167647
--- /dev/null
+++ b/test_output.yaml
@@ -0,0 +1,4 @@
+- description: null
+  name: test
+  utterances:
+  - utterance
diff --git a/tests/unit/test_route_config.py b/tests/unit/test_route_config.py
new file mode 100644
index 00000000..0c964d82
--- /dev/null
+++ b/tests/unit/test_route_config.py
@@ -0,0 +1,80 @@
+import os
+from unittest.mock import mock_open, patch
+
+import pytest
+
+from semantic_router.route import Route, RouteConfig
+
+
+class TestRouteConfig:
+    def test_init(self):
+        route_config = RouteConfig()
+        assert route_config.routes == []
+
+    def test_to_file_json(self):
+        route = Route(name="test", utterances=["utterance"])
+        route_config = RouteConfig(routes=[route])
+        with patch("builtins.open", mock_open()) as mocked_open:
+            route_config.to_file("data/test_output.json")
+            mocked_open.assert_called_once_with("data/test_output.json", "w")
+
+    def test_to_file_yaml(self):
+        route = Route(name="test", utterances=["utterance"])
+        route_config = RouteConfig(routes=[route])
+        with patch("builtins.open", mock_open()) as mocked_open:
+            route_config.to_file("data/test_output.yaml")
+            mocked_open.assert_called_once_with("data/test_output.yaml", "w")
+
+    def test_to_file_invalid(self):
+        route = Route(name="test", utterances=["utterance"])
+        route_config = RouteConfig(routes=[route])
+        with pytest.raises(ValueError):
+            route_config.to_file("test_output.txt")
+
+    def test_from_file_json(self):
+        mock_json_data = '[{"name": "test", "utterances": ["utterance"]}]'
+        with patch("builtins.open", mock_open(read_data=mock_json_data)) as mocked_open:
+            route_config = RouteConfig.from_file("data/test.json")
+            mocked_open.assert_called_once_with("data/test.json", "r")
+            assert isinstance(route_config, RouteConfig)
+
+    def test_from_file_yaml(self):
+        mock_yaml_data = "- name: test\n  utterances:\n  - utterance"
+        with patch("builtins.open", mock_open(read_data=mock_yaml_data)) as mocked_open:
+            route_config = RouteConfig.from_file("data/test.yaml")
+            mocked_open.assert_called_once_with("data/test.yaml", "r")
+            assert isinstance(route_config, RouteConfig)
+
+    def test_from_file_invalid(self):
+        with open("test.txt", "w") as f:
+            f.write("dummy content")
+        with pytest.raises(ValueError):
+            RouteConfig.from_file("test.txt")
+        os.remove("test.txt")
+
+    def test_to_dict(self):
+        route = Route(name="test", utterances=["utterance"])
+        route_config = RouteConfig(routes=[route])
+        assert route_config.to_dict() == [route.to_dict()]
+
+    def test_add(self):
+        route = Route(name="test", utterances=["utterance"])
+        route_config = RouteConfig()
+        route_config.add(route)
+        assert route_config.routes == [route]
+
+    def test_get(self):
+        route = Route(name="test", utterances=["utterance"])
+        route_config = RouteConfig(routes=[route])
+        assert route_config.get("test") == route
+
+    def test_get_not_found(self):
+        route = Route(name="test", utterances=["utterance"])
+        route_config = RouteConfig(routes=[route])
+        assert route_config.get("not_found") is None
+
+    def test_remove(self):
+        route = Route(name="test", utterances=["utterance"])
+        route_config = RouteConfig(routes=[route])
+        route_config.remove("test")
+        assert route_config.routes == []
-- 
GitLab