Code owners
Assign users and groups as approvers for specific file changes. Learn more.
vit.py 2.91 KiB
from typing import Any, List, Optional, Union, TypeAlias
from pydantic.v1 import PrivateAttr
from semantic_router.encoders import BaseEncoder
PILImage: TypeAlias = Union[Any, "Image"]
try:
from PIL.Image import Image
except ImportError:
pass
class VitEncoder(BaseEncoder):
name: str = "google/vit-base-patch16-224"
type: str = "huggingface"
score_threshold: float = 0.5
processor_kwargs: dict = {}
model_kwargs: dict = {}
device: Optional[str] = None
_processor: Any = PrivateAttr()
_model: Any = PrivateAttr()
_torch: Any = PrivateAttr()
_T: Any = PrivateAttr()
def __init__(self, **data):
super().__init__(**data)
self._processor, self._model = self._initialize_hf_model()
def _initialize_hf_model(self):
try:
from transformers import ViTImageProcessor, ViTModel
except ImportError:
raise ImportError(
"Please install transformers to use HuggingFaceEncoder. "
"You can install it with: "
"`pip install semantic-router[local]`"
)
try:
import torch
import torchvision.transforms as T
except ImportError:
raise ImportError(
"Please install Pytorch to use HuggingFaceEncoder. "
"You can install it with: "
"`pip install semantic-router[local]`"
)
self._torch = torch
self._T = T
processor = ViTImageProcessor.from_pretrained(
self.name, **self.processor_kwargs
)
model = ViTModel.from_pretrained(self.name, **self.model_kwargs)
if self.device:
model.to(self.device)
else:
device = "cuda" if self._torch.cuda.is_available() else "cpu"
model.to(device)
self.device = device
return processor, model
def _process_images(self, images: List[PILImage]):
rgb_images = [self._ensure_rgb(img) for img in images]
processed_images = self._processor(images=rgb_images, return_tensors="pt")
processed_images = processed_images.to(self.device)
return processed_images
def _ensure_rgb(self, img: PILImage):
rgbimg = Image.new("RGB", img.size)
rgbimg.paste(img)
return rgbimg
def __call__(
self,
imgs: List[PILImage],
batch_size: int = 32,
) -> List[List[float]]:
all_embeddings = []
for i in range(0, len(imgs), batch_size):
batch_imgs = imgs[i : i + batch_size]
batch_imgs_transform = self._process_images(batch_imgs)
with self._torch.no_grad():
embeddings = (
self._model(**batch_imgs_transform)
.last_hidden_state[:, 0]
.cpu()
.tolist()
)
all_embeddings.extend(embeddings)
return all_embeddings