diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py index 73122878c8c5e3862a0fcf7451273c9cfa084a29..c0bc0a2a7c662bc4395d6a1b6d5c2f4c95068c2f 100644 --- a/semantic_router/llms/llamacpp.py +++ b/semantic_router/llms/llamacpp.py @@ -10,9 +10,9 @@ from semantic_router.utils.logger import logger class LlamaCppLLM(BaseLLM): - llm: Optional[Llama] = None - temperature: Optional[float] = None - max_tokens: Optional[int] = None + llm: Llama + temperature: float + max_tokens: int grammar: Optional[LlamaGrammar] = None def __init__( @@ -24,7 +24,7 @@ class LlamaCppLLM(BaseLLM): ): if not llm: raise ValueError("`llama_cpp.Llama` llm is required") - super().__init__(name=name) + super().__init__(name=name, llm=llm, temperature=temperature, max_tokens=max_tokens) self.llm = llm self.temperature = temperature self.max_tokens = max_tokens diff --git a/tests/unit/llms/test_llm_llamacpp.py b/tests/unit/llms/test_llm_llamacpp.py index 223fdaf1eeefed21015f1c3dd94adc1fa5e1ec71..db47b1540c1bf3046ad7359f130befa6440911e0 100644 --- a/tests/unit/llms/test_llm_llamacpp.py +++ b/tests/unit/llms/test_llm_llamacpp.py @@ -3,10 +3,13 @@ import pytest from semantic_router.llms import LlamaCppLLM from semantic_router.schema import Message +from llama_cpp import Llama + @pytest.fixture def llamacpp_llm(mocker): - llm = mocker.Mock() + mock_llama = mocker.patch("llama_cpp.Llama", spec=Llama) + llm = mock_llama.return_value return LlamaCppLLM(llm=llm)