Skip to content
Snippets Groups Projects
Unverified Commit 7b74a91d authored by James Briggs's avatar James Briggs Committed by GitHub
Browse files

Merge branch 'main' into james/v0.0.41

parents ccd98a79 d3ad005a
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,30 @@ class BaseLLM(BaseModel):
def __call__(self, messages: List[Message]) -> Optional[str]:
raise NotImplementedError("Subclasses must implement this method")
def _check_for_mandatory_inputs(
self, inputs: dict[str, Any], mandatory_params: List[str]
) -> bool:
"""Check for mandatory parameters in inputs"""
for name in mandatory_params:
if name not in inputs:
logger.error(f"Mandatory input {name} missing from query")
return False
return True
def _check_for_extra_inputs(
self, inputs: dict[str, Any], all_params: List[str]
) -> bool:
"""Check for extra parameters not defined in the signature"""
input_keys = set(inputs.keys())
param_keys = set(all_params)
if not input_keys.issubset(param_keys):
extra_keys = input_keys - param_keys
logger.error(
f"Extra inputs provided that are not in the signature: {extra_keys}"
)
return False
return True
def _is_valid_inputs(
self, inputs: List[Dict[str, Any]], function_schemas: List[Dict[str, Any]]
) -> bool:
......@@ -48,17 +72,33 @@ class BaseLLM(BaseModel):
) -> bool:
"""Validate the extracted inputs against the function schema"""
try:
# Extract parameter names and types from the signature string
# Extract parameter names and determine if they are optional
signature = function_schema["signature"]
param_info = [param.strip() for param in signature[1:-1].split(",")]
param_names = [info.split(":")[0].strip() for info in param_info]
param_types = [
info.split(":")[1].strip().split("=")[0].strip() for info in param_info
]
for name, type_str in zip(param_names, param_types):
if name not in inputs:
logger.error(f"Input {name} missing from query")
return False
mandatory_params = []
all_params = []
for info in param_info:
parts = info.split("=")
name_type_pair = parts[0].strip()
if ":" in name_type_pair:
name, _ = name_type_pair.split(":")
else:
name = name_type_pair
all_params.append(name)
# If there is no default value, it's a mandatory parameter
if len(parts) == 1:
mandatory_params.append(name)
# Check for mandatory parameters
if not self._check_for_mandatory_inputs(inputs, mandatory_params):
return False
# Check for extra parameters not defined in the signature
if not self._check_for_extra_inputs(inputs, all_params):
return False
return True
except Exception as e:
logger.error(f"Single input validation error: {str(e)}")
......@@ -124,7 +164,7 @@ Return only JSON, stating the argument names and their corresponding values.
=== EXAMPLE_OUTPUT End ===
### EXAMPLE End ###
Note: I will tip $500 for and accurate JSON output. You will be penalized for an inaccurate JSON output.
Note: I will tip $500 for an accurate JSON output. You will be penalized for an inaccurate JSON output.
Provide JSON output now:
"""
......
......@@ -4,10 +4,29 @@ from unittest.mock import patch
class TestBaseLLM:
@pytest.fixture
def base_llm(self):
return BaseLLM(name="TestLLM")
@pytest.fixture
def mixed_function_schema(self):
return [
{
"name": "test_function",
"description": "A test function with mixed mandatory and optional parameters.",
"signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')",
}
]
@pytest.fixture
def mandatory_params(self):
return ["param1", "param2"]
@pytest.fixture
def all_params(self):
return ["param1", "param2", "optional1"]
def test_base_llm_initialization(self, base_llm):
assert base_llm.name == "TestLLM", "Initialization of name failed"
......@@ -72,6 +91,59 @@ class TestBaseLLM:
test_query = "What time is it in America/New_York?"
base_llm.extract_function_inputs(test_schema, test_query)
def test_mandatory_args_only(self, base_llm, mixed_function_schema):
inputs = [{"mandatory1": "value1", "mandatory2": 42}]
assert base_llm._is_valid_inputs(
inputs, mixed_function_schema
) # True is implied
def test_all_args_provided(self, base_llm, mixed_function_schema):
inputs = [
{
"mandatory1": "value1",
"mandatory2": 42,
"optional1": "opt1",
"optional2": "opt2",
}
]
assert base_llm._is_valid_inputs(
inputs, mixed_function_schema
) # True is implied
def test_missing_mandatory_arg(self, base_llm, mixed_function_schema):
inputs = [{"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"}]
assert not base_llm._is_valid_inputs(inputs, mixed_function_schema)
def test_extra_arg_provided(self, base_llm, mixed_function_schema):
inputs = [
{
"mandatory1": "value1",
"mandatory2": 42,
"optional1": "opt1",
"optional2": "opt2",
"extra": "value",
}
]
assert not base_llm._is_valid_inputs(inputs, mixed_function_schema)
def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params):
inputs = {"param1": "value1", "param2": "value2"}
assert base_llm._check_for_mandatory_inputs(
inputs, mandatory_params
) # True is implied
def test_check_for_mandatory_inputs_missing_one(self, base_llm, mandatory_params):
inputs = {"param1": "value1"}
assert not base_llm._check_for_mandatory_inputs(inputs, mandatory_params)
def test_check_for_extra_inputs_no_extras(self, base_llm, all_params):
inputs = {"param1": "value1", "param2": "value2"}
assert base_llm._check_for_extra_inputs(inputs, all_params) # True is implied
def test_check_for_extra_inputs_with_extras(self, base_llm, all_params):
inputs = {"param1": "value1", "param2": "value2", "extra_param": "extra"}
assert not base_llm._check_for_extra_inputs(inputs, all_params)
def test_is_valid_inputs_multiple_inputs(self, base_llm, mocker):
# Mock the logger to capture the error messages
mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error")
......@@ -139,7 +211,7 @@ class TestBaseLLM:
malformed_function_schema = {
"name": "get_time",
"description": "Finds the current time in a specific timezone.",
"signature": "(timezone str)", # Malformed signature missing colon
"signiture": "(timezone: str)", # Malformed key name
"output": "<class 'str'>",
}
......@@ -152,7 +224,7 @@ class TestBaseLLM:
assert not result, "Method should return False when an exception occurs"
# Check that the appropriate error message was logged
expected_error_message = "Single input validation error: list index out of range" # Adjust based on the actual exception message
expected_error_message = "Single input validation error: 'signature'" # Adjust based on the actual exception message
mocked_logger.assert_called_once_with(expected_error_message)
def test_extract_parameter_info_valid(self, base_llm):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment