From 92e8f8423bd944cd5eea4802c6af34c9e033bb56 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Wed, 27 Dec 2023 09:30:31 +0100
Subject: [PATCH] update tests and added type to encoder

---
 semantic_router/encoders/base.py |   4 +-
 tests/unit/test_layer.py         | 138 +++++++++++++++++++++++++++++--
 tests/unit/test_route.py         |  77 +----------------
 3 files changed, 136 insertions(+), 83 deletions(-)

diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py
index 4e9d02a0..bd952403 100644
--- a/semantic_router/encoders/base.py
+++ b/semantic_router/encoders/base.py
@@ -1,9 +1,9 @@
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
 
 
 class BaseEncoder(BaseModel):
     name: str
-    type: str
+    type: str = Field(default="base")
 
     class Config:
         arbitrary_types_allowed = True
diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py
index 21b48917..873e488a 100644
--- a/tests/unit/test_layer.py
+++ b/tests/unit/test_layer.py
@@ -1,7 +1,9 @@
+import os
 import pytest
+from unittest.mock import mock_open, patch
 
 from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder
-from semantic_router.layer import RouteLayer
+from semantic_router.layer import LayerConfig, RouteLayer
 from semantic_router.route import Route
 
 
@@ -17,6 +19,50 @@ def mock_encoder_call(utterances):
     return [mock_responses.get(u, [0, 0, 0]) for u in utterances]
 
 
+def layer_json():
+    return """{
+    "encoder_type": "cohere",
+    "encoder_name": "embed-english-v3.0",
+    "routes": [
+        {
+            "name": "politics",
+            "utterances": [
+                "isn't politics the best thing ever",
+                "why don't you tell me about your political opinions"
+            ],
+            "description": null,
+            "function_schema": null
+        },
+        {
+            "name": "chitchat",
+            "utterances": [
+                "how's the weather today?",
+                "how are things going?"
+            ],
+            "description": null,
+            "function_schema": null
+        }
+    ]
+}"""
+
+def layer_yaml():
+    return """encoder_name: embed-english-v3.0
+encoder_type: cohere
+routes:
+- description: null
+  function_schema: null
+  name: politics
+  utterances:
+  - isn't politics the best thing ever
+  - why don't you tell me about your political opinions
+- description: null
+  function_schema: null
+  name: chitchat
+  utterances:
+  - how's the weather today?
+  - how are things going?
+    """
+
 @pytest.fixture
 def base_encoder():
     return BaseEncoder(name="test-encoder")
@@ -67,30 +113,31 @@ class TestRouteLayer:
 
         route_layer.add(route=route1)
         assert route_layer.index is not None and route_layer.categories is not None
-        assert len(route_layer.index) == 2
+        assert route_layer.index.shape[0] == 2
         assert len(set(route_layer.categories)) == 1
         assert set(route_layer.categories) == {"Route 1"}
 
         route_layer.add(route=route2)
-        assert len(route_layer.index) == 4
+        assert route_layer.index.shape[0] == 4
         assert len(set(route_layer.categories)) == 2
         assert set(route_layer.categories) == {"Route 1", "Route 2"}
+        del route_layer
 
     def test_add_multiple_routes(self, openai_encoder, routes):
         route_layer = RouteLayer(encoder=openai_encoder)
         route_layer._add_routes(routes=routes)
         assert route_layer.index is not None and route_layer.categories is not None
-        assert len(route_layer.index) == 5
+        assert route_layer.index.shape[0] == 5
         assert len(set(route_layer.categories)) == 2
 
     def test_query_and_classification(self, openai_encoder, routes):
         route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
-        query_result = route_layer("Hello")
+        query_result = route_layer("Hello").name
         assert query_result in ["Route 1", "Route 2"]
 
     def test_query_with_no_index(self, openai_encoder):
         route_layer = RouteLayer(encoder=openai_encoder)
-        assert route_layer("Anything") is None
+        assert route_layer("Anything").name is None
 
     def test_semantic_classify(self, openai_encoder, routes):
         route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
@@ -126,3 +173,82 @@ class TestRouteLayer:
 
 
 # Add more tests for edge cases and error handling as needed.
