From f16529f884e028ef21db5b11e2c50d749fbff8b3 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Sat, 13 Jan 2024 18:17:14 +0000
Subject: [PATCH] added more tests

---
 semantic_router/llms/llamacpp.py     |  8 --------
 tests/unit/llms/test_llm_base.py     | 12 ++++++++++++
 tests/unit/llms/test_llm_llamacpp.py | 17 +++++++++++++++++
 3 files changed, 29 insertions(+), 8 deletions(-)

diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py
index 24f5fc91..2586d2e4 100644
--- a/semantic_router/llms/llamacpp.py
+++ b/semantic_router/llms/llamacpp.py
@@ -9,14 +9,6 @@ from semantic_router.schema import Message
 from semantic_router.utils.logger import logger
 
 
-class LlamaCppBaseLLM(BaseLLM):
-    def __init__(self, name: str, llm: Llama, temperature: float, max_tokens: int):
-        super().__init__(name)
-        self.llm = llm
-        self.temperature = temperature
-        self.max_tokens = max_tokens
-
-
 class LlamaCppLLM(BaseLLM):
     llm: Llama
     temperature: float
diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py
index f88307ab..076b2fc5 100644
--- a/tests/unit/llms/test_llm_base.py
+++ b/tests/unit/llms/test_llm_base.py
@@ -57,3 +57,15 @@ class TestBaseLLM:
             }
             test_query = "What time is it in America/New_York?"
             base_llm.extract_function_inputs(test_schema, test_query)
+
+    def test_base_llm_extract_function_inputs_no_output(self, base_llm, mocker):
+        with pytest.raises(Exception):
+            base_llm.output = mocker.Mock(return_value=None)
+            test_schema = {
+                "name": "get_time",
+                "description": 'Finds the current time in a specific timezone.\n\n:param timezone: The timezone to find the current time in, should\n    be a valid timezone from the IANA Time Zone Database like\n    "America/New_York" or "Europe/London". Do NOT put the place\n    name itself like "rome", or "new york", you must provide\n    the IANA format.\n:type timezone: str\n:return: The current time in the specified timezone.',
+                "signature": "(timezone: str) -> str",
+                "output": "<class 'str'>",
+            }
+            test_query = "What time is it in America/New_York?"
+            base_llm.extract_function_inputs(test_schema, test_query)
diff --git a/tests/unit/llms/test_llm_llamacpp.py b/tests/unit/llms/test_llm_llamacpp.py
index e2bad635..1344dda0 100644
--- a/tests/unit/llms/test_llm_llamacpp.py
+++ b/tests/unit/llms/test_llm_llamacpp.py
@@ -46,3 +46,20 @@ class TestLlamaCppLLM:
         llamacpp_llm.extract_function_inputs(
             query=test_query, function_schema=test_schema
         )
+
+    def test_llamacpp_extract_function_inputs_invalid(self, llamacpp_llm, mocker):
+        with pytest.raises(ValueError):
+            llamacpp_llm.llm.create_chat_completion = mocker.Mock(
+                return_value={"choices": [{"message": {"content": "{'time': 'America/New_York'}"}}]}
+            )
+            test_schema = {
+                "name": "get_time",
+                "description": 'Finds the current time in a specific timezone.\n\n:param timezone: The timezone to find the current time in, should\n    be a valid timezone from the IANA Time Zone Database like\n    "America/New_York" or "Europe/London". Do NOT put the place\n    name itself like "rome", or "new york", you must provide\n    the IANA format.\n:type timezone: str\n:return: The current time in the specified timezone.',
+                "signature": "(timezone: str) -> str",
+                "output": "<class 'str'>",
+            }
+            test_query = "What time is it in America/New_York?"
+
+            llamacpp_llm.extract_function_inputs(
+                query=test_query, function_schema=test_schema
+            )
\ No newline at end of file
-- 
GitLab