From 5c60f73bd1d400cc999b7160823c4af62ff4dd52 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Mon, 13 May 2024 17:51:00 +0400
Subject: [PATCH] More PyTests.

---
 tests/unit/llms/test_llm_base.py | 25 ++++++++++++++++++++++++-
 1 file changed, 24 insertions(+), 1 deletion(-)

diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py
index d5f14a10..5e9bbe9f 100644
--- a/tests/unit/llms/test_llm_base.py
+++ b/tests/unit/llms/test_llm_base.py
@@ -4,7 +4,7 @@ from semantic_router.llms import BaseLLM
 
 
 class TestBaseLLM:
-    
+
     @pytest.fixture
     def base_llm(self):
         return BaseLLM(name="TestLLM")
@@ -16,6 +16,14 @@ class TestBaseLLM:
             "description": "A test function with mixed mandatory and optional parameters.",
             "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')"
         }
+    
+    @pytest.fixture
+    def mandatory_params(self):
+        return ["param1", "param2"]
+
+    @pytest.fixture
+    def all_params(self):
+        return ["param1", "param2", "optional1"]
 
     def test_base_llm_initialization(self, base_llm):
         assert base_llm.name == "TestLLM", "Initialization of name failed"
@@ -96,3 +104,18 @@ class TestBaseLLM:
         inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value"}
         assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False
 
+    def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params):
+        inputs = {"param1": "value1", "param2": "value2"}
+        assert base_llm._check_for_mandatory_inputs(inputs, mandatory_params) == True
+
+    def test_check_for_mandatory_inputs_missing_one(self, base_llm, mandatory_params):
+        inputs = {"param1": "value1"}
+        assert base_llm._check_for_mandatory_inputs(inputs, mandatory_params) == False
+
+    def test_check_for_extra_inputs_no_extras(self, base_llm, all_params):
+        inputs = {"param1": "value1", "param2": "value2"}
+        assert base_llm._check_for_extra_inputs(inputs, all_params) == True
+
+    def test_check_for_extra_inputs_with_extras(self, base_llm, all_params):
+        inputs = {"param1": "value1", "param2": "value2", "extra_param": "extra"}
+        assert base_llm._check_for_extra_inputs(inputs, all_params) == False
\ No newline at end of file
-- 
GitLab