"""This file contains the MistralEncoder class which is used to encode text using MistralAI"""
import os
from time import sleep
from typing import List, Optional, Any


from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault
from pydantic.v1 import PrivateAttr


class MistralEncoder(BaseEncoder):
    """Class to encode text using MistralAI"""

    _client: Any = PrivateAttr()
    _mistralai: Any = PrivateAttr()
    type: str = "mistral"

    def __init__(
        self,
        name: Optional[str] = None,
        mistralai_api_key: Optional[str] = None,
        score_threshold: float = 0.82,
    ):
        if name is None:
            name = EncoderDefault.MISTRAL.value["embedding_model"]
        super().__init__(name=name, score_threshold=score_threshold)
           self._client, self._mistralai = self._initialize_client(mistralai_api_key)

    def _initialize_client(self, api_key):
        try:
            import mistralai
            from mistralai.client import MistralClient
        except ImportError:
            raise ImportError(
                "Please install MistralAI to use MistralEncoder. "
                "You can install it with: "
                "`pip install 'semantic-router[mistralai]'`"
            )

        api_key = api_key or os.getenv("MISTRALAI_API_KEY")
        if api_key is None:
            raise ValueError("Mistral API key not provided")
        try:
            client = MistralClient(api_key=api_key)
        except Exception as e:
            raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e
        return client, mistralai

    def __call__(self, docs: List[str]) -> List[List[float]]:
        if self._client is None:
            raise ValueError("Mistral client not initialized")
        embeds = None
        error_message = ""

        # Exponential backoff
        for _ in range(3):
            try:
                embeds = self._client.embeddings(model=self.name, input=docs)
                if embeds.data:
                    break
            except self._mistralai.exceptions.MistralException as e:
                sleep(2**_)
                error_message = str(e)
            except Exception as e:
                raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e

        if (
            not embeds
            or not isinstance(embeds, self._embedding_response)
            or not embeds.data
        ):
            raise ValueError(f"No embeddings returned from MistralAI: {error_message}")
        embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
        return embeddings