From 037b59f9978671991fcb0814fe242d2dd47818d2 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Tue, 20 Feb 2024 01:24:11 +0400
Subject: [PATCH] cleanup and remove pil from initial imports

---
 semantic_router/encoders/__init__.py |  2 +
 semantic_router/encoders/clip.py     | 29 +++++++++---
 semantic_router/encoders/vit.py      | 25 ++++++----
 tests/unit/encoders/test_clip.py     | 68 ++++++++++++++++++++++++++++
 4 files changed, 110 insertions(+), 14 deletions(-)
 create mode 100644 tests/unit/encoders/test_clip.py

diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py
index 9e74bb1a..d351028e 100644
--- a/semantic_router/encoders/__init__.py
+++ b/semantic_router/encoders/__init__.py
@@ -1,5 +1,6 @@
 from semantic_router.encoders.base import BaseEncoder
 from semantic_router.encoders.bm25 import BM25Encoder
+from semantic_router.encoders.clip import CLIPEncoder
 from semantic_router.encoders.cohere import CohereEncoder
 from semantic_router.encoders.fastembed import FastEmbedEncoder
 from semantic_router.encoders.huggingface import HuggingFaceEncoder
@@ -20,4 +21,5 @@ __all__ = [
     "HuggingFaceEncoder",
     "MistralEncoder",
     "VitEncoder",
+    "CLIPEncoder",
 ]
diff --git a/semantic_router/encoders/clip.py b/semantic_router/encoders/clip.py
index 029ca614..c1b87779 100644
--- a/semantic_router/encoders/clip.py
+++ b/semantic_router/encoders/clip.py
@@ -18,6 +18,7 @@ class CLIPEncoder(BaseEncoder):
     _processor: Any = PrivateAttr()
     _model: Any = PrivateAttr()
     _torch: Any = PrivateAttr()
+    _Image: Any = PrivateAttr()
 
     def __init__(self, **data):
         super().__init__(**data)
@@ -37,9 +38,9 @@ class CLIPEncoder(BaseEncoder):
         for i in range(0, len(docs), batch_size):
             batch_docs = docs[i : i + batch_size]
             if text:
-                embeddings = self._encode_text(batch_docs)
+                embeddings = self._encode_text(docs=batch_docs)
             else:
-                embeddings = self._encode_image(batch_docs)
+                embeddings = self._encode_image(images=batch_docs)
             if normalize_embeddings:
                 embeddings = embeddings / np.linalg.norm(embeddings, axis=0)
 
@@ -54,7 +55,7 @@ class CLIPEncoder(BaseEncoder):
             raise ImportError(
                 "Please install transformers to use CLIPEncoder. "
                 "You can install it with: "
-                "`pip install semantic-router[local]`"
+                "`pip install semantic-router[vision]`"
             )
 
         try:
@@ -63,10 +64,20 @@ class CLIPEncoder(BaseEncoder):
             raise ImportError(
                 "Please install Pytorch to use CLIPEncoder. "
                 "You can install it with: "
-                "`pip install semantic-router[local]`"
+                "`pip install semantic-router[vision]`"
+            )
+        
+        try:
+            from PIL import Image
+        except ImportError:
+            raise ImportError(
+                "Please install PIL to use HuggingFaceEncoder. "
+                "You can install it with: "
+                "`pip install semantic-router[vision]`"
             )
 
         self._torch = torch
