Skip to content
Snippets Groups Projects
Unverified Commit b040c168 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Fixed pytests following resolution of conflicts with main.

parent f9902a7c
No related branches found
No related tags found
No related merge requests found
...@@ -11,11 +11,11 @@ class TestBaseLLM: ...@@ -11,11 +11,11 @@ class TestBaseLLM:
@pytest.fixture @pytest.fixture
def mixed_function_schema(self): def mixed_function_schema(self):
return { return [{
"name": "test_function", "name": "test_function",
"description": "A test function with mixed mandatory and optional parameters.", "description": "A test function with mixed mandatory and optional parameters.",
"signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')", "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')",
} }]
@pytest.fixture @pytest.fixture
def mandatory_params(self): def mandatory_params(self):
...@@ -90,34 +90,34 @@ class TestBaseLLM: ...@@ -90,34 +90,34 @@ class TestBaseLLM:
base_llm.extract_function_inputs(test_schema, test_query) base_llm.extract_function_inputs(test_schema, test_query)
def test_mandatory_args_only(self, base_llm, mixed_function_schema): def test_mandatory_args_only(self, base_llm, mixed_function_schema):
inputs = {"mandatory1": "value1", "mandatory2": 42} inputs = [{"mandatory1": "value1", "mandatory2": 42}]
assert base_llm._is_valid_inputs( assert base_llm._is_valid_inputs(
inputs, mixed_function_schema inputs, mixed_function_schema
) # True is implied ) # True is implied
def test_all_args_provided(self, base_llm, mixed_function_schema): def test_all_args_provided(self, base_llm, mixed_function_schema):
inputs = { inputs = [{
"mandatory1": "value1", "mandatory1": "value1",
"mandatory2": 42, "mandatory2": 42,
"optional1": "opt1", "optional1": "opt1",
"optional2": "opt2", "optional2": "opt2",
} }]
assert base_llm._is_valid_inputs( assert base_llm._is_valid_inputs(
inputs, mixed_function_schema inputs, mixed_function_schema
) # True is implied ) # True is implied
def test_missing_mandatory_arg(self, base_llm, mixed_function_schema): def test_missing_mandatory_arg(self, base_llm, mixed_function_schema):
inputs = {"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"} inputs = [{"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"}]
assert not base_llm._is_valid_inputs(inputs, mixed_function_schema) assert not base_llm._is_valid_inputs(inputs, mixed_function_schema)
def test_extra_arg_provided(self, base_llm, mixed_function_schema): def test_extra_arg_provided(self, base_llm, mixed_function_schema):
inputs = { inputs = [{
"mandatory1": "value1", "mandatory1": "value1",
"mandatory2": 42, "mandatory2": 42,
"optional1": "opt1", "optional1": "opt1",
"optional2": "opt2", "optional2": "opt2",
"extra": "value", "extra": "value",
} }]
assert not base_llm._is_valid_inputs(inputs, mixed_function_schema) assert not base_llm._is_valid_inputs(inputs, mixed_function_schema)
def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params): def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params):
...@@ -201,11 +201,11 @@ class TestBaseLLM: ...@@ -201,11 +201,11 @@ class TestBaseLLM:
mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error")
# Prepare inputs and a malformed function schema # Prepare inputs and a malformed function schema
test_inputs = {"timezone": "America/New_York"} test_inputs = {"timezone": "America/New_York"}
malformed_function_schema = { malformed_function_schema = {
"name": "get_time", "name": "get_time",
"description": "Finds the current time in a specific timezone.", "description": "Finds the current time in a specific timezone.",
"signature": "(timezone str)", # Malformed signature missing colon "signiture": "(timezone: str)", # Malformed key name
"output": "<class 'str'>", "output": "<class 'str'>",
} }
...@@ -218,7 +218,7 @@ class TestBaseLLM: ...@@ -218,7 +218,7 @@ class TestBaseLLM:
assert not result, "Method should return False when an exception occurs" assert not result, "Method should return False when an exception occurs"
# Check that the appropriate error message was logged # Check that the appropriate error message was logged
expected_error_message = "Single input validation error: list index out of range" # Adjust based on the actual exception message expected_error_message = "Single input validation error: 'signature'" # Adjust based on the actual exception message
mocked_logger.assert_called_once_with(expected_error_message) mocked_logger.assert_called_once_with(expected_error_message)
def test_extract_parameter_info_valid(self, base_llm): def test_extract_parameter_info_valid(self, base_llm):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment