Skip to content
Snippets Groups Projects
Commit fc6ee50d authored by Simonas's avatar Simonas
Browse files

fix: Vit. Image type

parent cd6c924a
No related branches found
No related tags found
No related merge requests found
from typing import Any, List, Optional from typing import Any, List, Optional, Union, TypeAlias
from PIL import Image
from PIL.Image import Image as _Image
from pydantic.v1 import PrivateAttr from pydantic.v1 import PrivateAttr
from semantic_router.encoders import BaseEncoder from semantic_router.encoders import BaseEncoder
PILImage: TypeAlias = Union[Any, "Image"]
try:
from PIL.Image import Image
except ImportError:
pass
class VitEncoder(BaseEncoder): class VitEncoder(BaseEncoder):
name: str = "google/vit-base-patch16-224" name: str = "google/vit-base-patch16-224"
...@@ -62,20 +64,20 @@ class VitEncoder(BaseEncoder): ...@@ -62,20 +64,20 @@ class VitEncoder(BaseEncoder):
return processor, model return processor, model
def _process_images(self, images: List[_Image]): def _process_images(self, images: List[PILImage]):
rgb_images = [self._ensure_rgb(img) for img in images] rgb_images = [self._ensure_rgb(img) for img in images]
processed_images = self._processor(images=rgb_images, return_tensors="pt") processed_images = self._processor(images=rgb_images, return_tensors="pt")
processed_images = processed_images.to(self.device) processed_images = processed_images.to(self.device)
return processed_images return processed_images
def _ensure_rgb(self, img: _Image): def _ensure_rgb(self, img: PILImage):
rgbimg = Image.new("RGB", img.size) rgbimg = Image.new("RGB", img.size)
rgbimg.paste(img) rgbimg.paste(img)
return rgbimg return rgbimg
def __call__( def __call__(
self, self,
imgs: List[_Image], imgs: List[PILImage],
batch_size: int = 32, batch_size: int = 32,
) -> List[List[float]]: ) -> List[List[float]]:
all_embeddings = [] all_embeddings = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment