Skip to content
Snippets Groups Projects
  • dwmorris11's avatar
    89faa38e
    Added support for MistralAI API. This includes a · 89faa38e
    dwmorris11 authored
    new encoder and a new LLMS. The encoder is a
    simple wrapper around the MistralAI API, and the
    LLMS is a simple wrapper around the encoder.
    The encoder is tested with a simple unit test,
    and the LLMS is tested with a simple unit test.
    89faa38e
    History
    Added support for MistralAI API. This includes a
    dwmorris11 authored
    new encoder and a new LLMS. The encoder is a
    simple wrapper around the MistralAI API, and the
    LLMS is a simple wrapper around the encoder.
    The encoder is tested with a simple unit test,
    and the LLMS is tested with a simple unit test.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
test_mistral.py 4.58 KiB
import pytest
from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingResponse, EmbeddingObject, UsageInfo
from semantic_router.encoders import MistralEncoder


@pytest.fixture
def mistralai_encoder(mocker):
    mocker.patch("MistralClient")
    return MistralEncoder(mistralai_api_key="test_api_key")


class TestMistralEncoder:
    def test_mistralai_encoder_init_success(self, mocker):
        encoder = MistralEncoder()
        assert encoder.client is not None

    def test_mistralai_encoder_init_no_api_key(self, mocker):
        mocker.patch("os.getenv", return_value=None)
        with pytest.raises(ValueError) as _:
            MistralEncoder()

    def test_mistralai_encoder_call_uninitialized_client(self, mistralai_encoder):
        # Set the client to None to simulate an uninitialized client
        mistralai_encoder.client = None
        with pytest.raises(ValueError) as e:
            mistralai_encoder(["test document"])
        assert "MistralAI client is not initialized." in str(e.value)

    def test_mistralai_encoder_init_exception(self, mocker):
        mocker.patch("os.getenv", return_value="fake-api-key")
        mocker.patch("MistralClient", side_effect=Exception("Initialization error"))
        with pytest.raises(ValueError) as e:
            MistralEncoder()
        assert (
            "mistralai API client failed to initialize. Error: Initialization error"
            in str(e.value)
        )

    def test_mistralai_encoder_call_success(self, mistralai_encoder, mocker):
        mock_embeddings = mocker.Mock()
        mock_embeddings.data = [
            EmbeddingObject(embedding=[0.1, 0.2], index=0, object="embedding")
        ]

        mocker.patch("os.getenv", return_value="fake-api-key")
        mocker.patch("time.sleep", return_value=None)  # To speed up the test

        mock_embedding = EmbeddingObject(index=0, object="embedding", embedding=[0.1, 0.2])
        # Mock the CreateEmbeddingResponse object
        mock_response = EmbeddingResponse(
            model="mistral-embed",
            object="list",
            usage=UsageInfo(prompt_tokens=0, total_tokens=20),
            data=[mock_embedding],
        )

        responses = [MistralException("mistralai error"), mock_response]
        mocker.patch.object(
            mistralai_encoder.client.embeddings, "create", side_effect=responses
        )
        embeddings = mistralai_encoder(["test document"])
        assert embeddings == [[0.1, 0.2]]

    def test_mistralai_encoder_call_with_retries(self, mistralai_encoder, mocker):
        mocker.patch("os.getenv", return_value="fake-api-key")
        mocker.patch("time.sleep", return_value=None)  # To speed up the test
        mocker.patch.object(
            mistralai_encoder.client.embeddings,
            "create",
            side_effect=MistralException("Test error"),
        )
        with pytest.raises(ValueError) as e:
            mistralai_encoder(["test document"])
        assert "No embeddings returned. Error" in str(e.value)

    def test_mistralai_encoder_call_failure_non_mistralai_error(self, mistralai_encoder, mocker):
        mocker.patch("os.getenv", return_value="fake-api-key")
        mocker.patch("time.sleep", return_value=None)  # To speed up the test
        mocker.patch.object(
            mistralai_encoder.client.embeddings,
            "create",
            side_effect=Exception("Non-MistralException"),
        )
        with pytest.raises(ValueError) as e:
            mistralai_encoder(["test document"])

        assert "mistralai API call failed. Error: Non-MistralException" in str(e.value)

    def test_mistralai_encoder_call_successful_retry(self, mistralai_encoder, mocker):
        mock_embeddings = mocker.Mock()
        mock_embeddings.data = [
            EmbeddingObject(embedding=[0.1, 0.2], index=0, object="embedding")
        ]

        mocker.patch("os.getenv", return_value="fake-api-key")
        mocker.patch("time.sleep", return_value=None)  # To speed up the test

        mock_embedding = EmbeddingObject(index=0, object="embedding", embedding=[0.1, 0.2])
        # Mock the CreateEmbeddingResponse object
        mock_response = EmbeddingResponse(
            model="mistral-embed",
            object="list",
            usage=UsageInfo(prompt_tokens=0, total_tokens=20),
            data=[mock_embedding],
        )

        responses = [MistralException("mistralai error"), mock_response]
        mocker.patch.object(
            mistralai_encoder.client.embeddings, "create", side_effect=responses
        )
        embeddings = mistralai_encoder(["test document"])
        assert embeddings == [[0.1, 0.2]]