diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index 9fd97978b1be66d5181b33d49041674758da9636..f88307abc20cad754f212ce4b197f53c27818402 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -37,3 +37,23 @@ class TestBaseLLM: test_inputs = {"timezone": None} assert base_llm._is_valid_inputs(test_inputs, test_schema) is False + + def test_base_llm_is_valid_inputs_invalid_false(self, base_llm): + test_schema = { + "name": "get_time", + "description": 'Finds the current time in a specific timezone.\n\n:param timezone: The timezone to find the current time in, should\n be a valid timezone from the IANA Time Zone Database like\n "America/New_York" or "Europe/London". Do NOT put the place\n name itself like "rome", or "new york", you must provide\n the IANA format.\n:type timezone: str\n:return: The current time in the specified timezone.' + } + test_inputs = {"timezone": "America/New_York"} + + assert base_llm._is_valid_inputs(test_inputs, test_schema) is False + + def test_base_llm_extract_function_inputs(self, base_llm): + with pytest.raises(NotImplementedError): + test_schema = { + "name": "get_time", + "description": 'Finds the current time in a specific timezone.\n\n:param timezone: The timezone to find the current time in, should\n be a valid timezone from the IANA Time Zone Database like\n "America/New_York" or "Europe/London". Do NOT put the place\n name itself like "rome", or "new york", you must provide\n the IANA format.\n:type timezone: str\n:return: The current time in the specified timezone.', + "signature": "(timezone: str) -> str", + "output": "<class 'str'>", + } + test_query = "What time is it in America/New_York?" + base_llm.extract_function_inputs(test_schema, test_query) diff --git a/tests/unit/llms/test_llm_llamacpp.py b/tests/unit/llms/test_llm_llamacpp.py index 4bcf2a8ec3e9b2b14be8fd74b0f9c170dec8f9ed..507a52734c8dca38314b6bf8f058f029df6c75af 100644 --- a/tests/unit/llms/test_llm_llamacpp.py +++ b/tests/unit/llms/test_llm_llamacpp.py @@ -27,3 +27,6 @@ class TestLlamaCppLLM: llm_input = [Message(role="user", content="test")] output = llamacpp_llm(llm_input) assert output == "test" + + def test_llamacpp_llm_grammar(self, llamacpp_llm): + llamacpp_llm._grammar()