From 5fe6b553d4fe9ec764d015cc4980d9dc2406d5c3 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Wed, 8 May 2024 03:50:30 +0400
Subject: [PATCH] More pytests for openai.py.

---
 tests/unit/llms/test_llm_openai.py | 253 +++++++++++++++++++++++++++--
 1 file changed, 235 insertions(+), 18 deletions(-)

diff --git a/tests/unit/llms/test_llm_openai.py b/tests/unit/llms/test_llm_openai.py
index eeca55cb..04b811a0 100644
--- a/tests/unit/llms/test_llm_openai.py
+++ b/tests/unit/llms/test_llm_openai.py
@@ -10,6 +10,27 @@ def openai_llm(mocker):
     return OpenAILLM(openai_api_key="test_api_key")
 
 
+get_user_data_schema = [
+    {
+        "type": "function",
+        "function": {
+            "name": "get_user_data",
+            "description": "Function to fetch user data.",
+            "parameters": {
+                "type": "object",
+                "properties": {
+                    "user_id": {
+                        "type": "string",
+                        "description": "The ID of the user.",
+                    }
+                },
+                "required": ["user_id"],
+            },
+        },
+    }
+]
+
+
 class TestOpenAILLM:
     def test_openai_llm_init_with_api_key(self, openai_llm):
         assert openai_llm.client is not None, "Client should be initialized"
@@ -166,6 +187,170 @@ class TestOpenAILLM:
 
     def test_extract_function_inputs(self, openai_llm, mocker):
         query = "fetch user data"
+        function_schemas = get_user_data_schema
+
+        # Mock the __call__ method to return a JSON string as expected
+        mocker.patch.object(
+            OpenAILLM,
+            "__call__",
+            return_value='[{"function_name": "get_user_data", "arguments": {"user_id": "123"}}]',
+        )
+        result = openai_llm.extract_function_inputs(query, function_schemas)
+
+        # Ensure the __call__ method is called with the correct parameters
+        expected_messages = [
+            Message(
+                role="system",
+                content="You are an intelligent AI. Given a command or request from the user, call the function to complete the request.",
+            ),
+            Message(role="user", content=query),
+        ]
+        openai_llm.__call__.assert_called_once_with(
+            messages=expected_messages, function_schemas=function_schemas
+        )
+
+        # Check if the result is as expected
+        assert result == [
+            {"function_name": "get_user_data", "arguments": {"user_id": "123"}}
+        ], "The function inputs should match the expected dictionary."
+
+    def test_openai_llm_call_with_no_tool_calls_specified(self, openai_llm, mocker):
+        # Mocking the completion object to simulate no tool calls being specified
+        mock_completion = mocker.MagicMock()
+        mock_completion.choices[0].message.tool_calls = []  # Empty list of tool calls
+
+        # Patching the completions.create method to return the mocked completion
+        mocker.patch.object(
+            openai_llm.client.chat.completions, "create", return_value=mock_completion
+        )
+
+        # Input message list
+        llm_input = [Message(role="user", content="test")]
+        # Example function schema
+        function_schemas = [{"type": "function", "name": "sample_function"}]
+
+        # Expecting a generic Exception to be raised due to no tool calls being specified
+        with pytest.raises(Exception) as exc_info:
+            openai_llm(llm_input, function_schemas)
+
+        # Check if the raised Exception contains the expected message
+        expected_error_message = (
+            "LLM error: Invalid output, expected at least one tool to be specified."
+        )
+        assert (
+            str(exc_info.value) == expected_error_message
+        ), f"Expected error message: '{expected_error_message}', but got: '{str(exc_info.value)}'"
+
+    def test_extract_function_inputs_no_output(self, openai_llm, mocker):
+        query = "fetch user data"
+        function_schemas = [{"type": "function", "name": "get_user_data"}]
+
+        # Mock the __call__ method to return an empty string
+        mocker.patch.object(OpenAILLM, "__call__", return_value="")
+
+        # Expecting an Exception due to no output
+        with pytest.raises(Exception) as exc_info:
+            openai_llm.extract_function_inputs(query, function_schemas)
+
+        assert (
+            str(exc_info.value) == "No output generated for extract function input"
+        ), "Expected exception message not found"
+
+    def test_extract_function_inputs_invalid_output(self, openai_llm, mocker):
+        query = "fetch user data"
+        function_schemas = [{"type": "function", "name": "get_user_data"}]
+
+        # Mock the __call__ method to return a JSON string
+        mocker.patch.object(
+            OpenAILLM,
+            "__call__",
+            return_value='[{"function_name": "get_user_data", "arguments": {"user_id": "123"}}]',
+        )
+
+        # Mock _is_valid_inputs to return False
+        mocker.patch.object(OpenAILLM, "_is_valid_inputs", return_value=False)
+
+        # Expecting a ValueError due to invalid inputs
+        with pytest.raises(ValueError) as exc_info:
+            openai_llm.extract_function_inputs(query, function_schemas)
+
+        assert (
+            str(exc_info.value) == "Invalid inputs"
+        ), "Expected exception message not found"
+
+    def test_is_valid_inputs_missing_function_name(self, openai_llm, mocker):
+        # Mock the logger to capture the error messages
+        mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error")
+
+        # Input where 'function_name' is missing
+        inputs = [{"arguments": {"user_id": "123"}}]
+        function_schemas = get_user_data_schema
+
+        # Call the method with the test inputs
+        result = openai_llm._is_valid_inputs(inputs, function_schemas)
+
+        # Assert that the method returns False due to missing 'function_name'
+        assert (
+            not result
+        ), "The method should return False when 'function_name' is missing"
+
+        # Check that the appropriate error message was logged
+        mocked_logger.assert_called_once_with(
+            "Missing 'function_name' or 'arguments' in inputs"
+        )
+
+    def test_is_valid_inputs_missing_arguments(self, openai_llm, mocker):
+        # Mock the logger to capture the error messages
+        mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error")
+
+        # Input where 'arguments' is missing but 'function_name' is present
+        inputs = [{"function_name": "get_user_data"}]
+        function_schemas = get_user_data_schema
+
+        # Call the method with the test inputs
+        result = openai_llm._is_valid_inputs(inputs, function_schemas)
+
+        # Assert that the method returns False due to missing 'arguments'
+        assert not result, "The method should return False when 'arguments' is missing"
+
+        # Check that the appropriate error message was logged
+        mocked_logger.assert_called_once_with(
+            "Missing 'function_name' or 'arguments' in inputs"
+        )
+
+    def test_is_valid_inputs_no_matching_schema(self, openai_llm, mocker):
+        # Mock the logger to capture the error messages
+        mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error")
+
+        # Input where 'function_name' does not match any schema
+        inputs = [
+            {
+                "function_name": "name_that_does_not_exist_in_schema",
+                "arguments": {"user_id": "123"},
+            }
+        ]
+        function_schemas = get_user_data_schema
+
+        # Call the method with the test inputs
+        result = openai_llm._is_valid_inputs(inputs, function_schemas)
+
+        # Assert that the method returns False due to no matching function schema
+        assert (
+            not result
+        ), "The method should return False when no matching function schema is found"
+
+        # Check that the appropriate error message was logged
+        expected_error_message = "No matching function schema found for function name: name_that_does_not_exist_in_schema"
+        mocked_logger.assert_called_once_with(expected_error_message)
+
+    def test_is_valid_inputs_validation_failed(self, openai_llm, mocker):
+        # Mock the logger to capture the error messages
+        mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error")
+
+        # Input where 'arguments' do not meet the schema requirements
+        inputs = [
+            {"function_name": "get_user_data", "arguments": {"user_id": 123}}
+        ]  # user_id should be a string, not an integer
         function_schemas = [
             {
                 "type": "function",
@@ -186,27 +371,59 @@ class TestOpenAILLM:
             }
         ]
 
