From 62c961967b77d3593c99215ab68a70860bdaf8dd Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Sat, 13 Jan 2024 16:20:51 +0000 Subject: [PATCH] fix for tests --- semantic_router/llms/llamacpp.py | 8 ++++---- tests/unit/llms/test_llm_llamacpp.py | 5 ++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py index 73122878..c0bc0a2a 100644 --- a/semantic_router/llms/llamacpp.py +++ b/semantic_router/llms/llamacpp.py @@ -10,9 +10,9 @@ from semantic_router.utils.logger import logger class LlamaCppLLM(BaseLLM): - llm: Optional[Llama] = None - temperature: Optional[float] = None - max_tokens: Optional[int] = None + llm: Llama + temperature: float + max_tokens: int grammar: Optional[LlamaGrammar] = None def __init__( @@ -24,7 +24,7 @@ class LlamaCppLLM(BaseLLM): ): if not llm: raise ValueError("`llama_cpp.Llama` llm is required") - super().__init__(name=name) + super().__init__(name=name, llm=llm, temperature=temperature, max_tokens=max_tokens) self.llm = llm self.temperature = temperature self.max_tokens = max_tokens diff --git a/tests/unit/llms/test_llm_llamacpp.py b/tests/unit/llms/test_llm_llamacpp.py index 223fdaf1..db47b154 100644 --- a/tests/unit/llms/test_llm_llamacpp.py +++ b/tests/unit/llms/test_llm_llamacpp.py @@ -3,10 +3,13 @@ import pytest from semantic_router.llms import LlamaCppLLM from semantic_router.schema import Message +from llama_cpp import Llama + @pytest.fixture def llamacpp_llm(mocker): - llm = mocker.Mock() + mock_llama = mocker.patch("llama_cpp.Llama", spec=Llama) + llm = mock_llama.return_value return LlamaCppLLM(llm=llm) -- GitLab