From 037b59f9978671991fcb0814fe242d2dd47818d2 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Tue, 20 Feb 2024 01:24:11 +0400 Subject: [PATCH] cleanup and remove pil from initial imports --- semantic_router/encoders/__init__.py | 2 + semantic_router/encoders/clip.py | 29 +++++++++--- semantic_router/encoders/vit.py | 25 ++++++---- tests/unit/encoders/test_clip.py | 68 ++++++++++++++++++++++++++++ 4 files changed, 110 insertions(+), 14 deletions(-) create mode 100644 tests/unit/encoders/test_clip.py diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index 9e74bb1a..d351028e 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -1,5 +1,6 @@ from semantic_router.encoders.base import BaseEncoder from semantic_router.encoders.bm25 import BM25Encoder +from semantic_router.encoders.clip import CLIPEncoder from semantic_router.encoders.cohere import CohereEncoder from semantic_router.encoders.fastembed import FastEmbedEncoder from semantic_router.encoders.huggingface import HuggingFaceEncoder @@ -20,4 +21,5 @@ __all__ = [ "HuggingFaceEncoder", "MistralEncoder", "VitEncoder", + "CLIPEncoder", ] diff --git a/semantic_router/encoders/clip.py b/semantic_router/encoders/clip.py index 029ca614..c1b87779 100644 --- a/semantic_router/encoders/clip.py +++ b/semantic_router/encoders/clip.py @@ -18,6 +18,7 @@ class CLIPEncoder(BaseEncoder): _processor: Any = PrivateAttr() _model: Any = PrivateAttr() _torch: Any = PrivateAttr() + _Image: Any = PrivateAttr() def __init__(self, **data): super().__init__(**data) @@ -37,9 +38,9 @@ class CLIPEncoder(BaseEncoder): for i in range(0, len(docs), batch_size): batch_docs = docs[i : i + batch_size] if text: - embeddings = self._encode_text(batch_docs) + embeddings = self._encode_text(docs=batch_docs) else: - embeddings = self._encode_image(batch_docs) + embeddings = self._encode_image(images=batch_docs) if normalize_embeddings: embeddings = embeddings / np.linalg.norm(embeddings, axis=0) @@ -54,7 +55,7 @@ class CLIPEncoder(BaseEncoder): raise ImportError( "Please install transformers to use CLIPEncoder. " "You can install it with: " - "`pip install semantic-router[local]`" + "`pip install semantic-router[vision]`" ) try: @@ -63,10 +64,20 @@ class CLIPEncoder(BaseEncoder): raise ImportError( "Please install Pytorch to use CLIPEncoder. " "You can install it with: " - "`pip install semantic-router[local]`" + "`pip install semantic-router[vision]`" + ) + + try: + from PIL import Image + except ImportError: + raise ImportError( + "Please install PIL to use HuggingFaceEncoder. " + "You can install it with: " + "`pip install semantic-router[vision]`" ) self._torch = torch + self._Image = Image tokenizer = CLIPTokenizerFast.from_pretrained( self.name, @@ -94,11 +105,17 @@ class CLIPEncoder(BaseEncoder): embeds = embeds.squeeze(0).cpu().detach().numpy() return embeds - def _encode_image(self, docs: List[Any]) -> Any: - inputs = self._processor(text=None, images=docs, return_tensors="pt")[ + def _encode_image(self, images: List[Any]) -> Any: + rgb_images = [self._ensure_rgb(img) for img in images] + inputs = self._processor(text=None, images=rgb_images, return_tensors="pt")[ "pixel_values" ].to(self.device) with self._torch.no_grad(): embeds = self._model.get_image_features(pixel_values=inputs) embeds = embeds.squeeze(0).cpu().detach().numpy() return embeds + + def _ensure_rgb(self, img: Any): + rgbimg = self._Image.new("RGB", img.size) + rgbimg.paste(img) + return rgbimg diff --git a/semantic_router/encoders/vit.py b/semantic_router/encoders/vit.py index d2f28dda..9aba6257 100644 --- a/semantic_router/encoders/vit.py +++ b/semantic_router/encoders/vit.py @@ -1,7 +1,5 @@ from typing import Any, List, Optional -from PIL import Image -from PIL.Image import Image as _Image from pydantic.v1 import PrivateAttr from semantic_router.encoders import BaseEncoder @@ -18,6 +16,7 @@ class VitEncoder(BaseEncoder): _model: Any = PrivateAttr() _torch: Any = PrivateAttr() _T: Any = PrivateAttr() + _Image: Any = PrivateAttr() def __init__(self, **data): super().__init__(**data) @@ -30,7 +29,7 @@ class VitEncoder(BaseEncoder): raise ImportError( "Please install transformers to use HuggingFaceEncoder. " "You can install it with: " - "`pip install semantic-router[local]`" + "`pip install semantic-router[vision]`" ) try: @@ -40,10 +39,20 @@ class VitEncoder(BaseEncoder): raise ImportError( "Please install Pytorch to use HuggingFaceEncoder. " "You can install it with: " - "`pip install semantic-router[local]`" + "`pip install semantic-router[vision]`" + ) + + try: + from PIL import Image + except ImportError: + raise ImportError( + "Please install PIL to use HuggingFaceEncoder. " + "You can install it with: " + "`pip install semantic-router[vision]`" ) self._torch = torch + self._Image = Image self._T = T processor = ViTImageProcessor.from_pretrained( @@ -62,20 +71,20 @@ class VitEncoder(BaseEncoder): return processor, model - def _process_images(self, images: List[_Image]): + def _process_images(self, images: List[Any]): rgb_images = [self._ensure_rgb(img) for img in images] processed_images = self._processor(images=rgb_images, return_tensors="pt") processed_images = processed_images.to(self.device) return processed_images - def _ensure_rgb(self, img: _Image): - rgbimg = Image.new("RGB", img.size) + def _ensure_rgb(self, img: Any): + rgbimg = self._Image.new("RGB", img.size) rgbimg.paste(img) return rgbimg def __call__( self, - imgs: List[_Image], + imgs: List[Any], batch_size: int = 32, ) -> List[List[float]]: all_embeddings = [] diff --git a/tests/unit/encoders/test_clip.py b/tests/unit/encoders/test_clip.py new file mode 100644 index 00000000..f861f1ee --- /dev/null +++ b/tests/unit/encoders/test_clip.py @@ -0,0 +1,68 @@ +import numpy as np +import pytest +from PIL import Image + +from semantic_router.encoders import CLIPEncoder + +clip_encoder = CLIPEncoder() + + +@pytest.fixture() +def dummy_pil_image(): + return Image.fromarray(np.random.rand(1024, 512, 3).astype(np.uint8)) + + +@pytest.fixture() +def dummy_black_and_white_img(): + return Image.fromarray(np.random.rand(224, 224, 2).astype(np.uint8)) + + +@pytest.fixture() +def misshaped_pil_image(): + return Image.fromarray(np.random.rand(64, 64, 3).astype(np.uint8)) + + +class TestVitEncoder: + def test_clip_encoder__import_errors_transformers(self, mocker): + mocker.patch.dict("sys.modules", {"transformers": None}) + with pytest.raises(ImportError): + CLIPEncoder() + + def test_clip_encoder__import_errors_torch(self, mocker): + mocker.patch.dict("sys.modules", {"torch": None}) + with pytest.raises(ImportError): + CLIPEncoder() + + def test_clip_encoder_initialization(self): + assert clip_encoder.name == "openai/clip-vit-base-patch32" + assert clip_encoder.type == "huggingface" + assert clip_encoder.score_threshold == 0.2 + assert clip_encoder.device == "cpu" + + def test_clip_encoder_call_text(self): + embeddings = clip_encoder(["hello", "world"]) + + assert len(embeddings) == 2 + assert len(embeddings[0]) == 512 + + def test_clip_encoder_call_image(self, dummy_pil_image): + encoded_images = clip_encoder([dummy_pil_image] * 3) + + assert len(encoded_images) == 3 + assert set(map(len, encoded_images)) == {512} + + def test_clip_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image): + encoded_images = clip_encoder([dummy_pil_image, misshaped_pil_image]) + + assert len(encoded_images) == 2 + assert set(map(len, encoded_images)) == {512} + + def test_clip_device(self): + device = clip_encoder._model.device.type + assert device == "cpu" + + def test_clip_encoder_ensure_rgb(self, dummy_black_and_white_img): + rgb_image = clip_encoder._ensure_rgb(dummy_black_and_white_img) + + assert rgb_image.mode == "RGB" + assert np.array(rgb_image).shape == (224, 224, 3) -- GitLab