Skip to content
Snippets Groups Projects
Unverified Commit 75ee1b7d authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Linting.

parent db6cd601
No related branches found
No related tags found
No related merge requests found
...@@ -18,26 +18,31 @@ class BaseLLM(BaseModel): ...@@ -18,26 +18,31 @@ class BaseLLM(BaseModel):
def __call__(self, messages: List[Message]) -> Optional[str]: def __call__(self, messages: List[Message]) -> Optional[str]:
raise NotImplementedError("Subclasses must implement this method") raise NotImplementedError("Subclasses must implement this method")
def _check_for_mandatory_inputs(self, inputs: dict[str, Any], mandatory_params: List[str]) -> bool: def _check_for_mandatory_inputs(
self, inputs: dict[str, Any], mandatory_params: List[str]
) -> bool:
"""Check for mandatory parameters in inputs""" """Check for mandatory parameters in inputs"""
for name in mandatory_params: for name in mandatory_params:
if name not in inputs: if name not in inputs:
logger.error(f"Mandatory input {name} missing from query") logger.error(f"Mandatory input {name} missing from query")
return False return False
return True return True
def _check_for_extra_inputs(self, inputs: dict[str, Any], all_params: List[str]) -> bool: def _check_for_extra_inputs(
self, inputs: dict[str, Any], all_params: List[str]
) -> bool:
"""Check for extra parameters not defined in the signature""" """Check for extra parameters not defined in the signature"""
input_keys = set(inputs.keys()) input_keys = set(inputs.keys())
param_keys = set(all_params) param_keys = set(all_params)
if not input_keys.issubset(param_keys): if not input_keys.issubset(param_keys):
extra_keys = input_keys - param_keys extra_keys = input_keys - param_keys
logger.error(f"Extra inputs provided that are not in the signature: {extra_keys}") logger.error(
f"Extra inputs provided that are not in the signature: {extra_keys}"
)
return False return False
return True return True
def _is_valid_inputs( def _is_valid_inputs(
self, inputs: dict[str, Any], function_schema: dict[str, Any] self, inputs: dict[str, Any], function_schema: dict[str, Any]
) -> bool: ) -> bool:
...@@ -52,7 +57,7 @@ class BaseLLM(BaseModel): ...@@ -52,7 +57,7 @@ class BaseLLM(BaseModel):
for info in param_info: for info in param_info:
parts = info.split("=") parts = info.split("=")
name_type_pair = parts[0].strip() name_type_pair = parts[0].strip()
if ':' in name_type_pair: if ":" in name_type_pair:
name, _ = name_type_pair.split(":") name, _ = name_type_pair.split(":")
else: else:
name = name_type_pair name = name_type_pair
......
...@@ -8,15 +8,15 @@ class TestBaseLLM: ...@@ -8,15 +8,15 @@ class TestBaseLLM:
@pytest.fixture @pytest.fixture
def base_llm(self): def base_llm(self):
return BaseLLM(name="TestLLM") return BaseLLM(name="TestLLM")
@pytest.fixture @pytest.fixture
def mixed_function_schema(self): def mixed_function_schema(self):
return { return {
"name": "test_function", "name": "test_function",
"description": "A test function with mixed mandatory and optional parameters.", "description": "A test function with mixed mandatory and optional parameters.",
"signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')" "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')",
} }
@pytest.fixture @pytest.fixture
def mandatory_params(self): def mandatory_params(self):
return ["param1", "param2"] return ["param1", "param2"]
...@@ -87,13 +87,17 @@ class TestBaseLLM: ...@@ -87,13 +87,17 @@ class TestBaseLLM:
test_query = "What time is it in America/New_York?" test_query = "What time is it in America/New_York?"
base_llm.extract_function_inputs(test_schema, test_query) base_llm.extract_function_inputs(test_schema, test_query)
def test_mandatory_args_only(self, base_llm, mixed_function_schema): def test_mandatory_args_only(self, base_llm, mixed_function_schema):
inputs = {"mandatory1": "value1", "mandatory2": 42} inputs = {"mandatory1": "value1", "mandatory2": 42}
assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True
def test_all_args_provided(self, base_llm, mixed_function_schema): def test_all_args_provided(self, base_llm, mixed_function_schema):
inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2"} inputs = {
"mandatory1": "value1",
"mandatory2": 42,
"optional1": "opt1",
"optional2": "opt2",
}
assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True
def test_missing_mandatory_arg(self, base_llm, mixed_function_schema): def test_missing_mandatory_arg(self, base_llm, mixed_function_schema):
...@@ -101,7 +105,13 @@ class TestBaseLLM: ...@@ -101,7 +105,13 @@ class TestBaseLLM:
assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False
def test_extra_arg_provided(self, base_llm, mixed_function_schema): def test_extra_arg_provided(self, base_llm, mixed_function_schema):
inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value"} inputs = {
"mandatory1": "value1",
"mandatory2": 42,
"optional1": "opt1",
"optional2": "opt2",
"extra": "value",
}
assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False
def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params): def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params):
...@@ -118,4 +128,4 @@ class TestBaseLLM: ...@@ -118,4 +128,4 @@ class TestBaseLLM:
def test_check_for_extra_inputs_with_extras(self, base_llm, all_params): def test_check_for_extra_inputs_with_extras(self, base_llm, all_params):
inputs = {"param1": "value1", "param2": "value2", "extra_param": "extra"} inputs = {"param1": "value1", "param2": "value2", "extra_param": "extra"}
assert base_llm._check_for_extra_inputs(inputs, all_params) == False assert base_llm._check_for_extra_inputs(inputs, all_params) == False
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment