'''This file contains the MistralEncoder class which is used to encode text using MistralAI''' import os from time import sleep from typing import List, Optional from semantic_router.encoders import BaseEncoder from mistralai.client import MistralClient from mistralai.exceptions import MistralException from mistralai.models.embeddings import EmbeddingResponse class MistralEncoder(BaseEncoder): '''Class to encode text using MistralAI''' client: Optional[MistralClient] type: str = "mistral" def __init__(self, name: Optional[str] = None, mistral_api_key: Optional[str] = None, score_threshold: Optional[float] = 0.82): if name is None: name = os.getenv("MISTRAL_MODEL_NAME", "mistral-embed") super().__init__(name=name, score_threshold=score_threshold) api_key = mistral_api_key or os.getenv("MISTRALAI_API_KEY") if api_key is None: raise ValueError("Mistral API key not provided") try: self.client = MistralClient(api_key=api_key) except Exception as e: raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e def __call__(self, docs: List[str]) -> List[List[float]]: if self.client is None: raise ValueError("Mistral client not initialized") embeds = None error_message = "" # Exponential backoff for _ in range(3): try: embeds = self.client.embeddings(model=self.name, input=docs) if embeds.data: break except MistralException as e: sleep(2**_) error_message = str(e) except Exception as e: raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e if( not embeds or not isinstance(embeds, EmbeddingResponse) or not embeds.data ): raise ValueError(f"No embeddings returned from MistralAI: {error_message}") embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] return embeddings