From 89faa38ebc237c5efe7e357c3333b09cba6f7fd8 Mon Sep 17 00:00:00 2001 From: dwmorris11 <dustin.morris@outlook.com> Date: Sat, 10 Feb 2024 16:49:11 -0800 Subject: [PATCH] Added support for MistralAI API. This includes a new encoder and a new LLMS. The encoder is a simple wrapper around the MistralAI API, and the LLMS is a simple wrapper around the encoder. The encoder is tested with a simple unit test, and the LLMS is tested with a simple unit test. --- semantic_router/encoders/__init__.py | 2 + semantic_router/encoders/mistral.py | 55 +++++++++++++ semantic_router/llms/__init__.py | 3 +- semantic_router/llms/mistral.py | 56 +++++++++++++ semantic_router/schema.py | 10 ++- tests/unit/encoders/test_mistral.py | 113 +++++++++++++++++++++++++++ tests/unit/llms/test_llm_mistral.py | 56 +++++++++++++ 7 files changed, 292 insertions(+), 3 deletions(-) create mode 100644 semantic_router/encoders/mistral.py create mode 100644 semantic_router/llms/mistral.py create mode 100644 tests/unit/encoders/test_mistral.py create mode 100644 tests/unit/llms/test_llm_mistral.py diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index b4668154..a43d0cd6 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -3,6 +3,7 @@ from semantic_router.encoders.bm25 import BM25Encoder from semantic_router.encoders.cohere import CohereEncoder from semantic_router.encoders.fastembed import FastEmbedEncoder from semantic_router.encoders.huggingface import HuggingFaceEncoder +from semantic_router.encoders.mistral import MistralEncoder from semantic_router.encoders.openai import OpenAIEncoder from semantic_router.encoders.tfidf import TfidfEncoder from semantic_router.encoders.zure import AzureOpenAIEncoder @@ -16,4 +17,5 @@ __all__ = [ "TfidfEncoder", "FastEmbedEncoder", "HuggingFaceEncoder", + "MistralEncoder" ] diff --git a/semantic_router/encoders/mistral.py b/semantic_router/encoders/mistral.py new file mode 100644 index 00000000..ca28badc --- /dev/null +++ b/semantic_router/encoders/mistral.py @@ -0,0 +1,55 @@ +'''This file contains the MistralEncoder class which is used to encode text using MistralAI''' +import os +from time import sleep +from typing import List, Optional +from semantic_router.encoders import BaseEncoder +from mistralai.client import MistralClient +from mistralai.exceptions import MistralException +from mistralai.models.embeddings import EmbeddingResponse + +class MistralEncoder(BaseEncoder): + '''Class to encode text using MistralAI''' + client: Optional[MistralClient] + type: str = "mistral" + + def __init__(self, + name: Optional[str] = None, + mistral_api_key: Optional[str] = None, + score_threshold: Optional[float] = 0.82): + if name is None: + name = os.getenv("MISTRAL_MODEL_NAME", "mistral-embed") + super().__init__(name=name, score_threshold=score_threshold) + api_key = mistral_api_key or os.getenv("MISTRALAI_API_KEY") + if api_key is None: + raise ValueError("Mistral API key not provided") + try: + self.client = MistralClient(api_key=api_key) + except Exception as e: + raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e + + def __call__(self, docs: List[str]) -> List[List[float]]: + if self.client is None: + raise ValueError("Mistral client not initialized") + embeds = None + error_message = "" + + # Exponential backoff + for _ in range(3): + try: + embeds = self.client.embeddings(model=self.name, input=docs) + if embeds.data: + break + except MistralException as e: + sleep(2**_) + error_message = str(e) + except Exception as e: + raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e + + if( + not embeds + or not isinstance(embeds, EmbeddingResponse) + or not embeds.data + ): + raise ValueError(f"No embeddings returned from MistralAI: {error_message}") + embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] + return embeddings \ No newline at end of file diff --git a/semantic_router/llms/__init__.py b/semantic_router/llms/__init__.py index b216f0b0..94631ad6 100644 --- a/semantic_router/llms/__init__.py +++ b/semantic_router/llms/__init__.py @@ -3,5 +3,6 @@ from semantic_router.llms.cohere import CohereLLM from semantic_router.llms.openai import OpenAILLM from semantic_router.llms.openrouter import OpenRouterLLM from semantic_router.llms.zure import AzureOpenAILLM +from semantic_router.llms.mistral import MistralAILLM -__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM", "AzureOpenAILLM"] +__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM", "AzureOpenAILLM", "MistralAILLM"] diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py new file mode 100644 index 00000000..0f74bcbd --- /dev/null +++ b/semantic_router/llms/mistral.py @@ -0,0 +1,56 @@ +import os +from typing import List, Optional + +from mistralai.client import MistralClient + +from semantic_router.llms import BaseLLM +from semantic_router.schema import Message +from semantic_router.utils.logger import logger + + +class MistralAILLM(BaseLLM): + client: Optional[MistralClient] + temperature: Optional[float] + max_tokens: Optional[int] + + def __init__( + self, + name: Optional[str] = None, + mistralai_api_key: Optional[str] = None, + temperature: float = 0.01, + max_tokens: int = 200, + ): + if name is None: + name = os.getenv("MISTRALAI_CHAT_MODEL_NAME", "mistral-tiny") + super().__init__(name=name) + api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY") + if api_key is None: + raise ValueError("MistralAI API key cannot be 'None'.") + try: + self.client = MistralClient(api_key=api_key) + except Exception as e: + raise ValueError( + f"MistralAI API client failed to initialize. Error: {e}" + ) from e + self.temperature = temperature + self.max_tokens = max_tokens + + def __call__(self, messages: List[Message]) -> str: + if self.client is None: + raise ValueError("MistralAI client is not initialized.") + try: + completion = self.client.chat( + model=self.name, + messages=[m.to_mistral() for m in messages], + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + + output = completion.choices[0].message.content + + if not output: + raise Exception("No output generated") + return output + except Exception as e: + logger.error(f"LLM error: {e}") + raise Exception(f"LLM error: {e}") from e diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 46ee7f59..b8df1047 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -9,14 +9,15 @@ from semantic_router.encoders import ( CohereEncoder, FastEmbedEncoder, OpenAIEncoder, + MistralEncoder, ) - class EncoderType(Enum): HUGGINGFACE = "huggingface" FASTEMBED = "fastembed" OPENAI = "openai" COHERE = "cohere" + MISTRAL = "mistral" class RouteChoice(BaseModel): @@ -43,6 +44,8 @@ class Encoder: self.model = OpenAIEncoder(name=name) elif self.type == EncoderType.COHERE: self.model = CohereEncoder(name=name) + elif self.type == EncoderType.MISTRAL: + self.model = MistralEncoder(name=name) else: raise ValueError @@ -65,6 +68,9 @@ class Message(BaseModel): def to_llamacpp(self): return {"role": self.role, "content": self.content} + def to_mistral(self): + return {"role": self.role, "content": self.content} + def __str__(self): return f"{self.role}: {self.content}" @@ -72,4 +78,4 @@ class Message(BaseModel): class DocumentSplit(BaseModel): docs: List[str] is_triggered: bool = False - triggered_score: Optional[float] = None + triggered_score: Optional[float] = None \ No newline at end of file diff --git a/tests/unit/encoders/test_mistral.py b/tests/unit/encoders/test_mistral.py new file mode 100644 index 00000000..1d118101 --- /dev/null +++ b/tests/unit/encoders/test_mistral.py @@ -0,0 +1,113 @@ +import pytest +from mistralai.exceptions import MistralException +from mistralai.models.embeddings import EmbeddingResponse, EmbeddingObject, UsageInfo +from semantic_router.encoders import MistralEncoder + + +@pytest.fixture +def mistralai_encoder(mocker): + mocker.patch("MistralClient") + return MistralEncoder(mistralai_api_key="test_api_key") + + +class TestMistralEncoder: + def test_mistralai_encoder_init_success(self, mocker): + encoder = MistralEncoder() + assert encoder.client is not None + + def test_mistralai_encoder_init_no_api_key(self, mocker): + mocker.patch("os.getenv", return_value=None) + with pytest.raises(ValueError) as _: + MistralEncoder() + + def test_mistralai_encoder_call_uninitialized_client(self, mistralai_encoder): + # Set the client to None to simulate an uninitialized client + mistralai_encoder.client = None + with pytest.raises(ValueError) as e: + mistralai_encoder(["test document"]) + assert "MistralAI client is not initialized." in str(e.value) + + def test_mistralai_encoder_init_exception(self, mocker): + mocker.patch("os.getenv", return_value="fake-api-key") + mocker.patch("MistralClient", side_effect=Exception("Initialization error")) + with pytest.raises(ValueError) as e: + MistralEncoder() + assert ( + "mistralai API client failed to initialize. Error: Initialization error" + in str(e.value) + ) + + def test_mistralai_encoder_call_success(self, mistralai_encoder, mocker): + mock_embeddings = mocker.Mock() + mock_embeddings.data = [ + EmbeddingObject(embedding=[0.1, 0.2], index=0, object="embedding") + ] + + mocker.patch("os.getenv", return_value="fake-api-key") + mocker.patch("time.sleep", return_value=None) # To speed up the test + + mock_embedding = EmbeddingObject(index=0, object="embedding", embedding=[0.1, 0.2]) + # Mock the CreateEmbeddingResponse object + mock_response = EmbeddingResponse( + model="mistral-embed", + object="list", + usage=UsageInfo(prompt_tokens=0, total_tokens=20), + data=[mock_embedding], + ) + + responses = [MistralException("mistralai error"), mock_response] + mocker.patch.object( + mistralai_encoder.client.embeddings, "create", side_effect=responses + ) + embeddings = mistralai_encoder(["test document"]) + assert embeddings == [[0.1, 0.2]] + + def test_mistralai_encoder_call_with_retries(self, mistralai_encoder, mocker): + mocker.patch("os.getenv", return_value="fake-api-key") + mocker.patch("time.sleep", return_value=None) # To speed up the test + mocker.patch.object( + mistralai_encoder.client.embeddings, + "create", + side_effect=MistralException("Test error"), + ) + with pytest.raises(ValueError) as e: + mistralai_encoder(["test document"]) + assert "No embeddings returned. Error" in str(e.value) + + def test_mistralai_encoder_call_failure_non_mistralai_error(self, mistralai_encoder, mocker): + mocker.patch("os.getenv", return_value="fake-api-key") + mocker.patch("time.sleep", return_value=None) # To speed up the test + mocker.patch.object( + mistralai_encoder.client.embeddings, + "create", + side_effect=Exception("Non-MistralException"), + ) + with pytest.raises(ValueError) as e: + mistralai_encoder(["test document"]) + + assert "mistralai API call failed. Error: Non-MistralException" in str(e.value) + + def test_mistralai_encoder_call_successful_retry(self, mistralai_encoder, mocker): + mock_embeddings = mocker.Mock() + mock_embeddings.data = [ + EmbeddingObject(embedding=[0.1, 0.2], index=0, object="embedding") + ] + + mocker.patch("os.getenv", return_value="fake-api-key") + mocker.patch("time.sleep", return_value=None) # To speed up the test + + mock_embedding = EmbeddingObject(index=0, object="embedding", embedding=[0.1, 0.2]) + # Mock the CreateEmbeddingResponse object + mock_response = EmbeddingResponse( + model="mistral-embed", + object="list", + usage=UsageInfo(prompt_tokens=0, total_tokens=20), + data=[mock_embedding], + ) + + responses = [MistralException("mistralai error"), mock_response] + mocker.patch.object( + mistralai_encoder.client.embeddings, "create", side_effect=responses + ) + embeddings = mistralai_encoder(["test document"]) + assert embeddings == [[0.1, 0.2]] diff --git a/tests/unit/llms/test_llm_mistral.py b/tests/unit/llms/test_llm_mistral.py new file mode 100644 index 00000000..a7fc1d5f --- /dev/null +++ b/tests/unit/llms/test_llm_mistral.py @@ -0,0 +1,56 @@ +import pytest + +from semantic_router.llms import MistralAILLM +from semantic_router.schema import Message + + +@pytest.fixture +def mistralai_llm(mocker): + mocker.patch("mistralai.Client") + return MistralAILLM(mistralai_api_key="test_api_key") + + +class TestMistralAILLM: + def test_mistralai_llm_init_with_api_key(self, mistralai_llm): + assert mistralai_llm.client is not None, "Client should be initialized" + assert mistralai_llm.name == "mistral-tiny", "Default name not set correctly" + + def test_mistralai_llm_init_success(self, mocker): + mocker.patch("os.getenv", return_value="fake-api-key") + llm = MistralAILLM() + assert llm.client is not None + + def test_mistralai_llm_init_without_api_key(self, mocker): + mocker.patch("os.getenv", return_value=None) + with pytest.raises(ValueError) as _: + MistralAILLM() + + def test_mistralai_llm_call_uninitialized_client(self, mistralai_llm): + # Set the client to None to simulate an uninitialized client + mistralai_llm.client = None + with pytest.raises(ValueError) as e: + llm_input = [Message(role="user", content="test")] + mistralai_llm(llm_input) + assert "mistralai client is not initialized." in str(e.value) + + def test_mistralai_llm_init_exception(self, mocker): + mocker.patch("os.getenv", return_value="fake-api-key") + mocker.patch("mistralai.mistralai", side_effect=Exception("Initialization error")) + with pytest.raises(ValueError) as e: + MistralAILLM() + assert ( + "mistralai API client failed to initialize. Error: Initialization error" + in str(e.value) + ) + + def test_mistralai_llm_call_success(self, mistralai_llm, mocker): + mock_completion = mocker.MagicMock() + mock_completion.choices[0].message.content = "test" + + mocker.patch("os.getenv", return_value="fake-api-key") + mocker.patch.object( + mistralai_llm.client.chat.completions, "create", return_value=mock_completion + ) + llm_input = [Message(role="user", content="test")] + output = mistralai_llm(llm_input) + assert output == "test" -- GitLab