From d756c8694c433eb0f514f15e28b8685ff38a0485 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Mon, 13 May 2024 17:49:16 +0400
Subject: [PATCH] PyTests and bug fix.

---
 semantic_router/llms/base.py     |  5 ++++-
 tests/unit/llms/test_llm_base.py | 27 +++++++++++++++++++++++++++
 2 files changed, 31 insertions(+), 1 deletion(-)

diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py
index 658ed08d..639d3c19 100644
--- a/semantic_router/llms/base.py
+++ b/semantic_router/llms/base.py
@@ -52,7 +52,10 @@ class BaseLLM(BaseModel):
             for info in param_info:
                 parts = info.split("=")
                 name_type_pair = parts[0].strip()
-                name = name_type_pair.split(":")[0].strip()
+                if ':' in name_type_pair:
+                    name, _ = name_type_pair.split(":")
+                else:
+                    name = name_type_pair
                 all_params.append(name)
 
                 # If there is no default value, it's a mandatory parameter
diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py
index 2208928a..d5f14a10 100644
--- a/tests/unit/llms/test_llm_base.py
+++ b/tests/unit/llms/test_llm_base.py
@@ -4,9 +4,18 @@ from semantic_router.llms import BaseLLM
 
 
 class TestBaseLLM:
+    
     @pytest.fixture
     def base_llm(self):
         return BaseLLM(name="TestLLM")
+    
+    @pytest.fixture
+    def mixed_function_schema(self):
+        return {
+            "name": "test_function",
+            "description": "A test function with mixed mandatory and optional parameters.",
+            "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')"
+        }
 
     def test_base_llm_initialization(self, base_llm):
         assert base_llm.name == "TestLLM", "Initialization of name failed"
@@ -69,3 +78,21 @@ class TestBaseLLM:
             }
             test_query = "What time is it in America/New_York?"
             base_llm.extract_function_inputs(test_schema, test_query)
+
+
+    def test_mandatory_args_only(self, base_llm, mixed_function_schema):
+        inputs = {"mandatory1": "value1", "mandatory2": 42}
+        assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True
+
+    def test_all_args_provided(self, base_llm, mixed_function_schema):
+        inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2"}
+        assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True
+
+    def test_missing_mandatory_arg(self, base_llm, mixed_function_schema):
+        inputs = {"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"}
+        assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False
+
+    def test_extra_arg_provided(self, base_llm, mixed_function_schema):
+        inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value"}
+        assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False
+
-- 
GitLab