diff --git a/pyproject.toml b/pyproject.toml index b396601d8ca30dd802c5116f5493b0571a585f17..c997d0075ca209bb1b2104cda227050fc2a9d0dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ python = "^3.9" pydantic = "^2.5.3" openai = "^1.10.0" cohere = "^4.32" +mistralai= "^0.0.12" numpy = "^1.25.2" colorlog = "^6.8.0" pyyaml = "^6.0.1" diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index a43d0cd65282ef8200d354c645f60b0f017adffa..4f27e96334dd34bf215f0f416db84178bcdde6d9 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -17,5 +17,5 @@ __all__ = [ "TfidfEncoder", "FastEmbedEncoder", "HuggingFaceEncoder", - "MistralEncoder" + "MistralEncoder", ] diff --git a/semantic_router/encoders/mistral.py b/semantic_router/encoders/mistral.py index ca28badc2a50a5aba22380fa52115e4d2a36f0d8..b0314284f64582f811a7668a79c66200f79f0b77 100644 --- a/semantic_router/encoders/mistral.py +++ b/semantic_router/encoders/mistral.py @@ -1,21 +1,27 @@ -'''This file contains the MistralEncoder class which is used to encode text using MistralAI''' +"""This file contains the MistralEncoder class which is used to encode text using MistralAI""" import os from time import sleep from typing import List, Optional -from semantic_router.encoders import BaseEncoder + from mistralai.client import MistralClient from mistralai.exceptions import MistralException from mistralai.models.embeddings import EmbeddingResponse +from semantic_router.encoders import BaseEncoder + + class MistralEncoder(BaseEncoder): - '''Class to encode text using MistralAI''' + """Class to encode text using MistralAI""" + client: Optional[MistralClient] type: str = "mistral" - def __init__(self, - name: Optional[str] = None, - mistral_api_key: Optional[str] = None, - score_threshold: Optional[float] = 0.82): + def __init__( + self, + name: Optional[str] = None, + mistral_api_key: Optional[str] = None, + score_threshold: Optional[float] = 0.82, + ): if name is None: name = os.getenv("MISTRAL_MODEL_NAME", "mistral-embed") super().__init__(name=name, score_threshold=score_threshold) @@ -45,11 +51,7 @@ class MistralEncoder(BaseEncoder): except Exception as e: raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e - if( - not embeds - or not isinstance(embeds, EmbeddingResponse) - or not embeds.data - ): + if not embeds or not isinstance(embeds, EmbeddingResponse) or not embeds.data: raise ValueError(f"No embeddings returned from MistralAI: {error_message}") embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] - return embeddings \ No newline at end of file + return embeddings diff --git a/semantic_router/llms/__init__.py b/semantic_router/llms/__init__.py index 94631ad6e2da41691cadd2f48e49e6be1a25f7c2..4e2eef16f35d5ef726ed121c6159ab27446ada6a 100644 --- a/semantic_router/llms/__init__.py +++ b/semantic_router/llms/__init__.py @@ -1,8 +1,15 @@ from semantic_router.llms.base import BaseLLM from semantic_router.llms.cohere import CohereLLM +from semantic_router.llms.mistral import MistralAILLM from semantic_router.llms.openai import OpenAILLM from semantic_router.llms.openrouter import OpenRouterLLM from semantic_router.llms.zure import AzureOpenAILLM -from semantic_router.llms.mistral import MistralAILLM -__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM", "AzureOpenAILLM", "MistralAILLM"] +__all__ = [ + "BaseLLM", + "OpenAILLM", + "OpenRouterLLM", + "CohereLLM", + "AzureOpenAILLM", + "MistralAILLM", +] diff --git a/semantic_router/schema.py b/semantic_router/schema.py index b8df1047865d9d353bfbdf4964c44c8d7e8ebc9d..6a5a0637a1190652b39f8d779eb23138e53ebfe3 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -8,10 +8,11 @@ from semantic_router.encoders import ( BaseEncoder, CohereEncoder, FastEmbedEncoder, - OpenAIEncoder, MistralEncoder, + OpenAIEncoder, ) + class EncoderType(Enum): HUGGINGFACE = "huggingface" FASTEMBED = "fastembed" @@ -78,4 +79,4 @@ class Message(BaseModel): class DocumentSplit(BaseModel): docs: List[str] is_triggered: bool = False - triggered_score: Optional[float] = None \ No newline at end of file + triggered_score: Optional[float] = None diff --git a/semantic_router/splitters/base.py b/semantic_router/splitters/base.py index c7ca37c2ed2660448b4549181e354aa0e5b8727a..edeba73b7e98d366c5e2b5991cbe827853660367 100644 --- a/semantic_router/splitters/base.py +++ b/semantic_router/splitters/base.py @@ -1,6 +1,7 @@ from typing import List from pydantic.v1 import BaseModel + from semantic_router.encoders import BaseEncoder diff --git a/semantic_router/splitters/consecutive_sim.py b/semantic_router/splitters/consecutive_sim.py index 6bd08845e0d3a45706bc99dd0740ab1355ae3840..55a29a5c62b479a6cd046ddb233ece4760b4f92f 100644 --- a/semantic_router/splitters/consecutive_sim.py +++ b/semantic_router/splitters/consecutive_sim.py @@ -1,8 +1,10 @@ from typing import List -from semantic_router.splitters.base import BaseSplitter -from semantic_router.encoders import BaseEncoder + import numpy as np + +from semantic_router.encoders import BaseEncoder from semantic_router.schema import DocumentSplit +from semantic_router.splitters.base import BaseSplitter class ConsecutiveSimSplitter(BaseSplitter): diff --git a/semantic_router/splitters/cumulative_sim.py b/semantic_router/splitters/cumulative_sim.py index ba8f4bd31faa15d268acb729aaedb4147d9c97e0..f7a6475ad809a8b1eb877f592cf9ca0799941ba2 100644 --- a/semantic_router/splitters/cumulative_sim.py +++ b/semantic_router/splitters/cumulative_sim.py @@ -1,8 +1,10 @@ from typing import List -from semantic_router.splitters.base import BaseSplitter + import numpy as np -from semantic_router.schema import DocumentSplit + from semantic_router.encoders import BaseEncoder +from semantic_router.schema import DocumentSplit +from semantic_router.splitters.base import BaseSplitter class CumulativeSimSplitter(BaseSplitter): diff --git a/semantic_router/text.py b/semantic_router/text.py index 978da135eb71f3912c8d49318c743f3c91fc2b47..6038888d1926c59f256dd8b5df8c15851f84e784 100644 --- a/semantic_router/text.py +++ b/semantic_router/text.py @@ -1,13 +1,12 @@ -from colorama import Fore -from colorama import Style +from typing import List, Literal, Tuple, Union +from colorama import Fore, Style from pydantic.v1 import BaseModel, Field -from typing import Union, List, Literal, Tuple + +from semantic_router.encoders import BaseEncoder +from semantic_router.schema import DocumentSplit, Message from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter -from semantic_router.encoders import BaseEncoder -from semantic_router.schema import Message -from semantic_router.schema import DocumentSplit # Define a type alias for the splitter to simplify the annotation SplitterType = Union[ConsecutiveSimSplitter, CumulativeSimSplitter, None] diff --git a/tests/unit/encoders/test_mistral.py b/tests/unit/encoders/test_mistral.py index 1d118101d32a505a11a8bec7b750969f7f251d1b..8c9e4f778c8f2825433f42f67f58e3969dcea56f 100644 --- a/tests/unit/encoders/test_mistral.py +++ b/tests/unit/encoders/test_mistral.py @@ -1,6 +1,7 @@ import pytest from mistralai.exceptions import MistralException -from mistralai.models.embeddings import EmbeddingResponse, EmbeddingObject, UsageInfo +from mistralai.models.embeddings import EmbeddingObject, EmbeddingResponse, UsageInfo + from semantic_router.encoders import MistralEncoder @@ -46,7 +47,9 @@ class TestMistralEncoder: mocker.patch("os.getenv", return_value="fake-api-key") mocker.patch("time.sleep", return_value=None) # To speed up the test - mock_embedding = EmbeddingObject(index=0, object="embedding", embedding=[0.1, 0.2]) + mock_embedding = EmbeddingObject( + index=0, object="embedding", embedding=[0.1, 0.2] + ) # Mock the CreateEmbeddingResponse object mock_response = EmbeddingResponse( model="mistral-embed", @@ -74,7 +77,9 @@ class TestMistralEncoder: mistralai_encoder(["test document"]) assert "No embeddings returned. Error" in str(e.value) - def test_mistralai_encoder_call_failure_non_mistralai_error(self, mistralai_encoder, mocker): + def test_mistralai_encoder_call_failure_non_mistralai_error( + self, mistralai_encoder, mocker + ): mocker.patch("os.getenv", return_value="fake-api-key") mocker.patch("time.sleep", return_value=None) # To speed up the test mocker.patch.object( @@ -96,7 +101,9 @@ class TestMistralEncoder: mocker.patch("os.getenv", return_value="fake-api-key") mocker.patch("time.sleep", return_value=None) # To speed up the test - mock_embedding = EmbeddingObject(index=0, object="embedding", embedding=[0.1, 0.2]) + mock_embedding = EmbeddingObject( + index=0, object="embedding", embedding=[0.1, 0.2] + ) # Mock the CreateEmbeddingResponse object mock_response = EmbeddingResponse( model="mistral-embed", diff --git a/tests/unit/llms/test_llm_mistral.py b/tests/unit/llms/test_llm_mistral.py index a7fc1d5f2fa3043971c9799dbfbdbae39ba202cf..c86baea4f9609ada676a70f9391692be0c40eb69 100644 --- a/tests/unit/llms/test_llm_mistral.py +++ b/tests/unit/llms/test_llm_mistral.py @@ -35,7 +35,9 @@ class TestMistralAILLM: def test_mistralai_llm_init_exception(self, mocker): mocker.patch("os.getenv", return_value="fake-api-key") - mocker.patch("mistralai.mistralai", side_effect=Exception("Initialization error")) + mocker.patch( + "mistralai.mistralai", side_effect=Exception("Initialization error") + ) with pytest.raises(ValueError) as e: MistralAILLM() assert ( @@ -49,7 +51,9 @@ class TestMistralAILLM: mocker.patch("os.getenv", return_value="fake-api-key") mocker.patch.object( - mistralai_llm.client.chat.completions, "create", return_value=mock_completion + mistralai_llm.client.chat.completions, + "create", + return_value=mock_completion, ) llm_input = [Message(role="user", content="test")] output = mistralai_llm(llm_input) diff --git a/tests/unit/test_splitters.py b/tests/unit/test_splitters.py index 165fc2dd4327320f12ed3ea3a68c440b650af1dd..5ee28504dad622cd8cb52c62553f651a5739e990 100644 --- a/tests/unit/test_splitters.py +++ b/tests/unit/test_splitters.py @@ -1,15 +1,15 @@ from unittest.mock import Mock, create_autospec -import pytest import numpy as np +import pytest -from semantic_router.text import Conversation -from semantic_router.schema import Message -from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter -from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter from semantic_router.encoders.base import BaseEncoder from semantic_router.encoders.cohere import CohereEncoder +from semantic_router.schema import Message from semantic_router.splitters.base import BaseSplitter +from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter +from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter +from semantic_router.text import Conversation def test_consecutive_sim_splitter(): diff --git a/tests/unit/test_text.py b/tests/unit/test_text.py index e7490ec14d43593efa608261e4e6ef1d76915347..51328b8b0035b0e977268c063352165c6b71d902 100644 --- a/tests/unit/test_text.py +++ b/tests/unit/test_text.py @@ -1,12 +1,16 @@ -import pytest from unittest.mock import Mock -from semantic_router.text import Conversation, Message -from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter -from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter + +import pytest + from semantic_router.encoders.cohere import ( CohereEncoder, -) # Adjust this import based on your project structure +) + +# Adjust this import based on your project structure from semantic_router.schema import DocumentSplit +from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter +from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter +from semantic_router.text import Conversation, Message @pytest.fixture