diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py index 24f5fc9117b9d5f2181cce865699ef36964f2b3e..2586d2e4253e485445c9c5e5bc1b3b81061c8279 100644 --- a/semantic_router/llms/llamacpp.py +++ b/semantic_router/llms/llamacpp.py @@ -9,14 +9,6 @@ from semantic_router.schema import Message from semantic_router.utils.logger import logger -class LlamaCppBaseLLM(BaseLLM): - def __init__(self, name: str, llm: Llama, temperature: float, max_tokens: int): - super().__init__(name) - self.llm = llm - self.temperature = temperature - self.max_tokens = max_tokens - - class LlamaCppLLM(BaseLLM): llm: Llama temperature: float diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index f88307abc20cad754f212ce4b197f53c27818402..076b2fc584e8f71de586faa2b8df0a88c3db6fa3 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -57,3 +57,15 @@ class TestBaseLLM: } test_query = "What time is it in America/New_York?" base_llm.extract_function_inputs(test_schema, test_query) + + def test_base_llm_extract_function_inputs_no_output(self, base_llm, mocker): + with pytest.raises(Exception): + base_llm.output = mocker.Mock(return_value=None) + test_schema = { + "name": "get_time", + "description": 'Finds the current time in a specific timezone.\n\n:param timezone: The timezone to find the current time in, should\n be a valid timezone from the IANA Time Zone Database like\n "America/New_York" or "Europe/London". Do NOT put the place\n name itself like "rome", or "new york", you must provide\n the IANA format.\n:type timezone: str\n:return: The current time in the specified timezone.', + "signature": "(timezone: str) -> str", + "output": "<class 'str'>", + } + test_query = "What time is it in America/New_York?" + base_llm.extract_function_inputs(test_schema, test_query) diff --git a/tests/unit/llms/test_llm_llamacpp.py b/tests/unit/llms/test_llm_llamacpp.py index e2bad635949f6e86afc9270cbcd57096c1636136..1344dda0cb9b2e3aea507f69c4aafaafd8df56f8 100644 --- a/tests/unit/llms/test_llm_llamacpp.py +++ b/tests/unit/llms/test_llm_llamacpp.py @@ -46,3 +46,20 @@ class TestLlamaCppLLM: llamacpp_llm.extract_function_inputs( query=test_query, function_schema=test_schema ) + + def test_llamacpp_extract_function_inputs_invalid(self, llamacpp_llm, mocker): + with pytest.raises(ValueError): + llamacpp_llm.llm.create_chat_completion = mocker.Mock( + return_value={"choices": [{"message": {"content": "{'time': 'America/New_York'}"}}]} + ) + test_schema = { + "name": "get_time", + "description": 'Finds the current time in a specific timezone.\n\n:param timezone: The timezone to find the current time in, should\n be a valid timezone from the IANA Time Zone Database like\n "America/New_York" or "Europe/London". Do NOT put the place\n name itself like "rome", or "new york", you must provide\n the IANA format.\n:type timezone: str\n:return: The current time in the specified timezone.', + "signature": "(timezone: str) -> str", + "output": "<class 'str'>", + } + test_query = "What time is it in America/New_York?" + + llamacpp_llm.extract_function_inputs( + query=test_query, function_schema=test_schema + ) \ No newline at end of file