Skip to content
Snippets Groups Projects
Commit 89faa38e authored by dwmorris11's avatar dwmorris11
Browse files

Added support for MistralAI API. This includes a

new encoder and a new LLMS. The encoder is a
simple wrapper around the MistralAI API, and the
LLMS is a simple wrapper around the encoder.
The encoder is tested with a simple unit test,
and the LLMS is tested with a simple unit test.
parent d7421545
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@ from semantic_router.encoders.bm25 import BM25Encoder
from semantic_router.encoders.cohere import CohereEncoder
from semantic_router.encoders.fastembed import FastEmbedEncoder
from semantic_router.encoders.huggingface import HuggingFaceEncoder
from semantic_router.encoders.mistral import MistralEncoder
from semantic_router.encoders.openai import OpenAIEncoder
from semantic_router.encoders.tfidf import TfidfEncoder
from semantic_router.encoders.zure import AzureOpenAIEncoder
......@@ -16,4 +17,5 @@ __all__ = [
"TfidfEncoder",
"FastEmbedEncoder",
"HuggingFaceEncoder",
"MistralEncoder"
]
'''This file contains the MistralEncoder class which is used to encode text using MistralAI'''
import os
from time import sleep
from typing import List, Optional
from semantic_router.encoders import BaseEncoder
from mistralai.client import MistralClient
from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingResponse
class MistralEncoder(BaseEncoder):
'''Class to encode text using MistralAI'''
client: Optional[MistralClient]
type: str = "mistral"
def __init__(self,
name: Optional[str] = None,
mistral_api_key: Optional[str] = None,
score_threshold: Optional[float] = 0.82):
if name is None:
name = os.getenv("MISTRAL_MODEL_NAME", "mistral-embed")
super().__init__(name=name, score_threshold=score_threshold)
api_key = mistral_api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("Mistral API key not provided")
try:
self.client = MistralClient(api_key=api_key)
except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e
def __call__(self, docs: List[str]) -> List[List[float]]:
if self.client is None:
raise ValueError("Mistral client not initialized")
embeds = None
error_message = ""
# Exponential backoff
for _ in range(3):
try:
embeds = self.client.embeddings(model=self.name, input=docs)
if embeds.data:
break
except MistralException as e:
sleep(2**_)
error_message = str(e)
except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e
if(
not embeds
or not isinstance(embeds, EmbeddingResponse)
or not embeds.data
):
raise ValueError(f"No embeddings returned from MistralAI: {error_message}")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
\ No newline at end of file
......@@ -3,5 +3,6 @@ from semantic_router.llms.cohere import CohereLLM
from semantic_router.llms.openai import OpenAILLM
from semantic_router.llms.openrouter import OpenRouterLLM
from semantic_router.llms.zure import AzureOpenAILLM
from semantic_router.llms.mistral import MistralAILLM
__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM", "AzureOpenAILLM"]
__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM", "AzureOpenAILLM", "MistralAILLM"]
import os
from typing import List, Optional
from mistralai.client import MistralClient
from semantic_router.llms import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.logger import logger
class MistralAILLM(BaseLLM):
client: Optional[MistralClient]
temperature: Optional[float]
max_tokens: Optional[int]
def __init__(
self,
name: Optional[str] = None,
mistralai_api_key: Optional[str] = None,
temperature: float = 0.01,
max_tokens: int = 200,
):
if name is None:
name = os.getenv("MISTRALAI_CHAT_MODEL_NAME", "mistral-tiny")
super().__init__(name=name)
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("MistralAI API key cannot be 'None'.")
try:
self.client = MistralClient(api_key=api_key)
except Exception as e:
raise ValueError(
f"MistralAI API client failed to initialize. Error: {e}"
) from e
self.temperature = temperature
self.max_tokens = max_tokens
def __call__(self, messages: List[Message]) -> str:
if self.client is None:
raise ValueError("MistralAI client is not initialized.")
try:
completion = self.client.chat(
model=self.name,
messages=[m.to_mistral() for m in messages],
temperature=self.temperature,
max_tokens=self.max_tokens,
)
output = completion.choices[0].message.content
if not output:
raise Exception("No output generated")
return output
except Exception as e:
logger.error(f"LLM error: {e}")
raise Exception(f"LLM error: {e}") from e
......@@ -9,14 +9,15 @@ from semantic_router.encoders import (
CohereEncoder,
FastEmbedEncoder,
OpenAIEncoder,
MistralEncoder,
)
class EncoderType(Enum):
HUGGINGFACE = "huggingface"
FASTEMBED = "fastembed"
OPENAI = "openai"
COHERE = "cohere"
MISTRAL = "mistral"
class RouteChoice(BaseModel):
......@@ -43,6 +44,8 @@ class Encoder:
self.model = OpenAIEncoder(name=name)
elif self.type == EncoderType.COHERE:
self.model = CohereEncoder(name=name)
elif self.type == EncoderType.MISTRAL:
self.model = MistralEncoder(name=name)
else:
raise ValueError
......@@ -65,6 +68,9 @@ class Message(BaseModel):
def to_llamacpp(self):
return {"role": self.role, "content": self.content}
def to_mistral(self):
return {"role": self.role, "content": self.content}
def __str__(self):
return f"{self.role}: {self.content}"
......@@ -72,4 +78,4 @@ class Message(BaseModel):
class DocumentSplit(BaseModel):
docs: List[str]
is_triggered: bool = False
triggered_score: Optional[float] = None
triggered_score: Optional[float] = None
\ No newline at end of file
import pytest
from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingResponse, EmbeddingObject, UsageInfo
from semantic_router.encoders import MistralEncoder
@pytest.fixture
def mistralai_encoder(mocker):
mocker.patch("MistralClient")
return MistralEncoder(mistralai_api_key="test_api_key")
class TestMistralEncoder:
def test_mistralai_encoder_init_success(self, mocker):
encoder = MistralEncoder()
assert encoder.client is not None
def test_mistralai_encoder_init_no_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
with pytest.raises(ValueError) as _:
MistralEncoder()
def test_mistralai_encoder_call_uninitialized_client(self, mistralai_encoder):
# Set the client to None to simulate an uninitialized client
mistralai_encoder.client = None
with pytest.raises(ValueError) as e:
mistralai_encoder(["test document"])
assert "MistralAI client is not initialized." in str(e.value)
def test_mistralai_encoder_init_exception(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("MistralClient", side_effect=Exception("Initialization error"))
with pytest.raises(ValueError) as e:
MistralEncoder()
assert (
"mistralai API client failed to initialize. Error: Initialization error"
in str(e.value)
)
def test_mistralai_encoder_call_success(self, mistralai_encoder, mocker):
mock_embeddings = mocker.Mock()
mock_embeddings.data = [
EmbeddingObject(embedding=[0.1, 0.2], index=0, object="embedding")
]
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mock_embedding = EmbeddingObject(index=0, object="embedding", embedding=[0.1, 0.2])
# Mock the CreateEmbeddingResponse object
mock_response = EmbeddingResponse(
model="mistral-embed",
object="list",
usage=UsageInfo(prompt_tokens=0, total_tokens=20),
data=[mock_embedding],
)
responses = [MistralException("mistralai error"), mock_response]
mocker.patch.object(
mistralai_encoder.client.embeddings, "create", side_effect=responses
)
embeddings = mistralai_encoder(["test document"])
assert embeddings == [[0.1, 0.2]]
def test_mistralai_encoder_call_with_retries(self, mistralai_encoder, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mocker.patch.object(
mistralai_encoder.client.embeddings,
"create",
side_effect=MistralException("Test error"),
)
with pytest.raises(ValueError) as e:
mistralai_encoder(["test document"])
assert "No embeddings returned. Error" in str(e.value)
def test_mistralai_encoder_call_failure_non_mistralai_error(self, mistralai_encoder, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mocker.patch.object(
mistralai_encoder.client.embeddings,
"create",
side_effect=Exception("Non-MistralException"),
)
with pytest.raises(ValueError) as e:
mistralai_encoder(["test document"])
assert "mistralai API call failed. Error: Non-MistralException" in str(e.value)
def test_mistralai_encoder_call_successful_retry(self, mistralai_encoder, mocker):
mock_embeddings = mocker.Mock()
mock_embeddings.data = [
EmbeddingObject(embedding=[0.1, 0.2], index=0, object="embedding")
]
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mock_embedding = EmbeddingObject(index=0, object="embedding", embedding=[0.1, 0.2])
# Mock the CreateEmbeddingResponse object
mock_response = EmbeddingResponse(
model="mistral-embed",
object="list",
usage=UsageInfo(prompt_tokens=0, total_tokens=20),
data=[mock_embedding],
)
responses = [MistralException("mistralai error"), mock_response]
mocker.patch.object(
mistralai_encoder.client.embeddings, "create", side_effect=responses
)
embeddings = mistralai_encoder(["test document"])
assert embeddings == [[0.1, 0.2]]
import pytest
from semantic_router.llms import MistralAILLM
from semantic_router.schema import Message
@pytest.fixture
def mistralai_llm(mocker):
mocker.patch("mistralai.Client")
return MistralAILLM(mistralai_api_key="test_api_key")
class TestMistralAILLM:
def test_mistralai_llm_init_with_api_key(self, mistralai_llm):
assert mistralai_llm.client is not None, "Client should be initialized"
assert mistralai_llm.name == "mistral-tiny", "Default name not set correctly"
def test_mistralai_llm_init_success(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
llm = MistralAILLM()
assert llm.client is not None
def test_mistralai_llm_init_without_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
with pytest.raises(ValueError) as _:
MistralAILLM()
def test_mistralai_llm_call_uninitialized_client(self, mistralai_llm):
# Set the client to None to simulate an uninitialized client
mistralai_llm.client = None
with pytest.raises(ValueError) as e:
llm_input = [Message(role="user", content="test")]
mistralai_llm(llm_input)
assert "mistralai client is not initialized." in str(e.value)
def test_mistralai_llm_init_exception(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("mistralai.mistralai", side_effect=Exception("Initialization error"))
with pytest.raises(ValueError) as e:
MistralAILLM()
assert (
"mistralai API client failed to initialize. Error: Initialization error"
in str(e.value)
)
def test_mistralai_llm_call_success(self, mistralai_llm, mocker):
mock_completion = mocker.MagicMock()
mock_completion.choices[0].message.content = "test"
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch.object(
mistralai_llm.client.chat.completions, "create", return_value=mock_completion
)
llm_input = [Message(role="user", content="test")]
output = mistralai_llm(llm_input)
assert output == "test"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment