From 5fe6b553d4fe9ec764d015cc4980d9dc2406d5c3 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Wed, 8 May 2024 03:50:30 +0400 Subject: [PATCH] More pytests for openai.py. --- tests/unit/llms/test_llm_openai.py | 253 +++++++++++++++++++++++++++-- 1 file changed, 235 insertions(+), 18 deletions(-) diff --git a/tests/unit/llms/test_llm_openai.py b/tests/unit/llms/test_llm_openai.py index eeca55cb..04b811a0 100644 --- a/tests/unit/llms/test_llm_openai.py +++ b/tests/unit/llms/test_llm_openai.py @@ -10,6 +10,27 @@ def openai_llm(mocker): return OpenAILLM(openai_api_key="test_api_key") +get_user_data_schema = [ + { + "type": "function", + "function": { + "name": "get_user_data", + "description": "Function to fetch user data.", + "parameters": { + "type": "object", + "properties": { + "user_id": { + "type": "string", + "description": "The ID of the user.", + } + }, + "required": ["user_id"], + }, + }, + } +] + + class TestOpenAILLM: def test_openai_llm_init_with_api_key(self, openai_llm): assert openai_llm.client is not None, "Client should be initialized" @@ -166,6 +187,170 @@ class TestOpenAILLM: def test_extract_function_inputs(self, openai_llm, mocker): query = "fetch user data" + function_schemas = get_user_data_schema + + # Mock the __call__ method to return a JSON string as expected + mocker.patch.object( + OpenAILLM, + "__call__", + return_value='[{"function_name": "get_user_data", "arguments": {"user_id": "123"}}]', + ) + result = openai_llm.extract_function_inputs(query, function_schemas) + + # Ensure the __call__ method is called with the correct parameters + expected_messages = [ + Message( + role="system", + content="You are an intelligent AI. Given a command or request from the user, call the function to complete the request.", + ), + Message(role="user", content=query), + ] + openai_llm.__call__.assert_called_once_with( + messages=expected_messages, function_schemas=function_schemas + ) + + # Check if the result is as expected + assert result == [ + {"function_name": "get_user_data", "arguments": {"user_id": "123"}} + ], "The function inputs should match the expected dictionary." + + def test_openai_llm_call_with_no_tool_calls_specified(self, openai_llm, mocker): + # Mocking the completion object to simulate no tool calls being specified + mock_completion = mocker.MagicMock() + mock_completion.choices[0].message.tool_calls = [] # Empty list of tool calls + + # Patching the completions.create method to return the mocked completion + mocker.patch.object( + openai_llm.client.chat.completions, "create", return_value=mock_completion + ) + + # Input message list + llm_input = [Message(role="user", content="test")] + # Example function schema + function_schemas = [{"type": "function", "name": "sample_function"}] + + # Expecting a generic Exception to be raised due to no tool calls being specified + with pytest.raises(Exception) as exc_info: + openai_llm(llm_input, function_schemas) + + # Check if the raised Exception contains the expected message + expected_error_message = ( + "LLM error: Invalid output, expected at least one tool to be specified." + ) + assert ( + str(exc_info.value) == expected_error_message + ), f"Expected error message: '{expected_error_message}', but got: '{str(exc_info.value)}'" + + def test_extract_function_inputs_no_output(self, openai_llm, mocker): + query = "fetch user data" + function_schemas = [{"type": "function", "name": "get_user_data"}] + + # Mock the __call__ method to return an empty string + mocker.patch.object(OpenAILLM, "__call__", return_value="") + + # Expecting an Exception due to no output + with pytest.raises(Exception) as exc_info: + openai_llm.extract_function_inputs(query, function_schemas) + + assert ( + str(exc_info.value) == "No output generated for extract function input" + ), "Expected exception message not found" + + def test_extract_function_inputs_invalid_output(self, openai_llm, mocker): + query = "fetch user data" + function_schemas = [{"type": "function", "name": "get_user_data"}] + + # Mock the __call__ method to return a JSON string + mocker.patch.object( + OpenAILLM, + "__call__", + return_value='[{"function_name": "get_user_data", "arguments": {"user_id": "123"}}]', + ) + + # Mock _is_valid_inputs to return False + mocker.patch.object(OpenAILLM, "_is_valid_inputs", return_value=False) + + # Expecting a ValueError due to invalid inputs + with pytest.raises(ValueError) as exc_info: + openai_llm.extract_function_inputs(query, function_schemas) + + assert ( + str(exc_info.value) == "Invalid inputs" + ), "Expected exception message not found" + + def test_is_valid_inputs_missing_function_name(self, openai_llm, mocker): + # Mock the logger to capture the error messages + mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") + + # Input where 'function_name' is missing + inputs = [{"arguments": {"user_id": "123"}}] + function_schemas = get_user_data_schema + + # Call the method with the test inputs + result = openai_llm._is_valid_inputs(inputs, function_schemas) + + # Assert that the method returns False due to missing 'function_name' + assert ( + not result + ), "The method should return False when 'function_name' is missing" + + # Check that the appropriate error message was logged + mocked_logger.assert_called_once_with( + "Missing 'function_name' or 'arguments' in inputs" + ) + + def test_is_valid_inputs_missing_arguments(self, openai_llm, mocker): + # Mock the logger to capture the error messages + mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") + + # Input where 'arguments' is missing but 'function_name' is present + inputs = [{"function_name": "get_user_data"}] + function_schemas = get_user_data_schema + + # Call the method with the test inputs + result = openai_llm._is_valid_inputs(inputs, function_schemas) + + # Assert that the method returns False due to missing 'arguments' + assert not result, "The method should return False when 'arguments' is missing" + + # Check that the appropriate error message was logged + mocked_logger.assert_called_once_with( + "Missing 'function_name' or 'arguments' in inputs" + ) + + def test_is_valid_inputs_no_matching_schema(self, openai_llm, mocker): + # Mock the logger to capture the error messages + mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") + + # Input where 'function_name' does not match any schema + inputs = [ + { + "function_name": "name_that_does_not_exist_in_schema", + "arguments": {"user_id": "123"}, + } + ] + function_schemas = get_user_data_schema + + # Call the method with the test inputs + result = openai_llm._is_valid_inputs(inputs, function_schemas) + + # Assert that the method returns False due to no matching function schema + assert ( + not result + ), "The method should return False when no matching function schema is found" + + # Check that the appropriate error message was logged + expected_error_message = "No matching function schema found for function name: name_that_does_not_exist_in_schema" + mocked_logger.assert_called_once_with(expected_error_message) + + def test_is_valid_inputs_validation_failed(self, openai_llm, mocker): + # Mock the logger to capture the error messages + mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") + + # Input where 'arguments' do not meet the schema requirements + inputs = [ + {"function_name": "get_user_data", "arguments": {"user_id": 123}} + ] # user_id should be a string, not an integer function_schemas = [ { "type": "function", @@ -186,27 +371,59 @@ class TestOpenAILLM: } ] - # Mock the __call__ method to return a JSON string as expected + # Mock the _validate_single_function_inputs method to return False mocker.patch.object( - OpenAILLM, - "__call__", - return_value='[{"function_name": "get_user_data", "arguments": {"user_id": "123"}}]', + OpenAILLM, "_validate_single_function_inputs", return_value=False ) - result = openai_llm.extract_function_inputs(query, function_schemas) - # Ensure the __call__ method is called with the correct parameters - expected_messages = [ - Message( - role="system", - content="You are an intelligent AI. Given a command or request from the user, call the function to complete the request.", - ), - Message(role="user", content=query), + # Call the method with the test inputs + result = openai_llm._is_valid_inputs(inputs, function_schemas) + + # Assert that the method returns False due to validation failure + assert not result, "The method should return False when validation fails" + + # Check that the appropriate error message was logged + expected_error_message = "Validation failed for function name: get_user_data" + mocked_logger.assert_called_once_with(expected_error_message) + + def test_is_valid_inputs_exception_handling(self, openai_llm, mocker): + # Mock the logger to capture the error messages + mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") + + # Create test inputs that are valid but mock an internal method to raise an exception + inputs = [{"function_name": "get_user_data", "arguments": {"user_id": "123"}}] + function_schemas = [ + { + "type": "function", + "function": { + "name": "get_user_data", + "description": "Function to fetch user data.", + "parameters": { + "type": "object", + "properties": { + "user_id": { + "type": "string", + "description": "The ID of the user.", + } + }, + "required": ["user_id"], + }, + }, + } ] - openai_llm.__call__.assert_called_once_with( - messages=expected_messages, function_schemas=function_schemas + + # Mock a method used within _is_valid_inputs to raise an Exception + mocker.patch.object( + OpenAILLM, + "_validate_single_function_inputs", + side_effect=Exception("Test exception"), ) - # Check if the result is as expected - assert result == [ - {"function_name": "get_user_data", "arguments": {"user_id": "123"}} - ], "The function inputs should match the expected dictionary." + # Call the method with the test inputs + result = openai_llm._is_valid_inputs(inputs, function_schemas) + + # Assert that the method returns False due to exception + assert not result, "The method should return False when an exception occurs" + + # Check that the appropriate error message was logged + mocked_logger.assert_called_once_with("Input validation error: Test exception") -- GitLab