From 62c961967b77d3593c99215ab68a70860bdaf8dd Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Sat, 13 Jan 2024 16:20:51 +0000
Subject: [PATCH] fix for tests

---
 semantic_router/llms/llamacpp.py     | 8 ++++----
 tests/unit/llms/test_llm_llamacpp.py | 5 ++++-
 2 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py
index 73122878..c0bc0a2a 100644
--- a/semantic_router/llms/llamacpp.py
+++ b/semantic_router/llms/llamacpp.py
@@ -10,9 +10,9 @@ from semantic_router.utils.logger import logger
 
 
 class LlamaCppLLM(BaseLLM):
-    llm: Optional[Llama] = None
-    temperature: Optional[float] = None
-    max_tokens: Optional[int] = None
+    llm: Llama
+    temperature: float
+    max_tokens: int
     grammar: Optional[LlamaGrammar] = None
 
     def __init__(
@@ -24,7 +24,7 @@ class LlamaCppLLM(BaseLLM):
     ):
         if not llm:
             raise ValueError("`llama_cpp.Llama` llm is required")
-        super().__init__(name=name)
+        super().__init__(name=name, llm=llm, temperature=temperature, max_tokens=max_tokens)
         self.llm = llm
         self.temperature = temperature
         self.max_tokens = max_tokens
diff --git a/tests/unit/llms/test_llm_llamacpp.py b/tests/unit/llms/test_llm_llamacpp.py
index 223fdaf1..db47b154 100644
--- a/tests/unit/llms/test_llm_llamacpp.py
+++ b/tests/unit/llms/test_llm_llamacpp.py
@@ -3,10 +3,13 @@ import pytest
 from semantic_router.llms import LlamaCppLLM
 from semantic_router.schema import Message
 
+from llama_cpp import Llama
+
 
 @pytest.fixture
 def llamacpp_llm(mocker):
-    llm = mocker.Mock()
+    mock_llama = mocker.patch("llama_cpp.Llama", spec=Llama)
+    llm = mock_llama.return_value
     return LlamaCppLLM(llm=llm)
 
 
-- 
GitLab