diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index a52a341551f3c95ae1a9c5e8b3b49834c082d480..3699ded0590f2798633d14305c1f1b3409655685 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -11,11 +11,13 @@ class TestBaseLLM: @pytest.fixture def mixed_function_schema(self): - return [{ - "name": "test_function", - "description": "A test function with mixed mandatory and optional parameters.", - "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')", - }] + return [ + { + "name": "test_function", + "description": "A test function with mixed mandatory and optional parameters.", + "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')", + } + ] @pytest.fixture def mandatory_params(self): @@ -96,12 +98,14 @@ class TestBaseLLM: ) # True is implied def test_all_args_provided(self, base_llm, mixed_function_schema): - inputs = [{ - "mandatory1": "value1", - "mandatory2": 42, - "optional1": "opt1", - "optional2": "opt2", - }] + inputs = [ + { + "mandatory1": "value1", + "mandatory2": 42, + "optional1": "opt1", + "optional2": "opt2", + } + ] assert base_llm._is_valid_inputs( inputs, mixed_function_schema ) # True is implied @@ -111,13 +115,15 @@ class TestBaseLLM: assert not base_llm._is_valid_inputs(inputs, mixed_function_schema) def test_extra_arg_provided(self, base_llm, mixed_function_schema): - inputs = [{ - "mandatory1": "value1", - "mandatory2": 42, - "optional1": "opt1", - "optional2": "opt2", - "extra": "value", - }] + inputs = [ + { + "mandatory1": "value1", + "mandatory2": 42, + "optional1": "opt1", + "optional2": "opt2", + "extra": "value", + } + ] assert not base_llm._is_valid_inputs(inputs, mixed_function_schema) def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params): @@ -201,7 +207,7 @@ class TestBaseLLM: mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") # Prepare inputs and a malformed function schema - test_inputs = {"timezone": "America/New_York"} + test_inputs = {"timezone": "America/New_York"} malformed_function_schema = { "name": "get_time", "description": "Finds the current time in a specific timezone.",