+        self._Image = Image
 
         tokenizer = CLIPTokenizerFast.from_pretrained(
             self.name,
@@ -94,11 +105,17 @@ class CLIPEncoder(BaseEncoder):
             embeds = embeds.squeeze(0).cpu().detach().numpy()
         return embeds
 
-    def _encode_image(self, docs: List[Any]) -> Any:
-        inputs = self._processor(text=None, images=docs, return_tensors="pt")[
+    def _encode_image(self, images: List[Any]) -> Any:
+        rgb_images = [self._ensure_rgb(img) for img in images]
+        inputs = self._processor(text=None, images=rgb_images, return_tensors="pt")[
             "pixel_values"
         ].to(self.device)
         with self._torch.no_grad():
             embeds = self._model.get_image_features(pixel_values=inputs)
             embeds = embeds.squeeze(0).cpu().detach().numpy()
         return embeds
+    
+    def _ensure_rgb(self, img: Any):
+        rgbimg = self._Image.new("RGB", img.size)
+        rgbimg.paste(img)
+        return rgbimg
diff --git a/semantic_router/encoders/vit.py b/semantic_router/encoders/vit.py
index d2f28dda..9aba6257 100644
--- a/semantic_router/encoders/vit.py
+++ b/semantic_router/encoders/vit.py
@@ -1,7 +1,5 @@
 from typing import Any, List, Optional
 
-from PIL import Image
-from PIL.Image import Image as _Image
 from pydantic.v1 import PrivateAttr
 
 from semantic_router.encoders import BaseEncoder
@@ -18,6 +16,7 @@ class VitEncoder(BaseEncoder):
     _model: Any = PrivateAttr()
     _torch: Any = PrivateAttr()
     _T: Any = PrivateAttr()
+    _Image: Any = PrivateAttr()
 
     def __init__(self, **data):
         super().__init__(**data)
@@ -30,7 +29,7 @@ class VitEncoder(BaseEncoder):
             raise ImportError(
                 "Please install transformers to use HuggingFaceEncoder. "
                 "You can install it with: "
-                "`pip install semantic-router[local]`"
+                "`pip install semantic-router[vision]`"
             )
 
         try:
@@ -40,10 +39,20 @@ class VitEncoder(BaseEncoder):
             raise ImportError(
                 "Please install Pytorch to use HuggingFaceEncoder. "
                 "You can install it with: "
-                "`pip install semantic-router[local]`"
+                "`pip install semantic-router[vision]`"
+            )
+        
+        try:
+            from PIL import Image
+        except ImportError:
+            raise ImportError(
+                "Please install PIL to use HuggingFaceEncoder. "
+                "You can install it with: "
+                "`pip install semantic-router[vision]`"
             )
 
         self._torch = torch
+        self._Image = Image
         self._T = T
 
         processor = ViTImageProcessor.from_pretrained(
@@ -62,20 +71,20 @@ class VitEncoder(BaseEncoder):
 
         return processor, model
 
-    def _process_images(self, images: List[_Image]):
+    def _process_images(self, images: List[Any]):
         rgb_images = [self._ensure_rgb(img) for img in images]
         processed_images = self._processor(images=rgb_images, return_tensors="pt")
         processed_images = processed_images.to(self.device)
         return processed_images
 
-    def _ensure_rgb(self, img: _Image):
-        rgbimg = Image.new("RGB", img.size)
+    def _ensure_rgb(self, img: Any):
+        rgbimg = self._Image.new("RGB", img.size)
         rgbimg.paste(img)
         return rgbimg
 
     def __call__(
         self,
-        imgs: List[_Image],
+        imgs: List[Any],
         batch_size: int = 32,
     ) -> List[List[float]]:
         all_embeddings = []
diff --git a/tests/unit/encoders/test_clip.py b/tests/unit/encoders/test_clip.py
new file mode 100644
index 00000000..f861f1ee
--- /dev/null
+++ b/tests/unit/encoders/test_clip.py
@@ -0,0 +1,68 @@
+import numpy as np
+import pytest
+from PIL import Image
+
+from semantic_router.encoders import CLIPEncoder
+
+clip_encoder = CLIPEncoder()
+
+
+@pytest.fixture()
+def dummy_pil_image():
+    return Image.fromarray(np.random.rand(1024, 512, 3).astype(np.uint8))
+
+
+@pytest.fixture()
+def dummy_black_and_white_img():
+    return Image.fromarray(np.random.rand(224, 224, 2).astype(np.uint8))
+
+
+@pytest.fixture()
+def misshaped_pil_image():
+    return Image.fromarray(np.random.rand(64, 64, 3).astype(np.uint8))
+
+
+class TestVitEncoder:
+    def test_clip_encoder__import_errors_transformers(self, mocker):
+        mocker.patch.dict("sys.modules", {"transformers": None})
+        with pytest.raises(ImportError):
+            CLIPEncoder()
+
+    def test_clip_encoder__import_errors_torch(self, mocker):
+        mocker.patch.dict("sys.modules", {"torch": None})
+        with pytest.raises(ImportError):
+            CLIPEncoder()
+
+    def test_clip_encoder_initialization(self):
+        assert clip_encoder.name == "openai/clip-vit-base-patch32"
+        assert clip_encoder.type == "huggingface"
+        assert clip_encoder.score_threshold == 0.2
+        assert clip_encoder.device == "cpu"
+
+    def test_clip_encoder_call_text(self):
+        embeddings = clip_encoder(["hello", "world"])
+
+        assert len(embeddings) == 2
+        assert len(embeddings[0]) == 512
+
+    def test_clip_encoder_call_image(self, dummy_pil_image):
+        encoded_images = clip_encoder([dummy_pil_image] * 3)
+
+        assert len(encoded_images) == 3
+        assert set(map(len, encoded_images)) == {512}
+
+    def test_clip_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image):
+        encoded_images = clip_encoder([dummy_pil_image, misshaped_pil_image])
+
+        assert len(encoded_images) == 2
+        assert set(map(len, encoded_images)) == {512}
+
+    def test_clip_device(self):
+        device = clip_encoder._model.device.type
+        assert device == "cpu"
+
+    def test_clip_encoder_ensure_rgb(self, dummy_black_and_white_img):
+        rgb_image = clip_encoder._ensure_rgb(dummy_black_and_white_img)
+
+        assert rgb_image.mode == "RGB"
+        assert np.array(rgb_image).shape == (224, 224, 3)
-- 
GitLab