diff --git a/semantic_router/encoders/mistral.py b/semantic_router/encoders/mistral.py index bd6f5045822cdacbabb8ddbfed787b6ef3791f3c..fbe2640cbfbe9e4aad40f3cfb9e24effdfdcc8d0 100644 --- a/semantic_router/encoders/mistral.py +++ b/semantic_router/encoders/mistral.py @@ -25,7 +25,7 @@ class MistralEncoder(BaseEncoder): if name is None: name = EncoderDefault.MISTRAL.value["embedding_model"] super().__init__(name=name, score_threshold=score_threshold) - self._client, self._mistralai = self._initialize_client(mistralai_api_key) + self._client, self._mistralai = self._initialize_client(mistralai_api_key) def _initialize_client(self, api_key): try: @@ -67,7 +67,9 @@ class MistralEncoder(BaseEncoder): if ( not embeds - or not isinstance(embeds, self._mistralai.models.embeddings.EmbeddingResponse) + or not isinstance( + embeds, self._mistralai.models.embeddings.EmbeddingResponse + ) or not embeds.data ): raise ValueError(f"No embeddings returned from MistralAI: {error_message}") diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py index afaa5aa2f800cfe6ca3956c04be490793da7d032..647d4073e5c7b50591dd7cd686536c0bebba0714 100644 --- a/semantic_router/llms/mistral.py +++ b/semantic_router/llms/mistral.py @@ -14,6 +14,7 @@ class MistralAILLM(BaseLLM): _client: Any = PrivateAttr() temperature: Optional[float] max_tokens: Optional[int] + _mistralai: Any = PrivateAttr() def __init__( self, @@ -25,12 +26,13 @@ class MistralAILLM(BaseLLM): if name is None: name = EncoderDefault.MISTRAL.value["language_model"] super().__init__(name=name) - self._client = self._initialize_client(mistralai_api_key) + self._client, self._mistralai = self._initialize_client(mistralai_api_key) self.temperature = temperature self.max_tokens = max_tokens def _initialize_client(self, api_key): try: + import mistralai from mistralai.client import MistralClient except ImportError: raise ImportError( @@ -47,15 +49,22 @@ class MistralAILLM(BaseLLM): raise ValueError( f"MistralAI API client failed to initialize. Error: {e}" ) from e - return client + return client, mistralai def __call__(self, messages: List[Message]) -> str: if self._client is None: raise ValueError("MistralAI client is not initialized.") + + chat_messages = [ + self._mistralai.models.chat_completion.ChatMessage( + role=m.role, content=m.content + ) + for m in messages + ] try: completion = self._client.chat( model=self.name, - messages=[m.to_mistral() for m in messages], + messages=chat_messages, temperature=self.temperature, max_tokens=self.max_tokens, )