From 61d22a75a3108f9ca3b13a94d4ae6e182d5d46b8 Mon Sep 17 00:00:00 2001
From: zahid-syed <zahid.s2618@gmail.com>
Date: Tue, 12 Mar 2024 16:02:15 -0400
Subject: [PATCH] attempt to fix pytest coverage

---
 semantic_router/encoders/mistral.py  |  6 +++---
 semantic_router/llms/mistral.py      | 13 +++++++------
 tests/unit/encoders/test_mistral.py  | 13 +++++++++++++
 tests/unit/llms/test_llm_llamacpp.py | 14 ++++++++++++++
 tests/unit/llms/test_llm_mistral.py  | 13 +++++++++++++
 5 files changed, 50 insertions(+), 9 deletions(-)

diff --git a/semantic_router/encoders/mistral.py b/semantic_router/encoders/mistral.py
index 0d6bd2e7..6acadbb3 100644
--- a/semantic_router/encoders/mistral.py
+++ b/semantic_router/encoders/mistral.py
@@ -26,9 +26,6 @@ class MistralEncoder(BaseEncoder):
         if name is None:
             name = EncoderDefault.MISTRAL.value["embedding_model"]
         super().__init__(name=name, score_threshold=score_threshold)
-        api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
-        if api_key is None:
-            raise ValueError("Mistral API key not provided")
         (
             self._client,
             self._embedding_response,
@@ -47,6 +44,9 @@ class MistralEncoder(BaseEncoder):
                 "`pip install 'semantic-router[mistralai]'`"
             )
 
+        api_key = api_key or os.getenv("MISTRALAI_API_KEY")
+        if api_key is None:
+            raise ValueError("Mistral API key not provided")
         try:
             client = MistralClient(api_key=api_key)
             embedding_response = EmbeddingResponse
diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py
index adecd22c..afaa5aa2 100644
--- a/semantic_router/llms/mistral.py
+++ b/semantic_router/llms/mistral.py
@@ -25,10 +25,7 @@ class MistralAILLM(BaseLLM):
         if name is None:
             name = EncoderDefault.MISTRAL.value["language_model"]
         super().__init__(name=name)
-        api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
-        if api_key is None:
-            raise ValueError("MistralAI API key cannot be 'None'.")
-        self._initialize_client(api_key)
+        self._client = self._initialize_client(mistralai_api_key)
         self.temperature = temperature
         self.max_tokens = max_tokens
 
@@ -37,16 +34,20 @@ class MistralAILLM(BaseLLM):
             from mistralai.client import MistralClient
         except ImportError:
             raise ImportError(
-                "Please install MistralAI to use MistralEncoder. "
+                "Please install MistralAI to use MistralAI LLM. "
                 "You can install it with: "
                 "`pip install 'semantic-router[mistralai]'`"
             )
+        api_key = api_key or os.getenv("MISTRALAI_API_KEY")
+        if api_key is None:
+            raise ValueError("MistralAI API key cannot be 'None'.")
         try:
-            self._client = MistralClient(api_key=api_key)
+            client = MistralClient(api_key=api_key)
         except Exception as e:
             raise ValueError(
                 f"MistralAI API client failed to initialize. Error: {e}"
             ) from e
+        return client
 
     def __call__(self, messages: List[Message]) -> str:
         if self._client is None:
diff --git a/tests/unit/encoders/test_mistral.py b/tests/unit/encoders/test_mistral.py
index cd806ddf..f8b52338 100644
--- a/tests/unit/encoders/test_mistral.py
+++ b/tests/unit/encoders/test_mistral.py
@@ -4,6 +4,8 @@ from mistralai.models.embeddings import EmbeddingObject, EmbeddingResponse, Usag
 
 from semantic_router.encoders import MistralEncoder
 
+from unittest.mock import patch
+
 
 @pytest.fixture
 def mistralai_encoder(mocker):
@@ -12,6 +14,17 @@ def mistralai_encoder(mocker):
 
 
 class TestMistralEncoder:
+    def test_mistral_encoder_import_errors(self):
+        with patch.dict("sys.modules", {"mistralai": None}):
+            with pytest.raises(ImportError) as error:
+                MistralEncoder()
+
+        assert (
+            "Please install MistralAI to use MistralEncoder. "
+            "You can install it with: "
+            "`pip install 'semantic-router[mistralai]'`" in str(error.value)
+        )
+
     def test_mistralai_encoder_init_success(self, mocker):
         encoder = MistralEncoder(mistralai_api_key="test_api_key")
         assert encoder._client is not None
diff --git a/tests/unit/llms/test_llm_llamacpp.py b/tests/unit/llms/test_llm_llamacpp.py
index f0a5253f..7042d386 100644
--- a/tests/unit/llms/test_llm_llamacpp.py
+++ b/tests/unit/llms/test_llm_llamacpp.py
@@ -4,6 +4,8 @@ from llama_cpp import Llama
 from semantic_router.llms.llamacpp import LlamaCppLLM
 from semantic_router.schema import Message
 
+from unittest.mock import patch
+
 
 @pytest.fixture
 def llamacpp_llm(mocker):
@@ -13,6 +15,18 @@ def llamacpp_llm(mocker):
 
 
 class TestLlamaCppLLM:
+
+    def test_llama_cpp_import_errors(self, llamacpp_llm):
+        with patch.dict("sys.modules", {"llama_cpp": None}):
+            with pytest.raises(ImportError) as error:
+                LlamaCppLLM(llamacpp_llm.llm)
+
+        assert (
+            "Please install LlamaCPP to use Llama CPP llm. "
+            "You can install it with: "
+            "`pip install 'semantic-router[llama-cpp-python]'`" in str(error.value)
+        )
+
     def test_llamacpp_llm_init_success(self, llamacpp_llm):
         assert llamacpp_llm.name == "llama.cpp"
         assert llamacpp_llm.temperature == 0.2
diff --git a/tests/unit/llms/test_llm_mistral.py b/tests/unit/llms/test_llm_mistral.py
index e011a71c..cabf6521 100644
--- a/tests/unit/llms/test_llm_mistral.py
+++ b/tests/unit/llms/test_llm_mistral.py
@@ -3,6 +3,8 @@ import pytest
 from semantic_router.llms import MistralAILLM
 from semantic_router.schema import Message
 
+from unittest.mock import patch
+
 
 @pytest.fixture
 def mistralai_llm(mocker):
@@ -11,6 +13,17 @@ def mistralai_llm(mocker):
 
 
 class TestMistralAILLM:
+    def test_mistral_llm_import_errors(self):
+        with patch.dict("sys.modules", {"mistralai": None}):
+            with pytest.raises(ImportError) as error:
+                MistralAILLM(mistralai_api_key="random")
+
+        assert (
+            "Please install MistralAI to use MistralAI LLM. "
+            "You can install it with: "
+            "`pip install 'semantic-router[mistralai]'`" in str(error.value)
+        )
+
     def test_mistralai_llm_init_with_api_key(self, mistralai_llm):
         assert mistralai_llm._client is not None, "Client should be initialized"
         assert mistralai_llm.name == "mistral-tiny", "Default name not set correctly"
-- 
GitLab