diff --git a/semantic_router/encoders/mistral.py b/semantic_router/encoders/mistral.py index 0d6bd2e7f4413df31868fe7617d817fcd32adad4..6acadbb345eadfbd422d135009a29db1b3c06e13 100644 --- a/semantic_router/encoders/mistral.py +++ b/semantic_router/encoders/mistral.py @@ -26,9 +26,6 @@ class MistralEncoder(BaseEncoder): if name is None: name = EncoderDefault.MISTRAL.value["embedding_model"] super().__init__(name=name, score_threshold=score_threshold) - api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY") - if api_key is None: - raise ValueError("Mistral API key not provided") ( self._client, self._embedding_response, @@ -47,6 +44,9 @@ class MistralEncoder(BaseEncoder): "`pip install 'semantic-router[mistralai]'`" ) + api_key = api_key or os.getenv("MISTRALAI_API_KEY") + if api_key is None: + raise ValueError("Mistral API key not provided") try: client = MistralClient(api_key=api_key) embedding_response = EmbeddingResponse diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py index adecd22c86cccebaade5aaba9815524ea6d7846f..afaa5aa2f800cfe6ca3956c04be490793da7d032 100644 --- a/semantic_router/llms/mistral.py +++ b/semantic_router/llms/mistral.py @@ -25,10 +25,7 @@ class MistralAILLM(BaseLLM): if name is None: name = EncoderDefault.MISTRAL.value["language_model"] super().__init__(name=name) - api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY") - if api_key is None: - raise ValueError("MistralAI API key cannot be 'None'.") - self._initialize_client(api_key) + self._client = self._initialize_client(mistralai_api_key) self.temperature = temperature self.max_tokens = max_tokens @@ -37,16 +34,20 @@ class MistralAILLM(BaseLLM): from mistralai.client import MistralClient except ImportError: raise ImportError( - "Please install MistralAI to use MistralEncoder. " + "Please install MistralAI to use MistralAI LLM. " "You can install it with: " "`pip install 'semantic-router[mistralai]'`" ) + api_key = api_key or os.getenv("MISTRALAI_API_KEY") + if api_key is None: + raise ValueError("MistralAI API key cannot be 'None'.") try: - self._client = MistralClient(api_key=api_key) + client = MistralClient(api_key=api_key) except Exception as e: raise ValueError( f"MistralAI API client failed to initialize. Error: {e}" ) from e + return client def __call__(self, messages: List[Message]) -> str: if self._client is None: diff --git a/tests/unit/encoders/test_mistral.py b/tests/unit/encoders/test_mistral.py index cd806ddf1e15805f19ab859236e67e26416aa196..f8b523381a06101c922855b313aace309188f8cd 100644 --- a/tests/unit/encoders/test_mistral.py +++ b/tests/unit/encoders/test_mistral.py @@ -4,6 +4,8 @@ from mistralai.models.embeddings import EmbeddingObject, EmbeddingResponse, Usag from semantic_router.encoders import MistralEncoder +from unittest.mock import patch + @pytest.fixture def mistralai_encoder(mocker): @@ -12,6 +14,17 @@ def mistralai_encoder(mocker): class TestMistralEncoder: + def test_mistral_encoder_import_errors(self): + with patch.dict("sys.modules", {"mistralai": None}): + with pytest.raises(ImportError) as error: + MistralEncoder() + + assert ( + "Please install MistralAI to use MistralEncoder. " + "You can install it with: " + "`pip install 'semantic-router[mistralai]'`" in str(error.value) + ) + def test_mistralai_encoder_init_success(self, mocker): encoder = MistralEncoder(mistralai_api_key="test_api_key") assert encoder._client is not None diff --git a/tests/unit/llms/test_llm_llamacpp.py b/tests/unit/llms/test_llm_llamacpp.py index f0a5253f909ecce92769b50ccf7b6578720c3f63..7042d386eb6cc5c644b07d0b5236956357dd3987 100644 --- a/tests/unit/llms/test_llm_llamacpp.py +++ b/tests/unit/llms/test_llm_llamacpp.py @@ -4,6 +4,8 @@ from llama_cpp import Llama from semantic_router.llms.llamacpp import LlamaCppLLM from semantic_router.schema import Message +from unittest.mock import patch + @pytest.fixture def llamacpp_llm(mocker): @@ -13,6 +15,18 @@ def llamacpp_llm(mocker): class TestLlamaCppLLM: + + def test_llama_cpp_import_errors(self, llamacpp_llm): + with patch.dict("sys.modules", {"llama_cpp": None}): + with pytest.raises(ImportError) as error: + LlamaCppLLM(llamacpp_llm.llm) + + assert ( + "Please install LlamaCPP to use Llama CPP llm. " + "You can install it with: " + "`pip install 'semantic-router[llama-cpp-python]'`" in str(error.value) + ) + def test_llamacpp_llm_init_success(self, llamacpp_llm): assert llamacpp_llm.name == "llama.cpp" assert llamacpp_llm.temperature == 0.2 diff --git a/tests/unit/llms/test_llm_mistral.py b/tests/unit/llms/test_llm_mistral.py index e011a71cfa7a77fbf782a518ba23ce2923bc23be..cabf6521dd304c4b7005badeb68d3454806c66d1 100644 --- a/tests/unit/llms/test_llm_mistral.py +++ b/tests/unit/llms/test_llm_mistral.py @@ -3,6 +3,8 @@ import pytest from semantic_router.llms import MistralAILLM from semantic_router.schema import Message +from unittest.mock import patch + @pytest.fixture def mistralai_llm(mocker): @@ -11,6 +13,17 @@ def mistralai_llm(mocker): class TestMistralAILLM: + def test_mistral_llm_import_errors(self): + with patch.dict("sys.modules", {"mistralai": None}): + with pytest.raises(ImportError) as error: + MistralAILLM(mistralai_api_key="random") + + assert ( + "Please install MistralAI to use MistralAI LLM. " + "You can install it with: " + "`pip install 'semantic-router[mistralai]'`" in str(error.value) + ) + def test_mistralai_llm_init_with_api_key(self, mistralai_llm): assert mistralai_llm._client is not None, "Client should be initialized" assert mistralai_llm.name == "mistral-tiny", "Default name not set correctly"