diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index 38d2ee731971f4d559db273ee2380d97d3d776ac..a52a341551f3c95ae1a9c5e8b3b49834c082d480 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -11,11 +11,11 @@ class TestBaseLLM: @pytest.fixture def mixed_function_schema(self): - return { + return [{ "name": "test_function", "description": "A test function with mixed mandatory and optional parameters.", "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')", - } + }] @pytest.fixture def mandatory_params(self): @@ -90,34 +90,34 @@ class TestBaseLLM: base_llm.extract_function_inputs(test_schema, test_query) def test_mandatory_args_only(self, base_llm, mixed_function_schema): - inputs = {"mandatory1": "value1", "mandatory2": 42} + inputs = [{"mandatory1": "value1", "mandatory2": 42}] assert base_llm._is_valid_inputs( inputs, mixed_function_schema ) # True is implied def test_all_args_provided(self, base_llm, mixed_function_schema): - inputs = { + inputs = [{ "mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", - } + }] assert base_llm._is_valid_inputs( inputs, mixed_function_schema ) # True is implied def test_missing_mandatory_arg(self, base_llm, mixed_function_schema): - inputs = {"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"} + inputs = [{"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"}] assert not base_llm._is_valid_inputs(inputs, mixed_function_schema) def test_extra_arg_provided(self, base_llm, mixed_function_schema): - inputs = { + inputs = [{ "mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value", - } + }] assert not base_llm._is_valid_inputs(inputs, mixed_function_schema) def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params): @@ -201,11 +201,11 @@ class TestBaseLLM: mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") # Prepare inputs and a malformed function schema - test_inputs = {"timezone": "America/New_York"} + test_inputs = {"timezone": "America/New_York"} malformed_function_schema = { "name": "get_time", "description": "Finds the current time in a specific timezone.", - "signature": "(timezone str)", # Malformed signature missing colon + "signiture": "(timezone: str)", # Malformed key name "output": "<class 'str'>", } @@ -218,7 +218,7 @@ class TestBaseLLM: assert not result, "Method should return False when an exception occurs" # Check that the appropriate error message was logged - expected_error_message = "Single input validation error: list index out of range" # Adjust based on the actual exception message + expected_error_message = "Single input validation error: 'signature'" # Adjust based on the actual exception message mocked_logger.assert_called_once_with(expected_error_message) def test_extract_parameter_info_valid(self, base_llm):