From fc6ee50dccd8fb8063b62c549e94c6af61d4ff73 Mon Sep 17 00:00:00 2001 From: Simonas <20096648+simjak@users.noreply.github.com> Date: Tue, 27 Feb 2024 19:37:53 +0200 Subject: [PATCH] fix: Vit. Image type --- semantic_router/encoders/vit.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/semantic_router/encoders/vit.py b/semantic_router/encoders/vit.py index d2f28dda..3c50cff9 100644 --- a/semantic_router/encoders/vit.py +++ b/semantic_router/encoders/vit.py @@ -1,11 +1,13 @@ -from typing import Any, List, Optional - -from PIL import Image -from PIL.Image import Image as _Image +from typing import Any, List, Optional, Union, TypeAlias from pydantic.v1 import PrivateAttr - from semantic_router.encoders import BaseEncoder +PILImage: TypeAlias = Union[Any, "Image"] +try: + from PIL.Image import Image +except ImportError: + pass + class VitEncoder(BaseEncoder): name: str = "google/vit-base-patch16-224" @@ -62,20 +64,20 @@ class VitEncoder(BaseEncoder): return processor, model - def _process_images(self, images: List[_Image]): + def _process_images(self, images: List[PILImage]): rgb_images = [self._ensure_rgb(img) for img in images] processed_images = self._processor(images=rgb_images, return_tensors="pt") processed_images = processed_images.to(self.device) return processed_images - def _ensure_rgb(self, img: _Image): + def _ensure_rgb(self, img: PILImage): rgbimg = Image.new("RGB", img.size) rgbimg.paste(img) return rgbimg def __call__( self, - imgs: List[_Image], + imgs: List[PILImage], batch_size: int = 32, ) -> List[List[float]]: all_embeddings = [] -- GitLab