From 89faa38ebc237c5efe7e357c3333b09cba6f7fd8 Mon Sep 17 00:00:00 2001
From: dwmorris11 <dustin.morris@outlook.com>
Date: Sat, 10 Feb 2024 16:49:11 -0800
Subject: [PATCH] Added support for MistralAI API. This includes a new encoder
 and a new LLMS. The encoder is a simple wrapper around the MistralAI API, and
 the LLMS is a simple wrapper around the encoder. The encoder is tested with a
 simple unit test, and the LLMS is tested with a simple unit test.

---
 semantic_router/encoders/__init__.py |   2 +
 semantic_router/encoders/mistral.py  |  55 +++++++++++++
 semantic_router/llms/__init__.py     |   3 +-
 semantic_router/llms/mistral.py      |  56 +++++++++++++
 semantic_router/schema.py            |  10 ++-
 tests/unit/encoders/test_mistral.py  | 113 +++++++++++++++++++++++++++
 tests/unit/llms/test_llm_mistral.py  |  56 +++++++++++++
 7 files changed, 292 insertions(+), 3 deletions(-)
 create mode 100644 semantic_router/encoders/mistral.py
 create mode 100644 semantic_router/llms/mistral.py
 create mode 100644 tests/unit/encoders/test_mistral.py
 create mode 100644 tests/unit/llms/test_llm_mistral.py

diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py
index b4668154..a43d0cd6 100644
--- a/semantic_router/encoders/__init__.py
+++ b/semantic_router/encoders/__init__.py
@@ -3,6 +3,7 @@ from semantic_router.encoders.bm25 import BM25Encoder
 from semantic_router.encoders.cohere import CohereEncoder
 from semantic_router.encoders.fastembed import FastEmbedEncoder
 from semantic_router.encoders.huggingface import HuggingFaceEncoder
+from semantic_router.encoders.mistral import MistralEncoder
 from semantic_router.encoders.openai import OpenAIEncoder
 from semantic_router.encoders.tfidf import TfidfEncoder
 from semantic_router.encoders.zure import AzureOpenAIEncoder
@@ -16,4 +17,5 @@ __all__ = [
     "TfidfEncoder",
     "FastEmbedEncoder",
     "HuggingFaceEncoder",
+    "MistralEncoder"
 ]
diff --git a/semantic_router/encoders/mistral.py b/semantic_router/encoders/mistral.py
new file mode 100644
index 00000000..ca28badc
--- /dev/null
+++ b/semantic_router/encoders/mistral.py
@@ -0,0 +1,55 @@
+'''This file contains the MistralEncoder class which is used to encode text using MistralAI'''
+import os
+from time import sleep
+from typing import List, Optional
+from semantic_router.encoders import BaseEncoder
+from mistralai.client import MistralClient
+from mistralai.exceptions import MistralException
+from mistralai.models.embeddings import EmbeddingResponse
+
+class MistralEncoder(BaseEncoder):
+    '''Class to encode text using MistralAI'''
+    client: Optional[MistralClient]
+    type: str = "mistral"
+
+    def __init__(self,
+                 name: Optional[str] = None,
+                 mistral_api_key: Optional[str] = None,
+                 score_threshold: Optional[float] = 0.82):
+        if name is None:
+            name = os.getenv("MISTRAL_MODEL_NAME", "mistral-embed")
+        super().__init__(name=name, score_threshold=score_threshold)
+        api_key = mistral_api_key or os.getenv("MISTRALAI_API_KEY")
+        if api_key is None:
+            raise ValueError("Mistral API key not provided")
+        try:
+            self.client = MistralClient(api_key=api_key)
+        except Exception as e:
+            raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e
+
+    def __call__(self, docs: List[str]) -> List[List[float]]:
+        if self.client is None:
+            raise ValueError("Mistral client not initialized")
+        embeds = None
+        error_message = ""
+
+        # Exponential backoff
+        for _ in range(3):
+            try:
+                embeds = self.client.embeddings(model=self.name, input=docs)
+                if embeds.data:
+                    break
+            except MistralException as e:
+                sleep(2**_)
+                error_message = str(e)
+            except Exception as e:
+                raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e
+
+        if(
+            not embeds
+            or not isinstance(embeds, EmbeddingResponse)
+            or not embeds.data
+        ):
+            raise ValueError(f"No embeddings returned from MistralAI: {error_message}")
+        embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
+        return embeddings
\ No newline at end of file
diff --git a/semantic_router/llms/__init__.py b/semantic_router/llms/__init__.py
index b216f0b0..94631ad6 100644
--- a/semantic_router/llms/__init__.py
+++ b/semantic_router/llms/__init__.py
@@ -3,5 +3,6 @@ from semantic_router.llms.cohere import CohereLLM
 from semantic_router.llms.openai import OpenAILLM
 from semantic_router.llms.openrouter import OpenRouterLLM
 from semantic_router.llms.zure import AzureOpenAILLM
+from semantic_router.llms.mistral import MistralAILLM
 
-__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM", "AzureOpenAILLM"]
+__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM", "AzureOpenAILLM", "MistralAILLM"]
diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py
new file mode 100644
index 00000000..0f74bcbd
--- /dev/null
+++ b/semantic_router/llms/mistral.py
@@ -0,0 +1,56 @@
+import os
+from typing import List, Optional
+
+from mistralai.client import MistralClient
+
+from semantic_router.llms import BaseLLM
+from semantic_router.schema import Message
+from semantic_router.utils.logger import logger
+
+
+class MistralAILLM(BaseLLM):
+    client: Optional[MistralClient]
+    temperature: Optional[float]
+    max_tokens: Optional[int]
+
+    def __init__(
+        self,
+        name: Optional[str] = None,
+        mistralai_api_key: Optional[str] = None,
+        temperature: float = 0.01,
+        max_tokens: int = 200,
+    ):
+        if name is None:
+            name = os.getenv("MISTRALAI_CHAT_MODEL_NAME", "mistral-tiny")
+        super().__init__(name=name)
+        api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
+        if api_key is None:
+            raise ValueError("MistralAI API key cannot be 'None'.")
+        try:
+            self.client = MistralClient(api_key=api_key)
+        except Exception as e:
+            raise ValueError(
+                f"MistralAI API client failed to initialize. Error: {e}"
+            ) from e
+        self.temperature = temperature
+        self.max_tokens = max_tokens
+
+    def __call__(self, messages: List[Message]) -> str:
+        if self.client is None:
+            raise ValueError("MistralAI client is not initialized.")
+        try:
+            completion = self.client.chat(
+                model=self.name,
+                messages=[m.to_mistral() for m in messages],
+                temperature=self.temperature,
+                max_tokens=self.max_tokens,
+            )
+
+            output = completion.choices[0].message.content
+
+            if not output:
+                raise Exception("No output generated")
+            return output
+        except Exception as e:
+            logger.error(f"LLM error: {e}")
+            raise Exception(f"LLM error: {e}") from e
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 46ee7f59..b8df1047 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -9,14 +9,15 @@ from semantic_router.encoders import (
     CohereEncoder,
     FastEmbedEncoder,
     OpenAIEncoder,
+    MistralEncoder,
 )
 
-
 class EncoderType(Enum):
     HUGGINGFACE = "huggingface"
     FASTEMBED = "fastembed"
     OPENAI = "openai"
     COHERE = "cohere"
+    MISTRAL = "mistral"
 
 
 class RouteChoice(BaseModel):
@@ -43,6 +44,8 @@ class Encoder:
             self.model = OpenAIEncoder(name=name)
         elif self.type == EncoderType.COHERE:
             self.model = CohereEncoder(name=name)
+        elif self.type == EncoderType.MISTRAL:
+            self.model = MistralEncoder(name=name)
         else:
             raise ValueError
 
@@ -65,6 +68,9 @@ class Message(BaseModel):
     def to_llamacpp(self):
         return {"role": self.role, "content": self.content}
 
+    def to_mistral(self):
+        return {"role": self.role, "content": self.content}
+
     def __str__(self):
         return f"{self.role}: {self.content}"
 
@@ -72,4 +78,4 @@ class Message(BaseModel):
 class DocumentSplit(BaseModel):
     docs: List[str]
     is_triggered: bool = False
-    triggered_score: Optional[float] = None
+    triggered_score: Optional[float] = None
\ No newline at end of file
diff --git a/tests/unit/encoders/test_mistral.py b/tests/unit/encoders/test_mistral.py
new file mode 100644
index 00000000..1d118101
--- /dev/null
+++ b/tests/unit/encoders/test_mistral.py
@@ -0,0 +1,113 @@
+import pytest
+from mistralai.exceptions import MistralException
+from mistralai.models.embeddings import EmbeddingResponse, EmbeddingObject, UsageInfo
+from semantic_router.encoders import MistralEncoder
+
+
+@pytest.fixture
+def mistralai_encoder(mocker):
+    mocker.patch("MistralClient")
+    return MistralEncoder(mistralai_api_key="test_api_key")
+
+
+class TestMistralEncoder:
+    def test_mistralai_encoder_init_success(self, mocker):
+        encoder = MistralEncoder()
+        assert encoder.client is not None
+
+    def test_mistralai_encoder_init_no_api_key(self, mocker):
+        mocker.patch("os.getenv", return_value=None)
+        with pytest.raises(ValueError) as _:
+            MistralEncoder()
+
+    def test_mistralai_encoder_call_uninitialized_client(self, mistralai_encoder):
+        # Set the client to None to simulate an uninitialized client
+        mistralai_encoder.client = None
+        with pytest.raises(ValueError) as e:
+            mistralai_encoder(["test document"])
+        assert "MistralAI client is not initialized." in str(e.value)
+
+    def test_mistralai_encoder_init_exception(self, mocker):
+        mocker.patch("os.getenv", return_value="fake-api-key")
+        mocker.patch("MistralClient", side_effect=Exception("Initialization error"))
+        with pytest.raises(ValueError) as e:
+            MistralEncoder()
+        assert (
+            "mistralai API client failed to initialize. Error: Initialization error"
+            in str(e.value)
+        )
+
+    def test_mistralai_encoder_call_success(self, mistralai_encoder, mocker):
+        mock_embeddings = mocker.Mock()
+        mock_embeddings.data = [
+            EmbeddingObject(embedding=[0.1, 0.2], index=0, object="embedding")
+        ]
+
+        mocker.patch("os.getenv", return_value="fake-api-key")
+        mocker.patch("time.sleep", return_value=None)  # To speed up the test
+
+        mock_embedding = EmbeddingObject(index=0, object="embedding", embedding=[0.1, 0.2])
+        # Mock the CreateEmbeddingResponse object
+        mock_response = EmbeddingResponse(
+            model="mistral-embed",
+            object="list",
+            usage=UsageInfo(prompt_tokens=0, total_tokens=20),
+            data=[mock_embedding],
+        )
+
+        responses = [MistralException("mistralai error"), mock_response]
+        mocker.patch.object(
+            mistralai_encoder.client.embeddings, "create", side_effect=responses
+        )
+        embeddings = mistralai_encoder(["test document"])
+        assert embeddings == [[0.1, 0.2]]
+
+    def test_mistralai_encoder_call_with_retries(self, mistralai_encoder, mocker):
+        mocker.patch("os.getenv", return_value="fake-api-key")
+        mocker.patch("time.sleep", return_value=None)  # To speed up the test
+        mocker.patch.object(
+            mistralai_encoder.client.embeddings,
+            "create",
+            side_effect=MistralException("Test error"),
+        )
+        with pytest.raises(ValueError) as e:
+            mistralai_encoder(["test document"])
+        assert "No embeddings returned. Error" in str(e.value)
+
+    def test_mistralai_encoder_call_failure_non_mistralai_error(self, mistralai_encoder, mocker):
+        mocker.patch("os.getenv", return_value="fake-api-key")
+        mocker.patch("time.sleep", return_value=None)  # To speed up the test
+        mocker.patch.object(
+            mistralai_encoder.client.embeddings,
+            "create",
+            side_effect=Exception("Non-MistralException"),
+        )
+        with pytest.raises(ValueError) as e:
+            mistralai_encoder(["test document"])
+
+        assert "mistralai API call failed. Error: Non-MistralException" in str(e.value)
+
+    def test_mistralai_encoder_call_successful_retry(self, mistralai_encoder, mocker):
+        mock_embeddings = mocker.Mock()
+        mock_embeddings.data = [
+            EmbeddingObject(embedding=[0.1, 0.2], index=0, object="embedding")
+        ]
+
+        mocker.patch("os.getenv", return_value="fake-api-key")
+        mocker.patch("time.sleep", return_value=None)  # To speed up the test
+
+        mock_embedding = EmbeddingObject(index=0, object="embedding", embedding=[0.1, 0.2])
+        # Mock the CreateEmbeddingResponse object
+        mock_response = EmbeddingResponse(
+            model="mistral-embed",
+            object="list",
+            usage=UsageInfo(prompt_tokens=0, total_tokens=20),
+            data=[mock_embedding],
+        )
+
+        responses = [MistralException("mistralai error"), mock_response]
+        mocker.patch.object(
+            mistralai_encoder.client.embeddings, "create", side_effect=responses
+        )
+        embeddings = mistralai_encoder(["test document"])
+        assert embeddings == [[0.1, 0.2]]
diff --git a/tests/unit/llms/test_llm_mistral.py b/tests/unit/llms/test_llm_mistral.py
new file mode 100644
index 00000000..a7fc1d5f
--- /dev/null
+++ b/tests/unit/llms/test_llm_mistral.py
@@ -0,0 +1,56 @@
+import pytest
+
+from semantic_router.llms import MistralAILLM
+from semantic_router.schema import Message
+
+
+@pytest.fixture
+def mistralai_llm(mocker):
+    mocker.patch("mistralai.Client")
+    return MistralAILLM(mistralai_api_key="test_api_key")
+
+
+class TestMistralAILLM:
+    def test_mistralai_llm_init_with_api_key(self, mistralai_llm):
+        assert mistralai_llm.client is not None, "Client should be initialized"
+        assert mistralai_llm.name == "mistral-tiny", "Default name not set correctly"
+
+    def test_mistralai_llm_init_success(self, mocker):
+        mocker.patch("os.getenv", return_value="fake-api-key")
+        llm = MistralAILLM()
+        assert llm.client is not None
+
+    def test_mistralai_llm_init_without_api_key(self, mocker):
+        mocker.patch("os.getenv", return_value=None)
+        with pytest.raises(ValueError) as _:
+            MistralAILLM()
+
+    def test_mistralai_llm_call_uninitialized_client(self, mistralai_llm):
+        # Set the client to None to simulate an uninitialized client
+        mistralai_llm.client = None
+        with pytest.raises(ValueError) as e:
+            llm_input = [Message(role="user", content="test")]
+            mistralai_llm(llm_input)
+        assert "mistralai client is not initialized." in str(e.value)
+
+    def test_mistralai_llm_init_exception(self, mocker):
+        mocker.patch("os.getenv", return_value="fake-api-key")
+        mocker.patch("mistralai.mistralai", side_effect=Exception("Initialization error"))
+        with pytest.raises(ValueError) as e:
+            MistralAILLM()
+        assert (
+            "mistralai API client failed to initialize. Error: Initialization error"
+            in str(e.value)
+        )
+
+    def test_mistralai_llm_call_success(self, mistralai_llm, mocker):
+        mock_completion = mocker.MagicMock()
+        mock_completion.choices[0].message.content = "test"
+
+        mocker.patch("os.getenv", return_value="fake-api-key")
+        mocker.patch.object(
+            mistralai_llm.client.chat.completions, "create", return_value=mock_completion
+        )
+        llm_input = [Message(role="user", content="test")]
+        output = mistralai_llm(llm_input)
+        assert output == "test"
-- 
GitLab