Skip to content
Snippets Groups Projects
Commit 61d22a75 authored by zahid-syed's avatar zahid-syed
Browse files

attempt to fix pytest coverage

parent fcf70777
No related branches found
No related tags found
No related merge requests found
...@@ -26,9 +26,6 @@ class MistralEncoder(BaseEncoder): ...@@ -26,9 +26,6 @@ class MistralEncoder(BaseEncoder):
if name is None: if name is None:
name = EncoderDefault.MISTRAL.value["embedding_model"] name = EncoderDefault.MISTRAL.value["embedding_model"]
super().__init__(name=name, score_threshold=score_threshold) super().__init__(name=name, score_threshold=score_threshold)
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("Mistral API key not provided")
( (
self._client, self._client,
self._embedding_response, self._embedding_response,
...@@ -47,6 +44,9 @@ class MistralEncoder(BaseEncoder): ...@@ -47,6 +44,9 @@ class MistralEncoder(BaseEncoder):
"`pip install 'semantic-router[mistralai]'`" "`pip install 'semantic-router[mistralai]'`"
) )
api_key = api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("Mistral API key not provided")
try: try:
client = MistralClient(api_key=api_key) client = MistralClient(api_key=api_key)
embedding_response = EmbeddingResponse embedding_response = EmbeddingResponse
......
...@@ -25,10 +25,7 @@ class MistralAILLM(BaseLLM): ...@@ -25,10 +25,7 @@ class MistralAILLM(BaseLLM):
if name is None: if name is None:
name = EncoderDefault.MISTRAL.value["language_model"] name = EncoderDefault.MISTRAL.value["language_model"]
super().__init__(name=name) super().__init__(name=name)
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY") self._client = self._initialize_client(mistralai_api_key)
if api_key is None:
raise ValueError("MistralAI API key cannot be 'None'.")
self._initialize_client(api_key)
self.temperature = temperature self.temperature = temperature
self.max_tokens = max_tokens self.max_tokens = max_tokens
...@@ -37,16 +34,20 @@ class MistralAILLM(BaseLLM): ...@@ -37,16 +34,20 @@ class MistralAILLM(BaseLLM):
from mistralai.client import MistralClient from mistralai.client import MistralClient
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Please install MistralAI to use MistralEncoder. " "Please install MistralAI to use MistralAI LLM. "
"You can install it with: " "You can install it with: "
"`pip install 'semantic-router[mistralai]'`" "`pip install 'semantic-router[mistralai]'`"
) )
api_key = api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("MistralAI API key cannot be 'None'.")
try: try:
self._client = MistralClient(api_key=api_key) client = MistralClient(api_key=api_key)
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
f"MistralAI API client failed to initialize. Error: {e}" f"MistralAI API client failed to initialize. Error: {e}"
) from e ) from e
return client
def __call__(self, messages: List[Message]) -> str: def __call__(self, messages: List[Message]) -> str:
if self._client is None: if self._client is None:
......
...@@ -4,6 +4,8 @@ from mistralai.models.embeddings import EmbeddingObject, EmbeddingResponse, Usag ...@@ -4,6 +4,8 @@ from mistralai.models.embeddings import EmbeddingObject, EmbeddingResponse, Usag
from semantic_router.encoders import MistralEncoder from semantic_router.encoders import MistralEncoder
from unittest.mock import patch
@pytest.fixture @pytest.fixture
def mistralai_encoder(mocker): def mistralai_encoder(mocker):
...@@ -12,6 +14,17 @@ def mistralai_encoder(mocker): ...@@ -12,6 +14,17 @@ def mistralai_encoder(mocker):
class TestMistralEncoder: class TestMistralEncoder:
def test_mistral_encoder_import_errors(self):
with patch.dict("sys.modules", {"mistralai": None}):
with pytest.raises(ImportError) as error:
MistralEncoder()
assert (
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`" in str(error.value)
)
def test_mistralai_encoder_init_success(self, mocker): def test_mistralai_encoder_init_success(self, mocker):
encoder = MistralEncoder(mistralai_api_key="test_api_key") encoder = MistralEncoder(mistralai_api_key="test_api_key")
assert encoder._client is not None assert encoder._client is not None
......
...@@ -4,6 +4,8 @@ from llama_cpp import Llama ...@@ -4,6 +4,8 @@ from llama_cpp import Llama
from semantic_router.llms.llamacpp import LlamaCppLLM from semantic_router.llms.llamacpp import LlamaCppLLM
from semantic_router.schema import Message from semantic_router.schema import Message
from unittest.mock import patch
@pytest.fixture @pytest.fixture
def llamacpp_llm(mocker): def llamacpp_llm(mocker):
...@@ -13,6 +15,18 @@ def llamacpp_llm(mocker): ...@@ -13,6 +15,18 @@ def llamacpp_llm(mocker):
class TestLlamaCppLLM: class TestLlamaCppLLM:
def test_llama_cpp_import_errors(self, llamacpp_llm):
with patch.dict("sys.modules", {"llama_cpp": None}):
with pytest.raises(ImportError) as error:
LlamaCppLLM(llamacpp_llm.llm)
assert (
"Please install LlamaCPP to use Llama CPP llm. "
"You can install it with: "
"`pip install 'semantic-router[llama-cpp-python]'`" in str(error.value)
)
def test_llamacpp_llm_init_success(self, llamacpp_llm): def test_llamacpp_llm_init_success(self, llamacpp_llm):
assert llamacpp_llm.name == "llama.cpp" assert llamacpp_llm.name == "llama.cpp"
assert llamacpp_llm.temperature == 0.2 assert llamacpp_llm.temperature == 0.2
......
...@@ -3,6 +3,8 @@ import pytest ...@@ -3,6 +3,8 @@ import pytest
from semantic_router.llms import MistralAILLM from semantic_router.llms import MistralAILLM
from semantic_router.schema import Message from semantic_router.schema import Message
from unittest.mock import patch
@pytest.fixture @pytest.fixture
def mistralai_llm(mocker): def mistralai_llm(mocker):
...@@ -11,6 +13,17 @@ def mistralai_llm(mocker): ...@@ -11,6 +13,17 @@ def mistralai_llm(mocker):
class TestMistralAILLM: class TestMistralAILLM:
def test_mistral_llm_import_errors(self):
with patch.dict("sys.modules", {"mistralai": None}):
with pytest.raises(ImportError) as error:
MistralAILLM(mistralai_api_key="random")
assert (
"Please install MistralAI to use MistralAI LLM. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`" in str(error.value)
)
def test_mistralai_llm_init_with_api_key(self, mistralai_llm): def test_mistralai_llm_init_with_api_key(self, mistralai_llm):
assert mistralai_llm._client is not None, "Client should be initialized" assert mistralai_llm._client is not None, "Client should be initialized"
assert mistralai_llm.name == "mistral-tiny", "Default name not set correctly" assert mistralai_llm.name == "mistral-tiny", "Default name not set correctly"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment