diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index d5f14a109ab35b5eb2089bf21d387ffe07b9c2f5..5e9bbe9f8a3a9d2419e23b92d1a0fcd1c3d43616 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -4,7 +4,7 @@ from semantic_router.llms import BaseLLM class TestBaseLLM: - + @pytest.fixture def base_llm(self): return BaseLLM(name="TestLLM") @@ -16,6 +16,14 @@ class TestBaseLLM: "description": "A test function with mixed mandatory and optional parameters.", "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')" } + + @pytest.fixture + def mandatory_params(self): + return ["param1", "param2"] + + @pytest.fixture + def all_params(self): + return ["param1", "param2", "optional1"] def test_base_llm_initialization(self, base_llm): assert base_llm.name == "TestLLM", "Initialization of name failed" @@ -96,3 +104,18 @@ class TestBaseLLM: inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value"} assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False + def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params): + inputs = {"param1": "value1", "param2": "value2"} + assert base_llm._check_for_mandatory_inputs(inputs, mandatory_params) == True + + def test_check_for_mandatory_inputs_missing_one(self, base_llm, mandatory_params): + inputs = {"param1": "value1"} + assert base_llm._check_for_mandatory_inputs(inputs, mandatory_params) == False + + def test_check_for_extra_inputs_no_extras(self, base_llm, all_params): + inputs = {"param1": "value1", "param2": "value2"} + assert base_llm._check_for_extra_inputs(inputs, all_params) == True + + def test_check_for_extra_inputs_with_extras(self, base_llm, all_params): + inputs = {"param1": "value1", "param2": "value2", "extra_param": "extra"} + assert base_llm._check_for_extra_inputs(inputs, all_params) == False \ No newline at end of file