+
+
+class TestLayerConfig:
+    def test_init(self):
+        layer_config = LayerConfig()
+        assert layer_config.routes == []
+
+    def test_to_file_json(self):
+        route = Route(name="test", utterances=["utterance"])
+        layer_config = LayerConfig(routes=[route])
+        with patch("builtins.open", mock_open()) as mocked_open:
+            layer_config.to_file("data/test_output.json")
+            mocked_open.assert_called_once_with("data/test_output.json", "w")
+
+    def test_to_file_yaml(self):
+        route = Route(name="test", utterances=["utterance"])
+        layer_config = LayerConfig(routes=[route])
+        with patch("builtins.open", mock_open()) as mocked_open:
+            layer_config.to_file("data/test_output.yaml")
+            mocked_open.assert_called_once_with("data/test_output.yaml", "w")
+
+    def test_to_file_invalid(self):
+        route = Route(name="test", utterances=["utterance"])
+        layer_config = LayerConfig(routes=[route])
+        with pytest.raises(ValueError):
+            layer_config.to_file("test_output.txt")
+
+    def test_from_file_json(self):
+        mock_json_data = layer_json()
+        with patch("builtins.open", mock_open(read_data=mock_json_data)) as mocked_open:
+            layer_config = LayerConfig.from_file("data/test.json")
+            mocked_open.assert_called_once_with("data/test.json", "r")
+            assert isinstance(layer_config, LayerConfig)
+
+    def test_from_file_yaml(self):
+        mock_yaml_data = layer_yaml()
+        with patch("builtins.open", mock_open(read_data=mock_yaml_data)) as mocked_open:
+            layer_config = LayerConfig.from_file("data/test.yaml")
+            mocked_open.assert_called_once_with("data/test.yaml", "r")
+            assert isinstance(layer_config, LayerConfig)
+
+    def test_from_file_invalid(self):
+        with open("test.txt", "w") as f:
+            f.write("dummy content")
+        with pytest.raises(ValueError):
+            LayerConfig.from_file("test.txt")
+        os.remove("test.txt")
+
+    def test_to_dict(self):
+        route = Route(name="test", utterances=["utterance"])
+        layer_config = LayerConfig(routes=[route])
+        assert layer_config.to_dict()["routes"] == [route.to_dict()]
+
+    def test_add(self):
+        route = Route(name="test", utterances=["utterance"])
+        route2 = Route(name="test2", utterances=["utterance2"])
+        layer_config = LayerConfig()
+        layer_config.add(route)
+        # confirm route added
+        assert layer_config.routes == [route]
+        # add second route and check updates
+        layer_config.add(route2)
+        assert layer_config.routes == [route, route2]
+
+    def test_get(self):
+        route = Route(name="test", utterances=["utterance"])
+        layer_config = LayerConfig(routes=[route])
+        assert layer_config.get("test") == route
+
+    def test_get_not_found(self):
+        route = Route(name="test", utterances=["utterance"])
+        layer_config = LayerConfig(routes=[route])
+        assert layer_config.get("not_found") is None
+
+    def test_remove(self):
+        route = Route(name="test", utterances=["utterance"])
+        layer_config = LayerConfig(routes=[route])
+        layer_config.remove("test")
+        assert layer_config.routes == []
diff --git a/tests/unit/test_route.py b/tests/unit/test_route.py
index 1de3f0e5..4e19db24 100644
--- a/tests/unit/test_route.py
+++ b/tests/unit/test_route.py
@@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, mock_open, patch
 
 import pytest
 
-from semantic_router.route import Route, RouteConfig, is_valid
+from semantic_router.route import Route, is_valid
 
 
 # Is valid test:
@@ -78,6 +78,7 @@ class TestRoute:
             "name": "test",
             "utterances": ["utterance"],
             "description": None,
+            "function_schema": None,
         }
         assert route.to_dict() == expected_dict
 
@@ -146,77 +147,3 @@ class TestRoute:
         }
         """
         assert Route._parse_route_config(config).strip() == expected_config.strip()
-
-
-class TestRouteConfig:
-    def test_init(self):
-        route_config = RouteConfig()
-        assert route_config.routes == []
-
-    def test_to_file_json(self):
-        route = Route(name="test", utterances=["utterance"])
-        route_config = RouteConfig(routes=[route])
-        with patch("builtins.open", mock_open()) as mocked_open:
-            route_config.to_file("data/test_output.json")
-            mocked_open.assert_called_once_with("data/test_output.json", "w")
-
-    def test_to_file_yaml(self):
-        route = Route(name="test", utterances=["utterance"])
-        route_config = RouteConfig(routes=[route])
-        with patch("builtins.open", mock_open()) as mocked_open:
-            route_config.to_file("data/test_output.yaml")
-            mocked_open.assert_called_once_with("data/test_output.yaml", "w")
-
-    def test_to_file_invalid(self):
-        route = Route(name="test", utterances=["utterance"])
-        route_config = RouteConfig(routes=[route])
-        with pytest.raises(ValueError):
-            route_config.to_file("test_output.txt")
-
-    def test_from_file_json(self):
-        mock_json_data = '[{"name": "test", "utterances": ["utterance"]}]'
-        with patch("builtins.open", mock_open(read_data=mock_json_data)) as mocked_open:
-            route_config = RouteConfig.from_file("data/test.json")
-            mocked_open.assert_called_once_with("data/test.json", "r")
-            assert isinstance(route_config, RouteConfig)
-
-    def test_from_file_yaml(self):
-        mock_yaml_data = "- name: test\n  utterances:\n  - utterance"
-        with patch("builtins.open", mock_open(read_data=mock_yaml_data)) as mocked_open:
-            route_config = RouteConfig.from_file("data/test.yaml")
-            mocked_open.assert_called_once_with("data/test.yaml", "r")
-            assert isinstance(route_config, RouteConfig)
-
-    def test_from_file_invalid(self):
-        with open("test.txt", "w") as f:
-            f.write("dummy content")
-        with pytest.raises(ValueError):
-            RouteConfig.from_file("test.txt")
-        os.remove("test.txt")
-
-    def test_to_dict(self):
-        route = Route(name="test", utterances=["utterance"])
-        route_config = RouteConfig(routes=[route])
-        assert route_config.to_dict() == [route.to_dict()]
-
-    def test_add(self):
-        route = Route(name="test", utterances=["utterance"])
-        route_config = RouteConfig()
-        route_config.add(route)
-        assert route_config.routes == [route]
-
-    def test_get(self):
-        route = Route(name="test", utterances=["utterance"])
-        route_config = RouteConfig(routes=[route])
-        assert route_config.get("test") == route
-
-    def test_get_not_found(self):
-        route = Route(name="test", utterances=["utterance"])
-        route_config = RouteConfig(routes=[route])
-        assert route_config.get("not_found") is None
-
-    def test_remove(self):
-        route = Route(name="test", utterances=["utterance"])
-        route_config = RouteConfig(routes=[route])
-        route_config.remove("test")
-        assert route_config.routes == []
-- 
GitLab