from typing import Any, List, Optional, Union, TypeAlias
from pydantic.v1 import PrivateAttr
from semantic_router.encoders import BaseEncoder

PILImage: TypeAlias = Union[Any, "Image"]
try:
    from PIL.Image import Image
except ImportError:
    pass


class VitEncoder(BaseEncoder):
    name: str = "google/vit-base-patch16-224"
    type: str = "huggingface"
    score_threshold: float = 0.5
    processor_kwargs: dict = {}
    model_kwargs: dict = {}
    device: Optional[str] = None
    _processor: Any = PrivateAttr()
    _model: Any = PrivateAttr()
    _torch: Any = PrivateAttr()
    _T: Any = PrivateAttr()

    def __init__(self, **data):
        super().__init__(**data)
        self._processor, self._model = self._initialize_hf_model()

    def _initialize_hf_model(self):
        try:
            from transformers import ViTImageProcessor, ViTModel
        except ImportError:
            raise ImportError(
                "Please install transformers to use HuggingFaceEncoder. "
                "You can install it with: "
                "`pip install semantic-router[local]`"
            )

        try:
            import torch
            import torchvision.transforms as T
        except ImportError:
            raise ImportError(
                "Please install Pytorch to use HuggingFaceEncoder. "
                "You can install it with: "
                "`pip install semantic-router[local]`"
            )

        self._torch = torch
        self._T = T

        processor = ViTImageProcessor.from_pretrained(
            self.name, **self.processor_kwargs
        )

        model = ViTModel.from_pretrained(self.name, **self.model_kwargs)

        if self.device:
            model.to(self.device)

        else:
            device = "cuda" if self._torch.cuda.is_available() else "cpu"
            model.to(device)
            self.device = device

        return processor, model

    def _process_images(self, images: List[PILImage]):
        rgb_images = [self._ensure_rgb(img) for img in images]
        processed_images = self._processor(images=rgb_images, return_tensors="pt")
        processed_images = processed_images.to(self.device)
        return processed_images

    def _ensure_rgb(self, img: PILImage):
        rgbimg = Image.new("RGB", img.size)
        rgbimg.paste(img)
        return rgbimg

    def __call__(
        self,
        imgs: List[PILImage],
        batch_size: int = 32,
    ) -> List[List[float]]:
        all_embeddings = []
        for i in range(0, len(imgs), batch_size):
            batch_imgs = imgs[i : i + batch_size]
            batch_imgs_transform = self._process_images(batch_imgs)
            with self._torch.no_grad():
                embeddings = (
                    self._model(**batch_imgs_transform)
                    .last_hidden_state[:, 0]
                    .cpu()
                    .tolist()
                )
            all_embeddings.extend(embeddings)
        return all_embeddings