From 75ee1b7ddb2d80aba0b57a83bb7564cad78fc5eb Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Mon, 13 May 2024 18:07:14 +0400
Subject: [PATCH] Linting.

---
 semantic_router/llms/base.py     | 19 ++++++++++++-------
 tests/unit/llms/test_llm_base.py | 24 +++++++++++++++++-------
 2 files changed, 29 insertions(+), 14 deletions(-)

diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py
index 639d3c19..604d8ad2 100644
--- a/semantic_router/llms/base.py
+++ b/semantic_router/llms/base.py
@@ -18,26 +18,31 @@ class BaseLLM(BaseModel):
 
     def __call__(self, messages: List[Message]) -> Optional[str]:
         raise NotImplementedError("Subclasses must implement this method")
-    
 
-    def _check_for_mandatory_inputs(self, inputs: dict[str, Any], mandatory_params: List[str]) -> bool:
+    def _check_for_mandatory_inputs(
+        self, inputs: dict[str, Any], mandatory_params: List[str]
+    ) -> bool:
         """Check for mandatory parameters in inputs"""
         for name in mandatory_params:
             if name not in inputs:
                 logger.error(f"Mandatory input {name} missing from query")
                 return False
         return True
-    
-    def _check_for_extra_inputs(self, inputs: dict[str, Any], all_params: List[str]) -> bool:
+
+    def _check_for_extra_inputs(
+        self, inputs: dict[str, Any], all_params: List[str]
+    ) -> bool:
         """Check for extra parameters not defined in the signature"""
         input_keys = set(inputs.keys())
         param_keys = set(all_params)
         if not input_keys.issubset(param_keys):
             extra_keys = input_keys - param_keys
-            logger.error(f"Extra inputs provided that are not in the signature: {extra_keys}")
+            logger.error(
+                f"Extra inputs provided that are not in the signature: {extra_keys}"
+            )
             return False
         return True
-        
+
     def _is_valid_inputs(
         self, inputs: dict[str, Any], function_schema: dict[str, Any]
     ) -> bool:
@@ -52,7 +57,7 @@ class BaseLLM(BaseModel):
             for info in param_info:
                 parts = info.split("=")
                 name_type_pair = parts[0].strip()
-                if ':' in name_type_pair:
+                if ":" in name_type_pair:
                     name, _ = name_type_pair.split(":")
                 else:
                     name = name_type_pair
diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py
index 5e9bbe9f..7c8dbf37 100644
--- a/tests/unit/llms/test_llm_base.py
+++ b/tests/unit/llms/test_llm_base.py
@@ -8,15 +8,15 @@ class TestBaseLLM:
     @pytest.fixture
     def base_llm(self):
         return BaseLLM(name="TestLLM")
-    
+
     @pytest.fixture
     def mixed_function_schema(self):
         return {
             "name": "test_function",
             "description": "A test function with mixed mandatory and optional parameters.",
-            "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')"
+            "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')",
         }
-    
+
     @pytest.fixture
     def mandatory_params(self):
         return ["param1", "param2"]
@@ -87,13 +87,17 @@ class TestBaseLLM:
             test_query = "What time is it in America/New_York?"
             base_llm.extract_function_inputs(test_schema, test_query)
 
-
     def test_mandatory_args_only(self, base_llm, mixed_function_schema):
         inputs = {"mandatory1": "value1", "mandatory2": 42}
         assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True
 
     def test_all_args_provided(self, base_llm, mixed_function_schema):
-        inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2"}
+        inputs = {
+            "mandatory1": "value1",
+            "mandatory2": 42,
+            "optional1": "opt1",
+            "optional2": "opt2",
+        }
         assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True
 
     def test_missing_mandatory_arg(self, base_llm, mixed_function_schema):
@@ -101,7 +105,13 @@ class TestBaseLLM:
         assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False
 
     def test_extra_arg_provided(self, base_llm, mixed_function_schema):
-        inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value"}
+        inputs = {
+            "mandatory1": "value1",
+            "mandatory2": 42,
+            "optional1": "opt1",
+            "optional2": "opt2",
+            "extra": "value",
+        }
         assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False
 
     def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params):
@@ -118,4 +128,4 @@ class TestBaseLLM:
 
     def test_check_for_extra_inputs_with_extras(self, base_llm, all_params):
         inputs = {"param1": "value1", "param2": "value2", "extra_param": "extra"}
-        assert base_llm._check_for_extra_inputs(inputs, all_params) == False
\ No newline at end of file
+        assert base_llm._check_for_extra_inputs(inputs, all_params) == False
-- 
GitLab