From d23817225c496ecb305cc9ccea34d5e57098515d Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Wed, 8 May 2024 04:33:03 +0400
Subject: [PATCH] More pytests.

---
 tests/unit/llms/test_llm_base.py | 100 ++++++++++++++++++++++++++++++-
 1 file changed, 99 insertions(+), 1 deletion(-)

diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py
index 3892afe8..680b5d2d 100644
--- a/tests/unit/llms/test_llm_base.py
+++ b/tests/unit/llms/test_llm_base.py
@@ -1,6 +1,6 @@
 import pytest
-
 from semantic_router.llms import BaseLLM
+from unittest.mock import patch
 
 
 class TestBaseLLM:
@@ -71,3 +71,101 @@ class TestBaseLLM:
             }
             test_query = "What time is it in America/New_York?"
             base_llm.extract_function_inputs(test_schema, test_query)
+
+    def test_is_valid_inputs_multiple_inputs(self, base_llm, mocker):
+        # Mock the logger to capture the error messages
+        mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error")
+
+        # Prepare multiple sets of inputs
+        test_inputs = [{"timezone": "America/New_York"}, {"timezone": "Europe/London"}]
+        test_schemas = [
+            {
+                "name": "get_time",
+                "description": "Finds the current time in a specific timezone.",
+                "signature": "(timezone: str) -> str",
+                "output": "<class 'str'>",
+            }
+        ]
+
+        # Call the method with multiple inputs
+        result = base_llm._is_valid_inputs(test_inputs, test_schemas)
+
+        # Assert that the method returns False
+        assert (
+            not result
+        ), "Method should return False when multiple inputs are provided"
+
+        # Check that the appropriate error message was logged
+        mocked_logger.assert_called_once_with(
+            "Only one set of function inputs is allowed."
+        )
+
+    def test_is_valid_inputs_exception_handling(self, base_llm, mocker):
+        # Mock the logger to capture the error messages
+        mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error")
+
+        # Use patch on the method's full path
+        with patch(
+            "semantic_router.llms.base.BaseLLM._validate_single_function_inputs",
+            side_effect=Exception("Test Exception"),
+        ):
+            test_inputs = [{"timezone": "America/New_York"}]
+            test_schemas = [
+                {
+                    "name": "get_time",
+                    "description": "Finds the current time in a specific timezone.",
+                    "signature": "(timezone: str) -> str",
+                    "output": "<class 'str'>",
+                }
+            ]
+
+            # Call the method and expect it to return False due to the exception
+            result = base_llm._is_valid_inputs(test_inputs, test_schemas)
+
+            # Assert that the method returns False
+            assert not result, "Method should return False when an exception occurs"
+
+            # Check that the appropriate error message was logged
+            mocked_logger.assert_called_once_with(
+                "Input validation error: Test Exception"
+            )
+
+    def test_validate_single_function_inputs_exception_handling(self, base_llm, mocker):
+        # Mock the logger to capture the error messages
+        mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error")
+
+        # Prepare inputs and a malformed function schema
+        test_inputs = {"timezone": "America/New_York"}
+        malformed_function_schema = {
+            "name": "get_time",
+            "description": "Finds the current time in a specific timezone.",
+            "signature": "(timezone str)",  # Malformed signature missing colon
+            "output": "<class 'str'>",
+        }
+
+        # Call the method and expect it to return False due to the exception
+        result = base_llm._validate_single_function_inputs(
+            test_inputs, malformed_function_schema
+        )
+
+        # Assert that the method returns False
+        assert not result, "Method should return False when an exception occurs"
+
+        # Check that the appropriate error message was logged
+        expected_error_message = "Single input validation error: list index out of range"  # Adjust based on the actual exception message
+        mocked_logger.assert_called_once_with(expected_error_message)
+
+    def test_extract_parameter_info_valid(self, base_llm):
+        # Test with a valid signature
+        signature = "(param1: int, param2: str = 'default')"
+        expected_names = ["param1", "param2"]
+        expected_types = ["int", "str"]
+        param_names, param_types = base_llm._extract_parameter_info(signature)
+        assert param_names == expected_names, "Parameter names did not match expected"
+        assert param_types == expected_types, "Parameter types did not match expected"
+
+    def test_extract_parameter_info_malformed(self, base_llm):
+        # Test with a malformed signature
+        signature = "(param1 int, param2: str = 'default')"
+        with pytest.raises(IndexError):
+            base_llm._extract_parameter_info(signature)
-- 
GitLab