From 75ee1b7ddb2d80aba0b57a83bb7564cad78fc5eb Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Mon, 13 May 2024 18:07:14 +0400 Subject: [PATCH] Linting. --- semantic_router/llms/base.py | 19 ++++++++++++------- tests/unit/llms/test_llm_base.py | 24 +++++++++++++++++------- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 639d3c19..604d8ad2 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -18,26 +18,31 @@ class BaseLLM(BaseModel): def __call__(self, messages: List[Message]) -> Optional[str]: raise NotImplementedError("Subclasses must implement this method") - - def _check_for_mandatory_inputs(self, inputs: dict[str, Any], mandatory_params: List[str]) -> bool: + def _check_for_mandatory_inputs( + self, inputs: dict[str, Any], mandatory_params: List[str] + ) -> bool: """Check for mandatory parameters in inputs""" for name in mandatory_params: if name not in inputs: logger.error(f"Mandatory input {name} missing from query") return False return True - - def _check_for_extra_inputs(self, inputs: dict[str, Any], all_params: List[str]) -> bool: + + def _check_for_extra_inputs( + self, inputs: dict[str, Any], all_params: List[str] + ) -> bool: """Check for extra parameters not defined in the signature""" input_keys = set(inputs.keys()) param_keys = set(all_params) if not input_keys.issubset(param_keys): extra_keys = input_keys - param_keys - logger.error(f"Extra inputs provided that are not in the signature: {extra_keys}") + logger.error( + f"Extra inputs provided that are not in the signature: {extra_keys}" + ) return False return True - + def _is_valid_inputs( self, inputs: dict[str, Any], function_schema: dict[str, Any] ) -> bool: @@ -52,7 +57,7 @@ class BaseLLM(BaseModel): for info in param_info: parts = info.split("=") name_type_pair = parts[0].strip() - if ':' in name_type_pair: + if ":" in name_type_pair: name, _ = name_type_pair.split(":") else: name = name_type_pair diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index 5e9bbe9f..7c8dbf37 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -8,15 +8,15 @@ class TestBaseLLM: @pytest.fixture def base_llm(self): return BaseLLM(name="TestLLM") - + @pytest.fixture def mixed_function_schema(self): return { "name": "test_function", "description": "A test function with mixed mandatory and optional parameters.", - "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')" + "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')", } - + @pytest.fixture def mandatory_params(self): return ["param1", "param2"] @@ -87,13 +87,17 @@ class TestBaseLLM: test_query = "What time is it in America/New_York?" base_llm.extract_function_inputs(test_schema, test_query) - def test_mandatory_args_only(self, base_llm, mixed_function_schema): inputs = {"mandatory1": "value1", "mandatory2": 42} assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True def test_all_args_provided(self, base_llm, mixed_function_schema): - inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2"} + inputs = { + "mandatory1": "value1", + "mandatory2": 42, + "optional1": "opt1", + "optional2": "opt2", + } assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True def test_missing_mandatory_arg(self, base_llm, mixed_function_schema): @@ -101,7 +105,13 @@ class TestBaseLLM: assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False def test_extra_arg_provided(self, base_llm, mixed_function_schema): - inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value"} + inputs = { + "mandatory1": "value1", + "mandatory2": 42, + "optional1": "opt1", + "optional2": "opt2", + "extra": "value", + } assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params): @@ -118,4 +128,4 @@ class TestBaseLLM: def test_check_for_extra_inputs_with_extras(self, base_llm, all_params): inputs = {"param1": "value1", "param2": "value2", "extra_param": "extra"} - assert base_llm._check_for_extra_inputs(inputs, all_params) == False \ No newline at end of file + assert base_llm._check_for_extra_inputs(inputs, all_params) == False -- GitLab