diff --git a/tests/unit/llms/test_llm_openai.py b/tests/unit/llms/test_llm_openai.py index 04b811a05b39ad6124214ba6a4d27b1f58b8af1d..c5f987104e5fd097a46c92e057867f3bae406869 100644 --- a/tests/unit/llms/test_llm_openai.py +++ b/tests/unit/llms/test_llm_openai.py @@ -30,6 +30,15 @@ get_user_data_schema = [ } ] +example_function_schema = { + "parameters": { + "properties": { + "user_id": {"type": "string", "description": "The ID of the user."} + }, + "required": ["user_id"], + } +} + class TestOpenAILLM: def test_openai_llm_init_with_api_key(self, openai_llm): @@ -351,25 +360,7 @@ class TestOpenAILLM: inputs = [ {"function_name": "get_user_data", "arguments": {"user_id": 123}} ] # user_id should be a string, not an integer - function_schemas = [ - { - "type": "function", - "function": { - "name": "get_user_data", - "description": "Function to fetch user data.", - "parameters": { - "type": "object", - "properties": { - "user_id": { - "type": "string", - "description": "The ID of the user.", - } - }, - "required": ["user_id"], - }, - }, - } - ] + function_schemas = get_user_data_schema # Mock the _validate_single_function_inputs method to return False mocker.patch.object( @@ -392,25 +383,7 @@ class TestOpenAILLM: # Create test inputs that are valid but mock an internal method to raise an exception inputs = [{"function_name": "get_user_data", "arguments": {"user_id": "123"}}] - function_schemas = [ - { - "type": "function", - "function": { - "name": "get_user_data", - "description": "Function to fetch user data.", - "parameters": { - "type": "object", - "properties": { - "user_id": { - "type": "string", - "description": "The ID of the user.", - } - }, - "required": ["user_id"], - }, - }, - } - ] + function_schemas = get_user_data_schema # Mock a method used within _is_valid_inputs to raise an Exception mocker.patch.object( @@ -427,3 +400,74 @@ class TestOpenAILLM: # Check that the appropriate error message was logged mocked_logger.assert_called_once_with("Input validation error: Test exception") + + def test_validate_single_function_inputs_missing_required_param( + self, openai_llm, mocker + ): + # Mock the logger to capture the error messages + mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") + + # Define the function schema with a required parameter + function_schema = example_function_schema + + # Input dictionary missing the required 'user_id' parameter + inputs = {} + + # Call the method with the test inputs + result = openai_llm._validate_single_function_inputs(inputs, function_schema) + + # Assert that the method returns False due to missing required parameter + assert ( + not result + ), "The method should return False when a required parameter is missing" + + # Check that the appropriate error message was logged + expected_error_message = "Required input 'user_id' missing from query" + mocked_logger.assert_called_once_with(expected_error_message) + + def test_validate_single_function_inputs_incorrect_type(self, openai_llm, mocker): + # Mock the logger to capture the error messages + mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") + + # Define the function schema with type specifications + function_schema = example_function_schema + + # Input dictionary with incorrect type for 'user_id' + inputs = {"user_id": 123} # user_id should be a string, not an integer + + # Call the method with the test inputs + result = openai_llm._validate_single_function_inputs(inputs, function_schema) + + # Assert that the method returns False due to incorrect type + assert not result, "The method should return False when input type is incorrect" + + # Check that the appropriate error message was logged + expected_error_message = "Input type for 'user_id' is not string" + mocked_logger.assert_called_once_with(expected_error_message) + + def test_validate_single_function_inputs_exception_handling( + self, openai_llm, mocker + ): + # Mock the logger to capture the error messages + mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") + + # Create a custom class that raises an exception when any attribute is accessed + class SchemaSimulator: + def __getitem__(self, item): + raise Exception("Test exception") + + # Replace the function_schema with an instance of this custom class + function_schema = SchemaSimulator() + + # Call the method with the test inputs + result = openai_llm._validate_single_function_inputs( + {"user_id": "123"}, function_schema + ) + + # Assert that the method returns False due to exception + assert not result, "The method should return False when an exception occurs" + + # Check that the appropriate error message was logged + mocked_logger.assert_called_once_with( + "Single input validation error: Test exception" + )