-
dwmorris11 authored
new encoder and a new LLMS. The encoder is a simple wrapper around the MistralAI API, and the LLMS is a simple wrapper around the encoder. The encoder is tested with a simple unit test, and the LLMS is tested with a simple unit test.
dwmorris11 authorednew encoder and a new LLMS. The encoder is a simple wrapper around the MistralAI API, and the LLMS is a simple wrapper around the encoder. The encoder is tested with a simple unit test, and the LLMS is tested with a simple unit test.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
test_mistral.py 4.58 KiB
import pytest
from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingResponse, EmbeddingObject, UsageInfo
from semantic_router.encoders import MistralEncoder
@pytest.fixture
def mistralai_encoder(mocker):
mocker.patch("MistralClient")
return MistralEncoder(mistralai_api_key="test_api_key")
class TestMistralEncoder:
def test_mistralai_encoder_init_success(self, mocker):
encoder = MistralEncoder()
assert encoder.client is not None
def test_mistralai_encoder_init_no_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
with pytest.raises(ValueError) as _:
MistralEncoder()
def test_mistralai_encoder_call_uninitialized_client(self, mistralai_encoder):
# Set the client to None to simulate an uninitialized client
mistralai_encoder.client = None
with pytest.raises(ValueError) as e:
mistralai_encoder(["test document"])
assert "MistralAI client is not initialized." in str(e.value)
def test_mistralai_encoder_init_exception(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("MistralClient", side_effect=Exception("Initialization error"))
with pytest.raises(ValueError) as e:
MistralEncoder()
assert (
"mistralai API client failed to initialize. Error: Initialization error"
in str(e.value)
)
def test_mistralai_encoder_call_success(self, mistralai_encoder, mocker):
mock_embeddings = mocker.Mock()
mock_embeddings.data = [
EmbeddingObject(embedding=[0.1, 0.2], index=0, object="embedding")
]
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mock_embedding = EmbeddingObject(index=0, object="embedding", embedding=[0.1, 0.2])
# Mock the CreateEmbeddingResponse object
mock_response = EmbeddingResponse(
model="mistral-embed",
object="list",
usage=UsageInfo(prompt_tokens=0, total_tokens=20),
data=[mock_embedding],
)
responses = [MistralException("mistralai error"), mock_response]
mocker.patch.object(
mistralai_encoder.client.embeddings, "create", side_effect=responses
)
embeddings = mistralai_encoder(["test document"])
assert embeddings == [[0.1, 0.2]]
def test_mistralai_encoder_call_with_retries(self, mistralai_encoder, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mocker.patch.object(
mistralai_encoder.client.embeddings,
"create",
side_effect=MistralException("Test error"),
)
with pytest.raises(ValueError) as e:
mistralai_encoder(["test document"])
assert "No embeddings returned. Error" in str(e.value)
def test_mistralai_encoder_call_failure_non_mistralai_error(self, mistralai_encoder, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mocker.patch.object(
mistralai_encoder.client.embeddings,
"create",
side_effect=Exception("Non-MistralException"),
)
with pytest.raises(ValueError) as e:
mistralai_encoder(["test document"])
assert "mistralai API call failed. Error: Non-MistralException" in str(e.value)
def test_mistralai_encoder_call_successful_retry(self, mistralai_encoder, mocker):
mock_embeddings = mocker.Mock()
mock_embeddings.data = [
EmbeddingObject(embedding=[0.1, 0.2], index=0, object="embedding")
]
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mock_embedding = EmbeddingObject(index=0, object="embedding", embedding=[0.1, 0.2])
# Mock the CreateEmbeddingResponse object
mock_response = EmbeddingResponse(
model="mistral-embed",
object="list",
usage=UsageInfo(prompt_tokens=0, total_tokens=20),
data=[mock_embedding],
)
responses = [MistralException("mistralai error"), mock_response]
mocker.patch.object(
mistralai_encoder.client.embeddings, "create", side_effect=responses
)
embeddings = mistralai_encoder(["test document"])
assert embeddings == [[0.1, 0.2]]