from typing import Any, List, Optional

import numpy as np
from pydantic import BaseModel, PrivateAttr


class FastEmbedEncoder(BaseModel):
    type: str = "fastembed"
    model_name: str = "BAAI/bge-small-en-v1.5"
    max_length: int = 512
    cache_dir: Optional[str] = None
    threads: Optional[int] = None
    _client: Any = PrivateAttr()

    def __init__(self, **data):
        super().__init__(**data)
        self._client = self._initialize_client()

    def _initialize_client(self):
        try:
            from fastembed.embedding import FlagEmbedding as Embedding
        except ImportError:
            raise ImportError(
                "Please install fastembed to use FastEmbedEncoder"
                "You can install it with: `pip install fastembed`"
            )

        embedding_args = {
            "model_name": self.model_name,
            "max_length": self.max_length,
            "cache_dir": self.cache_dir,
            "threads": self.threads,
        }

        embedding_args = {k: v for k, v in embedding_args.items() if v is not None}

        embedding = Embedding(**embedding_args)
        return embedding

    def __call__(self, docs: list[str]) -> list[list[float]]:
        try:
            embeds: List[np.ndarray] = list(self._client.embed(docs))
            embeddings: List[List[float]] = [e.tolist() for e in embeds]
            return embeddings
        except Exception as e:
            raise ValueError(f"FastEmbed embed failed. Error: {e}")