diff --git a/semantic_router/encoders/vit.py b/semantic_router/encoders/vit.py index d2f28ddad44e240bd46716252b735939428967b4..3c50cff989bd3d3f1c35f6e4d044a5a31c4a4cfd 100644 --- a/semantic_router/encoders/vit.py +++ b/semantic_router/encoders/vit.py @@ -1,11 +1,13 @@ -from typing import Any, List, Optional - -from PIL import Image -from PIL.Image import Image as _Image +from typing import Any, List, Optional, Union, TypeAlias from pydantic.v1 import PrivateAttr - from semantic_router.encoders import BaseEncoder +PILImage: TypeAlias = Union[Any, "Image"] +try: + from PIL.Image import Image +except ImportError: + pass + class VitEncoder(BaseEncoder): name: str = "google/vit-base-patch16-224" @@ -62,20 +64,20 @@ class VitEncoder(BaseEncoder): return processor, model - def _process_images(self, images: List[_Image]): + def _process_images(self, images: List[PILImage]): rgb_images = [self._ensure_rgb(img) for img in images] processed_images = self._processor(images=rgb_images, return_tensors="pt") processed_images = processed_images.to(self.device) return processed_images - def _ensure_rgb(self, img: _Image): + def _ensure_rgb(self, img: PILImage): rgbimg = Image.new("RGB", img.size) rgbimg.paste(img) return rgbimg def __call__( self, - imgs: List[_Image], + imgs: List[PILImage], batch_size: int = 32, ) -> List[List[float]]: all_embeddings = []