From 04ae1a9f8909af4d3ef5f7c6f3f48ec8bb418134 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Wed, 8 May 2024 02:37:33 +0400
Subject: [PATCH] Linting.

---
 tests/unit/llms/test_llm_base.py   | 14 ++++++++------
 tests/unit/llms/test_llm_openai.py | 26 +++++++++++++++-----------
 tests/unit/test_layer.py           |  4 +++-
 tests/unit/test_route.py           |  6 ++++--
 4 files changed, 30 insertions(+), 20 deletions(-)

diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py
index 322435e9..3892afe8 100644
--- a/tests/unit/llms/test_llm_base.py
+++ b/tests/unit/llms/test_llm_base.py
@@ -16,12 +16,14 @@ class TestBaseLLM:
             base_llm("test")
 
     def test_base_llm_is_valid_inputs_valid_input_pass(self, base_llm):
-        test_schemas = [{
-            "name": "get_time",
-            "description": 'Finds the current time in a specific timezone.\n\n:param timezone: The timezone to find the current time in, should\n    be a valid timezone from the IANA Time Zone Database like\n    "America/New_York" or "Europe/London". Do NOT put the place\n    name itself like "rome", or "new york", you must provide\n    the IANA format.\n:type timezone: str\n:return: The current time in the specified timezone.',
-            "signature": "(timezone: str) -> str",
-            "output": "<class 'str'>",
-        }]
+        test_schemas = [
+            {
+                "name": "get_time",
+                "description": 'Finds the current time in a specific timezone.\n\n:param timezone: The timezone to find the current time in, should\n    be a valid timezone from the IANA Time Zone Database like\n    "America/New_York" or "Europe/London". Do NOT put the place\n    name itself like "rome", or "new york", you must provide\n    the IANA format.\n:type timezone: str\n:return: The current time in the specified timezone.',
+                "signature": "(timezone: str) -> str",
+                "output": "<class 'str'>",
+            }
+        ]
         test_inputs = [{"timezone": "America/New_York"}]
 
         assert base_llm._is_valid_inputs(test_inputs, test_schemas) is True
diff --git a/tests/unit/llms/test_llm_openai.py b/tests/unit/llms/test_llm_openai.py
index bc014bed..eeca55cb 100644
--- a/tests/unit/llms/test_llm_openai.py
+++ b/tests/unit/llms/test_llm_openai.py
@@ -122,7 +122,8 @@ class TestOpenAILLM:
         function_schemas = [{"type": "function", "name": "sample_function"}]
         output = openai_llm(llm_input, function_schemas)
         assert (
-            output == "[{'function_name': 'sample_function', 'arguments': {'timezone': 'America/New_York'}}]"
+            output
+            == "[{'function_name': 'sample_function', 'arguments': {'timezone': 'America/New_York'}}]"
         ), "Output did not match expected result with function schema"
 
     def test_openai_llm_call_with_invalid_tool_calls(self, openai_llm, mocker):
@@ -157,15 +158,12 @@ class TestOpenAILLM:
         with pytest.raises(Exception) as exc_info:
             openai_llm(llm_input, function_schemas)
 
-        expected_error_message = (
-            "LLM error: Invalid output, expected arguments to be specified for each tool call."
-        )
+        expected_error_message = "LLM error: Invalid output, expected arguments to be specified for each tool call."
         actual_error_message = str(exc_info.value)
         assert (
             expected_error_message in actual_error_message
         ), f"Expected error message: '{expected_error_message}', but got: '{actual_error_message}'"
 
-
     def test_extract_function_inputs(self, openai_llm, mocker):
         query = "fetch user data"
         function_schemas = [
@@ -179,17 +177,21 @@ class TestOpenAILLM:
                         "properties": {
                             "user_id": {
                                 "type": "string",
-                                "description": "The ID of the user."
+                                "description": "The ID of the user.",
                             }
                         },
-                        "required": ["user_id"]
-                    }
-                }
+                        "required": ["user_id"],
+                    },
+                },
             }
         ]
 
         # Mock the __call__ method to return a JSON string as expected
-        mocker.patch.object(OpenAILLM, "__call__", return_value='[{"function_name": "get_user_data", "arguments": {"user_id": "123"}}]')
+        mocker.patch.object(
+            OpenAILLM,
+            "__call__",
+            return_value='[{"function_name": "get_user_data", "arguments": {"user_id": "123"}}]',
+        )
         result = openai_llm.extract_function_inputs(query, function_schemas)
 
         # Ensure the __call__ method is called with the correct parameters
@@ -205,4 +207,6 @@ class TestOpenAILLM:
         )
 
         # Check if the result is as expected
-        assert result == [{"function_name": "get_user_data", "arguments": {"user_id": "123"}}], "The function inputs should match the expected dictionary."
\ No newline at end of file
+        assert result == [
+            {"function_name": "get_user_data", "arguments": {"user_id": "123"}}
+        ], "The function inputs should match the expected dictionary."
diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py
index d18f90a8..3490baa4 100644
--- a/tests/unit/test_layer.py
+++ b/tests/unit/test_layer.py
@@ -118,7 +118,9 @@ def routes_3():
 def dynamic_routes():
     return [
         Route(
-            name="Route 1", utterances=["Hello", "Hi"], function_schemas=[{"name": "test"}]
+            name="Route 1",
+            utterances=["Hello", "Hi"],
+            function_schemas=[{"name": "test"}],
         ),
         Route(
             name="Route 2",
diff --git a/tests/unit/test_route.py b/tests/unit/test_route.py
index 59050dbb..b21ffd87 100644
--- a/tests/unit/test_route.py
+++ b/tests/unit/test_route.py
@@ -76,7 +76,7 @@ class TestRoute:
 
     def test_generate_dynamic_route(self):
         mock_llm = MockLLM(name="test")
-        function_schemas = {"name": "test_function", "type": "function"}#
+        function_schemas = {"name": "test_function", "type": "function"}  #
         route = Route._generate_dynamic_route(
             llm=mock_llm, function_schemas=function_schemas, route_name="test_route"
         )
@@ -144,7 +144,9 @@ class TestRoute:
             """Test function docstring"""
             pass
 
-        dynamic_route = Route.from_dynamic_route(llm=mock_llm, entities=[test_function], route_name="test_route")
+        dynamic_route = Route.from_dynamic_route(
+            llm=mock_llm, entities=[test_function], route_name="test_route"
+        )
 
         assert dynamic_route.name == "test_function"
         assert dynamic_route.utterances == [
-- 
GitLab