From d756c8694c433eb0f514f15e28b8685ff38a0485 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Mon, 13 May 2024 17:49:16 +0400 Subject: [PATCH] PyTests and bug fix. --- semantic_router/llms/base.py | 5 ++++- tests/unit/llms/test_llm_base.py | 27 +++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 658ed08d..639d3c19 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -52,7 +52,10 @@ class BaseLLM(BaseModel): for info in param_info: parts = info.split("=") name_type_pair = parts[0].strip() - name = name_type_pair.split(":")[0].strip() + if ':' in name_type_pair: + name, _ = name_type_pair.split(":") + else: + name = name_type_pair all_params.append(name) # If there is no default value, it's a mandatory parameter diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index 2208928a..d5f14a10 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -4,9 +4,18 @@ from semantic_router.llms import BaseLLM class TestBaseLLM: + @pytest.fixture def base_llm(self): return BaseLLM(name="TestLLM") + + @pytest.fixture + def mixed_function_schema(self): + return { + "name": "test_function", + "description": "A test function with mixed mandatory and optional parameters.", + "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')" + } def test_base_llm_initialization(self, base_llm): assert base_llm.name == "TestLLM", "Initialization of name failed" @@ -69,3 +78,21 @@ class TestBaseLLM: } test_query = "What time is it in America/New_York?" base_llm.extract_function_inputs(test_schema, test_query) + + + def test_mandatory_args_only(self, base_llm, mixed_function_schema): + inputs = {"mandatory1": "value1", "mandatory2": 42} + assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True + + def test_all_args_provided(self, base_llm, mixed_function_schema): + inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2"} + assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True + + def test_missing_mandatory_arg(self, base_llm, mixed_function_schema): + inputs = {"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"} + assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False + + def test_extra_arg_provided(self, base_llm, mixed_function_schema): + inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value"} + assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False + -- GitLab