diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py
index 5ee56213d9b2482752f167b8a28895098ca41585..bbd39b4ef4fdfe2fa6527cba6419d91137fd7946 100644
--- a/semantic_router/llms/base.py
+++ b/semantic_router/llms/base.py
@@ -19,6 +19,30 @@ class BaseLLM(BaseModel):
     def __call__(self, messages: List[Message]) -> Optional[str]:
         raise NotImplementedError("Subclasses must implement this method")
 
+    def _check_for_mandatory_inputs(
+        self, inputs: dict[str, Any], mandatory_params: List[str]
+    ) -> bool:
+        """Check for mandatory parameters in inputs"""
+        for name in mandatory_params:
+            if name not in inputs:
+                logger.error(f"Mandatory input {name} missing from query")
+                return False
+        return True
+
+    def _check_for_extra_inputs(
+        self, inputs: dict[str, Any], all_params: List[str]
+    ) -> bool:
+        """Check for extra parameters not defined in the signature"""
+        input_keys = set(inputs.keys())
+        param_keys = set(all_params)
+        if not input_keys.issubset(param_keys):
+            extra_keys = input_keys - param_keys
+            logger.error(
+                f"Extra inputs provided that are not in the signature: {extra_keys}"
+            )
+            return False
+        return True
+
     def _is_valid_inputs(
         self, inputs: List[Dict[str, Any]], function_schemas: List[Dict[str, Any]]
     ) -> bool:
@@ -48,17 +72,33 @@ class BaseLLM(BaseModel):
     ) -> bool:
         """Validate the extracted inputs against the function schema"""
         try:
-            # Extract parameter names and types from the signature string
+            # Extract parameter names and determine if they are optional
             signature = function_schema["signature"]
             param_info = [param.strip() for param in signature[1:-1].split(",")]
-            param_names = [info.split(":")[0].strip() for info in param_info]
-            param_types = [
-                info.split(":")[1].strip().split("=")[0].strip() for info in param_info
-            ]
-            for name, type_str in zip(param_names, param_types):
-                if name not in inputs:
-                    logger.error(f"Input {name} missing from query")
-                    return False
+            mandatory_params = []
+            all_params = []
+
+            for info in param_info:
+                parts = info.split("=")
+                name_type_pair = parts[0].strip()
+                if ":" in name_type_pair:
+                    name, _ = name_type_pair.split(":")
+                else:
+                    name = name_type_pair
+                all_params.append(name)
+
+                # If there is no default value, it's a mandatory parameter
+                if len(parts) == 1:
+                    mandatory_params.append(name)
+
+            # Check for mandatory parameters
+            if not self._check_for_mandatory_inputs(inputs, mandatory_params):
+                return False
+
+            # Check for extra parameters not defined in the signature
+            if not self._check_for_extra_inputs(inputs, all_params):
+                return False
+
             return True
         except Exception as e:
             logger.error(f"Single input validation error: {str(e)}")
@@ -124,7 +164,7 @@ Return only JSON, stating the argument names and their corresponding values.
 	=== EXAMPLE_OUTPUT End ===
 ### EXAMPLE End ###
 
-Note: I will tip $500 for and accurate JSON output. You will be penalized for an inaccurate JSON output.
+Note: I will tip $500 for an accurate JSON output. You will be penalized for an inaccurate JSON output.
 
 Provide JSON output now:
 """
diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py
index 680b5d2d0ccd3098272230a8a7f2ebcf79a703d4..3699ded0590f2798633d14305c1f1b3409655685 100644
--- a/tests/unit/llms/test_llm_base.py
+++ b/tests/unit/llms/test_llm_base.py
@@ -4,10 +4,29 @@ from unittest.mock import patch
 
 
 class TestBaseLLM:
+
     @pytest.fixture
     def base_llm(self):
         return BaseLLM(name="TestLLM")
 
+    @pytest.fixture
+    def mixed_function_schema(self):
+        return [
+            {
+                "name": "test_function",
+                "description": "A test function with mixed mandatory and optional parameters.",
+                "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')",
+            }
+        ]
+
+    @pytest.fixture
+    def mandatory_params(self):
+        return ["param1", "param2"]
+
+    @pytest.fixture
+    def all_params(self):
+        return ["param1", "param2", "optional1"]
+
     def test_base_llm_initialization(self, base_llm):
         assert base_llm.name == "TestLLM", "Initialization of name failed"
 
@@ -72,6 +91,59 @@ class TestBaseLLM:
             test_query = "What time is it in America/New_York?"
             base_llm.extract_function_inputs(test_schema, test_query)
 
+    def test_mandatory_args_only(self, base_llm, mixed_function_schema):
+        inputs = [{"mandatory1": "value1", "mandatory2": 42}]
+        assert base_llm._is_valid_inputs(
+            inputs, mixed_function_schema
+        )  # True is implied
+
+    def test_all_args_provided(self, base_llm, mixed_function_schema):
+        inputs = [
+            {
+                "mandatory1": "value1",
+                "mandatory2": 42,
+                "optional1": "opt1",
+                "optional2": "opt2",
+            }
+        ]
+        assert base_llm._is_valid_inputs(
+            inputs, mixed_function_schema
+        )  # True is implied
+
+    def test_missing_mandatory_arg(self, base_llm, mixed_function_schema):
+        inputs = [{"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"}]
+        assert not base_llm._is_valid_inputs(inputs, mixed_function_schema)
+
+    def test_extra_arg_provided(self, base_llm, mixed_function_schema):
+        inputs = [
+            {
+                "mandatory1": "value1",
+                "mandatory2": 42,
+                "optional1": "opt1",
+                "optional2": "opt2",
+                "extra": "value",
+            }
+        ]
+        assert not base_llm._is_valid_inputs(inputs, mixed_function_schema)
+
+    def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params):
+        inputs = {"param1": "value1", "param2": "value2"}
+        assert base_llm._check_for_mandatory_inputs(
+            inputs, mandatory_params
+        )  # True is implied
+
+    def test_check_for_mandatory_inputs_missing_one(self, base_llm, mandatory_params):
+        inputs = {"param1": "value1"}
+        assert not base_llm._check_for_mandatory_inputs(inputs, mandatory_params)
+
+    def test_check_for_extra_inputs_no_extras(self, base_llm, all_params):
+        inputs = {"param1": "value1", "param2": "value2"}
+        assert base_llm._check_for_extra_inputs(inputs, all_params)  # True is implied
+
+    def test_check_for_extra_inputs_with_extras(self, base_llm, all_params):
+        inputs = {"param1": "value1", "param2": "value2", "extra_param": "extra"}
+        assert not base_llm._check_for_extra_inputs(inputs, all_params)
+
     def test_is_valid_inputs_multiple_inputs(self, base_llm, mocker):
         # Mock the logger to capture the error messages
         mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error")
@@ -139,7 +211,7 @@ class TestBaseLLM:
         malformed_function_schema = {
             "name": "get_time",
             "description": "Finds the current time in a specific timezone.",
-            "signature": "(timezone str)",  # Malformed signature missing colon
+            "signiture": "(timezone: str)",  # Malformed key name
             "output": "<class 'str'>",
         }
 
@@ -152,7 +224,7 @@ class TestBaseLLM:
         assert not result, "Method should return False when an exception occurs"
 
         # Check that the appropriate error message was logged
-        expected_error_message = "Single input validation error: list index out of range"  # Adjust based on the actual exception message
+        expected_error_message = "Single input validation error: 'signature'"  # Adjust based on the actual exception message
         mocked_logger.assert_called_once_with(expected_error_message)
 
     def test_extract_parameter_info_valid(self, base_llm):