Skip to content
Snippets Groups Projects
Unverified Commit 5c60f73b authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

More PyTests.

parent d756c869
No related branches found
No related tags found
No related merge requests found
...@@ -4,7 +4,7 @@ from semantic_router.llms import BaseLLM ...@@ -4,7 +4,7 @@ from semantic_router.llms import BaseLLM
class TestBaseLLM: class TestBaseLLM:
@pytest.fixture @pytest.fixture
def base_llm(self): def base_llm(self):
return BaseLLM(name="TestLLM") return BaseLLM(name="TestLLM")
...@@ -16,6 +16,14 @@ class TestBaseLLM: ...@@ -16,6 +16,14 @@ class TestBaseLLM:
"description": "A test function with mixed mandatory and optional parameters.", "description": "A test function with mixed mandatory and optional parameters.",
"signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')" "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')"
} }
@pytest.fixture
def mandatory_params(self):
return ["param1", "param2"]
@pytest.fixture
def all_params(self):
return ["param1", "param2", "optional1"]
def test_base_llm_initialization(self, base_llm): def test_base_llm_initialization(self, base_llm):
assert base_llm.name == "TestLLM", "Initialization of name failed" assert base_llm.name == "TestLLM", "Initialization of name failed"
...@@ -96,3 +104,18 @@ class TestBaseLLM: ...@@ -96,3 +104,18 @@ class TestBaseLLM:
inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value"} inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value"}
assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False
def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params):
inputs = {"param1": "value1", "param2": "value2"}
assert base_llm._check_for_mandatory_inputs(inputs, mandatory_params) == True
def test_check_for_mandatory_inputs_missing_one(self, base_llm, mandatory_params):
inputs = {"param1": "value1"}
assert base_llm._check_for_mandatory_inputs(inputs, mandatory_params) == False
def test_check_for_extra_inputs_no_extras(self, base_llm, all_params):
inputs = {"param1": "value1", "param2": "value2"}
assert base_llm._check_for_extra_inputs(inputs, all_params) == True
def test_check_for_extra_inputs_with_extras(self, base_llm, all_params):
inputs = {"param1": "value1", "param2": "value2", "extra_param": "extra"}
assert base_llm._check_for_extra_inputs(inputs, all_params) == False
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment