Skip to content
Snippets Groups Projects
Commit fc6ee50d authored by Simonas's avatar Simonas
Browse files

fix: Vit. Image type

parent cd6c924a
No related branches found
No related tags found
No related merge requests found
from typing import Any, List, Optional
from PIL import Image
from PIL.Image import Image as _Image
from typing import Any, List, Optional, Union, TypeAlias
from pydantic.v1 import PrivateAttr
from semantic_router.encoders import BaseEncoder
PILImage: TypeAlias = Union[Any, "Image"]
try:
from PIL.Image import Image
except ImportError:
pass
class VitEncoder(BaseEncoder):
name: str = "google/vit-base-patch16-224"
......@@ -62,20 +64,20 @@ class VitEncoder(BaseEncoder):
return processor, model
def _process_images(self, images: List[_Image]):
def _process_images(self, images: List[PILImage]):
rgb_images = [self._ensure_rgb(img) for img in images]
processed_images = self._processor(images=rgb_images, return_tensors="pt")
processed_images = processed_images.to(self.device)
return processed_images
def _ensure_rgb(self, img: _Image):
def _ensure_rgb(self, img: PILImage):
rgbimg = Image.new("RGB", img.size)
rgbimg.paste(img)
return rgbimg
def __call__(
self,
imgs: List[_Image],
imgs: List[PILImage],
batch_size: int = 32,
) -> List[List[float]]:
all_embeddings = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment