import os
import numpy as np
import pytest
import torch
from PIL import Image
from unittest.mock import patch

from semantic_router.encoders import CLIPEncoder

test_model_name = "aurelio-ai/sr-test-clip"
embed_dim = 64

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"


@pytest.fixture()
def dummy_pil_image():
    return Image.fromarray(np.random.rand(512, 224, 3).astype(np.uint8))


@pytest.fixture()
def dummy_black_and_white_img():
    return Image.fromarray(np.random.rand(224, 224, 2).astype(np.uint8))


@pytest.fixture()
def misshaped_pil_image():
    return Image.fromarray(np.random.rand(64, 64, 3).astype(np.uint8))


class TestClipEncoder:
    @pytest.mark.skipif(
        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
    )
    def test_clip_encoder__import_errors_transformers(self):
        with patch.dict("sys.modules", {"transformers": None}):
            with pytest.raises(ImportError) as error:
                CLIPEncoder()

        assert "install transformers" in str(error.value)

    @pytest.mark.skipif(
        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
    )
    def test_clip_encoder__import_errors_torch(self):
        with patch.dict("sys.modules", {"torch": None}):
            with pytest.raises(ImportError) as error:
                CLIPEncoder()

        assert "install Pytorch" in str(error.value)

    @pytest.mark.skipif(
        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
    )
    def test_clip_encoder_initialization(self):
        clip_encoder = CLIPEncoder(name=test_model_name)
        assert clip_encoder.name == test_model_name
        assert clip_encoder.type == "huggingface"
        assert clip_encoder.score_threshold == 0.2
        assert clip_encoder.device == device

    @pytest.mark.skipif(
        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
    )
    def test_clip_encoder_call_text(self):
        clip_encoder = CLIPEncoder(name=test_model_name)
        embeddings = clip_encoder(["hello", "world"])

        assert len(embeddings) == 2
        assert len(embeddings[0]) == embed_dim

    @pytest.mark.skipif(
        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
    )
    def test_clip_encoder_call_image(self, dummy_pil_image):
        clip_encoder = CLIPEncoder(name=test_model_name)
        encoded_images = clip_encoder([dummy_pil_image] * 3)

        assert len(encoded_images) == 3
        assert set(map(len, encoded_images)) == {embed_dim}

    @pytest.mark.skipif(
        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
    )
    def test_clip_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image):
        clip_encoder = CLIPEncoder(name=test_model_name)
        encoded_images = clip_encoder([dummy_pil_image, misshaped_pil_image])

        assert len(encoded_images) == 2
        assert set(map(len, encoded_images)) == {embed_dim}

    @pytest.mark.skipif(
        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
    )
    def test_clip_device(self):
        clip_encoder = CLIPEncoder(name=test_model_name)
        device = clip_encoder._model.device.type
        assert device == device

    @pytest.mark.skipif(
        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
    )
    def test_clip_encoder_ensure_rgb(self, dummy_black_and_white_img):
        clip_encoder = CLIPEncoder(name=test_model_name)
        rgb_image = clip_encoder._ensure_rgb(dummy_black_and_white_img)

        assert rgb_image.mode == "RGB"
        assert np.array(rgb_image).shape == (224, 224, 3)