diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 3d1d9b3737c80a3efa702a979de740c9d023ea4e..3170ccc2b8b438d0ddc4b8ae922df8456bc252e2 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -18,7 +18,7 @@ class BaseLLM(BaseModel): def __call__(self, messages: List[Message]) -> Optional[str]: raise NotImplementedError("Subclasses must implement this method") - + def _is_valid_inputs( self, inputs: list[dict[str, Any]], function_schemas: list[dict[str, Any]] ) -> bool: @@ -40,20 +40,34 @@ class BaseLLM(BaseModel): logger.error(f"No matching function schema found for function name: {function_name}") return False - # Extract parameter names and types from the signature string of the matching schema - param_names, param_types = self._extract_parameter_info(matching_schema["signature"]) - - # Validate that all required parameters are present in the arguments - for name, type_str in zip(param_names, param_types): - if name not in arguments: - logger.error(f"Input {name} missing from arguments") - return False + # Validate the inputs against the function schema + if not self._validate_single_function_inputs(arguments, matching_schema): + return False return True except Exception as e: logger.error(f"Input validation error: {str(e)}") return False + def _validate_single_function_inputs(self, inputs: dict[str, Any], function_schema: dict[str, Any]) -> bool: + """Validate the extracted inputs against the function schema""" + try: + # Extract parameter names and types from the signature string + signature = function_schema["signature"] + param_info = [param.strip() for param in signature[1:-1].split(",")] + param_names = [info.split(":")[0].strip() for info in param_info] + param_types = [ + info.split(":")[1].strip().split("=")[0].strip() for info in param_info + ] + for name, type_str in zip(param_names, param_types): + if name not in inputs: + logger.error(f"Input {name} missing from query") + return False + return True + except Exception as e: + logger.error(f"Single input validation error: {str(e)}") + return False + def _extract_parameter_info(self, signature: str) -> tuple[list[str], list[str]]: """Extract parameter names and types from the function signature.""" param_info = [param.strip() for param in signature[1:-1].split(",")]