Skip to content
Snippets Groups Projects
Unverified Commit 9615e08d authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

is_valid_inputs changed so it calls a function _validate_single_function_inputs

_validate_single_function_inputs  has much the same logic as the original is_valid_inputs function.
parent a9af91a7
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,7 @@ class BaseLLM(BaseModel):
def __call__(self, messages: List[Message]) -> Optional[str]:
raise NotImplementedError("Subclasses must implement this method")
def _is_valid_inputs(
self, inputs: list[dict[str, Any]], function_schemas: list[dict[str, Any]]
) -> bool:
......@@ -40,20 +40,34 @@ class BaseLLM(BaseModel):
logger.error(f"No matching function schema found for function name: {function_name}")
return False
# Extract parameter names and types from the signature string of the matching schema
param_names, param_types = self._extract_parameter_info(matching_schema["signature"])
# Validate that all required parameters are present in the arguments
for name, type_str in zip(param_names, param_types):
if name not in arguments:
logger.error(f"Input {name} missing from arguments")
return False
# Validate the inputs against the function schema
if not self._validate_single_function_inputs(arguments, matching_schema):
return False
return True
except Exception as e:
logger.error(f"Input validation error: {str(e)}")
return False
def _validate_single_function_inputs(self, inputs: dict[str, Any], function_schema: dict[str, Any]) -> bool:
"""Validate the extracted inputs against the function schema"""
try:
# Extract parameter names and types from the signature string
signature = function_schema["signature"]
param_info = [param.strip() for param in signature[1:-1].split(",")]
param_names = [info.split(":")[0].strip() for info in param_info]
param_types = [
info.split(":")[1].strip().split("=")[0].strip() for info in param_info
]
for name, type_str in zip(param_names, param_types):
if name not in inputs:
logger.error(f"Input {name} missing from query")
return False
return True
except Exception as e:
logger.error(f"Single input validation error: {str(e)}")
return False
def _extract_parameter_info(self, signature: str) -> tuple[list[str], list[str]]:
"""Extract parameter names and types from the function signature."""
param_info = [param.strip() for param in signature[1:-1].split(",")]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment