diff --git a/pyproject.toml b/pyproject.toml index 024b9e3de66bd003879e1708cd3ec80d6405efec..ab173ba40c8225749659ddd08326984aac1fa6ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ python = ">=3.9,<3.13" pydantic = "^2.5.3" openai = "^1.10.0" cohere = "^4.32" -mistralai= "^0.0.12" +mistralai= {version = "^0.0.12", optional = true} numpy = "^1.25.2" colorlog = "^6.8.0" pyyaml = "^6.0.1" diff --git a/semantic_router/encoders/mistral.py b/semantic_router/encoders/mistral.py index c76f1e74a1e86ede34f87ca6857c17dd71487ccd..5beabb71d3a966625abc15946b1bd9877c19bf63 100644 --- a/semantic_router/encoders/mistral.py +++ b/semantic_router/encoders/mistral.py @@ -1,20 +1,22 @@ """This file contains the MistralEncoder class which is used to encode text using MistralAI""" import os from time import sleep -from typing import List, Optional +from typing import List, Optional, Any + -from mistralai.client import MistralClient -from mistralai.exceptions import MistralException -from mistralai.models.embeddings import EmbeddingResponse from semantic_router.encoders import BaseEncoder from semantic_router.utils.defaults import EncoderDefault +from pydantic.v1 import PrivateAttr + class MistralEncoder(BaseEncoder): """Class to encode text using MistralAI""" - client: Optional[MistralClient] + client: Any = PrivateAttr() + embedding_response: Any = PrivateAttr() + mistral_exception: Any = PrivateAttr() type: str = "mistral" def __init__( @@ -29,12 +31,39 @@ class MistralEncoder(BaseEncoder): api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY") if api_key is None: raise ValueError("Mistral API key not provided") + self._client = self._initialize_client(mistralai_api_key) + + def _initialize_client(self, api_key): + try: + from mistralai.client import MistralClient + except ImportError: + raise ImportError( + "Please install MistralAI to use MistralEncoder. " + "You can install it with: " + "`pip install 'semantic-router[mistralai]'`" + ) + try: + from mistralai.exceptions import MistralException + from mistralai.models.embeddings import EmbeddingResponse + except ImportError: + raise ImportError( + "Please install MistralAI to use MistralEncoder. " + "You can install it with: " + "`pip install 'semantic-router[mistralai]'`" + ) + + try: self.client = MistralClient(api_key=api_key) + self.embedding_response = EmbeddingResponse + self.mistral_exception = MistralException except Exception as e: raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e + def __call__(self, docs: List[str]) -> List[List[float]]: + + if self.client is None: raise ValueError("Mistral client not initialized") embeds = None @@ -46,13 +75,13 @@ class MistralEncoder(BaseEncoder): embeds = self.client.embeddings(model=self.name, input=docs) if embeds.data: break - except MistralException as e: + except self.mistral_exception as e: sleep(2**_) error_message = str(e) except Exception as e: raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e - if not embeds or not isinstance(embeds, EmbeddingResponse) or not embeds.data: + if not embeds or not isinstance(embeds, self.embedding_response) or not embeds.data: raise ValueError(f"No embeddings returned from MistralAI: {error_message}") embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] - return embeddings + return embeddings \ No newline at end of file diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py index e17ba8bab050d6dbdccf3f22376729fd921c692e..42cbe46d026991f4112fac07b675f3fb152beeb7 100644 --- a/semantic_router/llms/mistral.py +++ b/semantic_router/llms/mistral.py @@ -1,16 +1,18 @@ import os -from typing import List, Optional +from typing import List, Optional, Any + -from mistralai.client import MistralClient from semantic_router.llms import BaseLLM from semantic_router.schema import Message from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.logger import logger +from pydantic.v1 import PrivateAttr + class MistralAILLM(BaseLLM): - client: Optional[MistralClient] + client: Any = PrivateAttr() temperature: Optional[float] max_tokens: Optional[int] @@ -27,15 +29,26 @@ class MistralAILLM(BaseLLM): api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY") if api_key is None: raise ValueError("MistralAI API key cannot be 'None'.") + self._initialize_client(api_key) + self.temperature = temperature + self.max_tokens = max_tokens + + def _initialize_client(self, api_key): + try: + from mistralai.client import MistralClient + except ImportError: + raise ImportError( + "Please install MistralAI to use MistralEncoder. " + "You can install it with: " + "`pip install 'semantic-router[mistralai]'`" + ) try: self.client = MistralClient(api_key=api_key) except Exception as e: raise ValueError( f"MistralAI API client failed to initialize. Error: {e}" ) from e - self.temperature = temperature - self.max_tokens = max_tokens - + def __call__(self, messages: List[Message]) -> str: if self.client is None: raise ValueError("MistralAI client is not initialized.")