From d23817225c496ecb305cc9ccea34d5e57098515d Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Wed, 8 May 2024 04:33:03 +0400 Subject: [PATCH] More pytests. --- tests/unit/llms/test_llm_base.py | 100 ++++++++++++++++++++++++++++++- 1 file changed, 99 insertions(+), 1 deletion(-) diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index 3892afe8..680b5d2d 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -1,6 +1,6 @@ import pytest - from semantic_router.llms import BaseLLM +from unittest.mock import patch class TestBaseLLM: @@ -71,3 +71,101 @@ class TestBaseLLM: } test_query = "What time is it in America/New_York?" base_llm.extract_function_inputs(test_schema, test_query) + + def test_is_valid_inputs_multiple_inputs(self, base_llm, mocker): + # Mock the logger to capture the error messages + mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") + + # Prepare multiple sets of inputs + test_inputs = [{"timezone": "America/New_York"}, {"timezone": "Europe/London"}] + test_schemas = [ + { + "name": "get_time", + "description": "Finds the current time in a specific timezone.", + "signature": "(timezone: str) -> str", + "output": "<class 'str'>", + } + ] + + # Call the method with multiple inputs + result = base_llm._is_valid_inputs(test_inputs, test_schemas) + + # Assert that the method returns False + assert ( + not result + ), "Method should return False when multiple inputs are provided" + + # Check that the appropriate error message was logged + mocked_logger.assert_called_once_with( + "Only one set of function inputs is allowed." + ) + + def test_is_valid_inputs_exception_handling(self, base_llm, mocker): + # Mock the logger to capture the error messages + mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") + + # Use patch on the method's full path + with patch( + "semantic_router.llms.base.BaseLLM._validate_single_function_inputs", + side_effect=Exception("Test Exception"), + ): + test_inputs = [{"timezone": "America/New_York"}] + test_schemas = [ + { + "name": "get_time", + "description": "Finds the current time in a specific timezone.", + "signature": "(timezone: str) -> str", + "output": "<class 'str'>", + } + ] + + # Call the method and expect it to return False due to the exception + result = base_llm._is_valid_inputs(test_inputs, test_schemas) + + # Assert that the method returns False + assert not result, "Method should return False when an exception occurs" + + # Check that the appropriate error message was logged + mocked_logger.assert_called_once_with( + "Input validation error: Test Exception" + ) + + def test_validate_single_function_inputs_exception_handling(self, base_llm, mocker): + # Mock the logger to capture the error messages + mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") + + # Prepare inputs and a malformed function schema + test_inputs = {"timezone": "America/New_York"} + malformed_function_schema = { + "name": "get_time", + "description": "Finds the current time in a specific timezone.", + "signature": "(timezone str)", # Malformed signature missing colon + "output": "<class 'str'>", + } + + # Call the method and expect it to return False due to the exception + result = base_llm._validate_single_function_inputs( + test_inputs, malformed_function_schema + ) + + # Assert that the method returns False + assert not result, "Method should return False when an exception occurs" + + # Check that the appropriate error message was logged + expected_error_message = "Single input validation error: list index out of range" # Adjust based on the actual exception message + mocked_logger.assert_called_once_with(expected_error_message) + + def test_extract_parameter_info_valid(self, base_llm): + # Test with a valid signature + signature = "(param1: int, param2: str = 'default')" + expected_names = ["param1", "param2"] + expected_types = ["int", "str"] + param_names, param_types = base_llm._extract_parameter_info(signature) + assert param_names == expected_names, "Parameter names did not match expected" + assert param_types == expected_types, "Parameter types did not match expected" + + def test_extract_parameter_info_malformed(self, base_llm): + # Test with a malformed signature + signature = "(param1 int, param2: str = 'default')" + with pytest.raises(IndexError): + base_llm._extract_parameter_info(signature) -- GitLab