-        # Mock the __call__ method to return a JSON string as expected
+        # Mock the _validate_single_function_inputs method to return False
         mocker.patch.object(
-            OpenAILLM,
-            "__call__",
-            return_value='[{"function_name": "get_user_data", "arguments": {"user_id": "123"}}]',
+            OpenAILLM, "_validate_single_function_inputs", return_value=False
         )
-        result = openai_llm.extract_function_inputs(query, function_schemas)
 
-        # Ensure the __call__ method is called with the correct parameters
-        expected_messages = [
-            Message(
-                role="system",
-                content="You are an intelligent AI. Given a command or request from the user, call the function to complete the request.",
-            ),
-            Message(role="user", content=query),
+        # Call the method with the test inputs
+        result = openai_llm._is_valid_inputs(inputs, function_schemas)
+
+        # Assert that the method returns False due to validation failure
+        assert not result, "The method should return False when validation fails"
+
+        # Check that the appropriate error message was logged
+        expected_error_message = "Validation failed for function name: get_user_data"
+        mocked_logger.assert_called_once_with(expected_error_message)
+
+    def test_is_valid_inputs_exception_handling(self, openai_llm, mocker):
+        # Mock the logger to capture the error messages
+        mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error")
+
+        # Create test inputs that are valid but mock an internal method to raise an exception
+        inputs = [{"function_name": "get_user_data", "arguments": {"user_id": "123"}}]
+        function_schemas = [
+            {
+                "type": "function",
+                "function": {
+                    "name": "get_user_data",
+                    "description": "Function to fetch user data.",
+                    "parameters": {
+                        "type": "object",
+                        "properties": {
+                            "user_id": {
+                                "type": "string",
+                                "description": "The ID of the user.",
+                            }
+                        },
+                        "required": ["user_id"],
+                    },
+                },
+            }
         ]
-        openai_llm.__call__.assert_called_once_with(
-            messages=expected_messages, function_schemas=function_schemas
+
+        # Mock a method used within _is_valid_inputs to raise an Exception
+        mocker.patch.object(
+            OpenAILLM,
+            "_validate_single_function_inputs",
+            side_effect=Exception("Test exception"),
         )
 
-        # Check if the result is as expected
-        assert result == [
-            {"function_name": "get_user_data", "arguments": {"user_id": "123"}}
-        ], "The function inputs should match the expected dictionary."
+        # Call the method with the test inputs
+        result = openai_llm._is_valid_inputs(inputs, function_schemas)
+
+        # Assert that the method returns False due to exception
+        assert not result, "The method should return False when an exception occurs"
+
+        # Check that the appropriate error message was logged
+        mocked_logger.assert_called_once_with("Input validation error: Test exception")
-- 
GitLab