Skip to content
Snippets Groups Projects
Unverified Commit 75ee1b7d authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Linting.

parent db6cd601
No related branches found
No related tags found
No related merge requests found
......@@ -18,26 +18,31 @@ class BaseLLM(BaseModel):
def __call__(self, messages: List[Message]) -> Optional[str]:
raise NotImplementedError("Subclasses must implement this method")
def _check_for_mandatory_inputs(self, inputs: dict[str, Any], mandatory_params: List[str]) -> bool:
def _check_for_mandatory_inputs(
self, inputs: dict[str, Any], mandatory_params: List[str]
) -> bool:
"""Check for mandatory parameters in inputs"""
for name in mandatory_params:
if name not in inputs:
logger.error(f"Mandatory input {name} missing from query")
return False
return True
def _check_for_extra_inputs(self, inputs: dict[str, Any], all_params: List[str]) -> bool:
def _check_for_extra_inputs(
self, inputs: dict[str, Any], all_params: List[str]
) -> bool:
"""Check for extra parameters not defined in the signature"""
input_keys = set(inputs.keys())
param_keys = set(all_params)
if not input_keys.issubset(param_keys):
extra_keys = input_keys - param_keys
logger.error(f"Extra inputs provided that are not in the signature: {extra_keys}")
logger.error(
f"Extra inputs provided that are not in the signature: {extra_keys}"
)
return False
return True
def _is_valid_inputs(
self, inputs: dict[str, Any], function_schema: dict[str, Any]
) -> bool:
......@@ -52,7 +57,7 @@ class BaseLLM(BaseModel):
for info in param_info:
parts = info.split("=")
name_type_pair = parts[0].strip()
if ':' in name_type_pair:
if ":" in name_type_pair:
name, _ = name_type_pair.split(":")
else:
name = name_type_pair
......
......@@ -8,15 +8,15 @@ class TestBaseLLM:
@pytest.fixture
def base_llm(self):
return BaseLLM(name="TestLLM")
@pytest.fixture
def mixed_function_schema(self):
return {
"name": "test_function",
"description": "A test function with mixed mandatory and optional parameters.",
"signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')"
"signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')",
}
@pytest.fixture
def mandatory_params(self):
return ["param1", "param2"]
......@@ -87,13 +87,17 @@ class TestBaseLLM:
test_query = "What time is it in America/New_York?"
base_llm.extract_function_inputs(test_schema, test_query)
def test_mandatory_args_only(self, base_llm, mixed_function_schema):
inputs = {"mandatory1": "value1", "mandatory2": 42}
assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True
def test_all_args_provided(self, base_llm, mixed_function_schema):
inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2"}
inputs = {
"mandatory1": "value1",
"mandatory2": 42,
"optional1": "opt1",
"optional2": "opt2",
}
assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True
def test_missing_mandatory_arg(self, base_llm, mixed_function_schema):
......@@ -101,7 +105,13 @@ class TestBaseLLM:
assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False
def test_extra_arg_provided(self, base_llm, mixed_function_schema):
inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value"}
inputs = {
"mandatory1": "value1",
"mandatory2": 42,
"optional1": "opt1",
"optional2": "opt2",
"extra": "value",
}
assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False
def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params):
......@@ -118,4 +128,4 @@ class TestBaseLLM:
def test_check_for_extra_inputs_with_extras(self, base_llm, all_params):
inputs = {"param1": "value1", "param2": "value2", "extra_param": "extra"}
assert base_llm._check_for_extra_inputs(inputs, all_params) == False
\ No newline at end of file
assert base_llm._check_for_extra_inputs(inputs, all_params) == False
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment