Code owners
Assign users and groups as approvers for specific file changes. Learn more.
fastembed.py 1.46 KiB
from typing import Any, List, Optional
import numpy as np
from pydantic import BaseModel, PrivateAttr
class FastEmbedEncoder(BaseModel):
type: str = "fastembed"
model_name: str = "BAAI/bge-small-en-v1.5"
max_length: int = 512
cache_dir: Optional[str] = None
threads: Optional[int] = None
_client: Any = PrivateAttr()
def __init__(self, **data):
super().__init__(**data)
self._client = self._initialize_client()
def _initialize_client(self):
try:
from fastembed.embedding import FlagEmbedding as Embedding
except ImportError:
raise ImportError(
"Please install fastembed to use FastEmbedEncoder"
"You can install it with: `pip install fastembed`"
)
embedding_args = {
"model_name": self.model_name,
"max_length": self.max_length,
"cache_dir": self.cache_dir,
"threads": self.threads,
}
embedding_args = {k: v for k, v in embedding_args.items() if v is not None}
embedding = Embedding(**embedding_args)
return embedding
def __call__(self, docs: list[str]) -> list[list[float]]:
try:
embeds: List[np.ndarray] = list(self._client.embed(docs))
embeddings: List[List[float]] = [e.tolist() for e in embeds]
return embeddings
except Exception as e:
raise ValueError(f"FastEmbed embed failed. Error: {e}")