diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py
index 3d1d9b3737c80a3efa702a979de740c9d023ea4e..3170ccc2b8b438d0ddc4b8ae922df8456bc252e2 100644
--- a/semantic_router/llms/base.py
+++ b/semantic_router/llms/base.py
@@ -18,7 +18,7 @@ class BaseLLM(BaseModel):
 
     def __call__(self, messages: List[Message]) -> Optional[str]:
         raise NotImplementedError("Subclasses must implement this method")
-
+    
     def _is_valid_inputs(
         self, inputs: list[dict[str, Any]], function_schemas: list[dict[str, Any]]
     ) -> bool:
@@ -40,20 +40,34 @@ class BaseLLM(BaseModel):
                     logger.error(f"No matching function schema found for function name: {function_name}")
                     return False
 
-                # Extract parameter names and types from the signature string of the matching schema
-                param_names, param_types = self._extract_parameter_info(matching_schema["signature"])
-
-                # Validate that all required parameters are present in the arguments
-                for name, type_str in zip(param_names, param_types):
-                    if name not in arguments:
-                        logger.error(f"Input {name} missing from arguments")
-                        return False
+                # Validate the inputs against the function schema
+                if not self._validate_single_function_inputs(arguments, matching_schema):
+                    return False
 
             return True
         except Exception as e:
             logger.error(f"Input validation error: {str(e)}")
             return False
 
+    def _validate_single_function_inputs(self, inputs: dict[str, Any], function_schema: dict[str, Any]) -> bool:
+        """Validate the extracted inputs against the function schema"""
+        try:
+            # Extract parameter names and types from the signature string
+            signature = function_schema["signature"]
+            param_info = [param.strip() for param in signature[1:-1].split(",")]
+            param_names = [info.split(":")[0].strip() for info in param_info]
+            param_types = [
+                info.split(":")[1].strip().split("=")[0].strip() for info in param_info
+            ]
+            for name, type_str in zip(param_names, param_types):
+                if name not in inputs:
+                    logger.error(f"Input {name} missing from query")
+                    return False
+            return True
+        except Exception as e:
+            logger.error(f"Single input validation error: {str(e)}")
+            return False
+
     def _extract_parameter_info(self, signature: str) -> tuple[list[str], list[str]]:
         """Extract parameter names and types from the function signature."""
         param_info = [param.strip() for param in signature[1:-1].split(",")]