diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 658ed08da6a1015df839f0c349524ad955605108..639d3c194cb68bfd9200b6ac993a32902e70eeb5 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -52,7 +52,10 @@ class BaseLLM(BaseModel): for info in param_info: parts = info.split("=") name_type_pair = parts[0].strip() - name = name_type_pair.split(":")[0].strip() + if ':' in name_type_pair: + name, _ = name_type_pair.split(":") + else: + name = name_type_pair all_params.append(name) # If there is no default value, it's a mandatory parameter diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index 2208928a107575e8e5bc5306fc92b534713365f8..d5f14a109ab35b5eb2089bf21d387ffe07b9c2f5 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -4,9 +4,18 @@ from semantic_router.llms import BaseLLM class TestBaseLLM: + @pytest.fixture def base_llm(self): return BaseLLM(name="TestLLM") + + @pytest.fixture + def mixed_function_schema(self): + return { + "name": "test_function", + "description": "A test function with mixed mandatory and optional parameters.", + "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')" + } def test_base_llm_initialization(self, base_llm): assert base_llm.name == "TestLLM", "Initialization of name failed" @@ -69,3 +78,21 @@ class TestBaseLLM: } test_query = "What time is it in America/New_York?" base_llm.extract_function_inputs(test_schema, test_query) + + + def test_mandatory_args_only(self, base_llm, mixed_function_schema): + inputs = {"mandatory1": "value1", "mandatory2": 42} + assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True + + def test_all_args_provided(self, base_llm, mixed_function_schema): + inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2"} + assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True + + def test_missing_mandatory_arg(self, base_llm, mixed_function_schema): + inputs = {"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"} + assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False + + def test_extra_arg_provided(self, base_llm, mixed_function_schema): + inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value"} + assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False +