From 2ab6116bdcd613244b19a7077b4bcab3eccc85d5 Mon Sep 17 00:00:00 2001
From: James Briggs <james.briggs@hotmail.com>
Date: Sat, 2 Mar 2024 16:16:12 +0800
Subject: [PATCH] update vit and clip mock models

---
 coverage.xml                     |  2 +-
 tests/unit/encoders/test_clip.py | 12 +++++++-----
 tests/unit/encoders/test_vit.py  | 10 ++++++----
 3 files changed, 14 insertions(+), 10 deletions(-)

diff --git a/coverage.xml b/coverage.xml
index eac3fdcb..fefa412d 100644
--- a/coverage.xml
+++ b/coverage.xml
@@ -1,5 +1,5 @@
 <?xml version="1.0" ?>
-<coverage version="7.4.3" timestamp="1709363338273" lines-valid="2072" lines-covered="1619" line-rate="0.7814" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0">
+<coverage version="7.4.3" timestamp="1709366989240" lines-valid="2072" lines-covered="1619" line-rate="0.7814" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0">
 	<!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.4.3 -->
 	<!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd -->
 	<sources>
diff --git a/tests/unit/encoders/test_clip.py b/tests/unit/encoders/test_clip.py
index fc057e85..ebbb80e5 100644
--- a/tests/unit/encoders/test_clip.py
+++ b/tests/unit/encoders/test_clip.py
@@ -5,7 +5,9 @@ from PIL import Image
 
 from semantic_router.encoders import CLIPEncoder
 
-clip_encoder = CLIPEncoder()
+test_model_name = "hf-internal-testing/tiny-random-CLIPModel"
+clip_encoder = CLIPEncoder(name=test_model_name)
+embed_dim = 64
 
 if torch.cuda.is_available():
     device = "cuda"
@@ -42,7 +44,7 @@ class TestVitEncoder:
             CLIPEncoder()
 
     def test_clip_encoder_initialization(self):
-        assert clip_encoder.name == "openai/clip-vit-base-patch16"
+        assert clip_encoder.name == test_model_name
         assert clip_encoder.type == "huggingface"
         assert clip_encoder.score_threshold == 0.2
         assert clip_encoder.device == device
@@ -51,19 +53,19 @@ class TestVitEncoder:
         embeddings = clip_encoder(["hello", "world"])
 
         assert len(embeddings) == 2
-        assert len(embeddings[0]) == 512
+        assert len(embeddings[0]) == embed_dim
 
     def test_clip_encoder_call_image(self, dummy_pil_image):
         encoded_images = clip_encoder([dummy_pil_image] * 3)
 
         assert len(encoded_images) == 3
-        assert set(map(len, encoded_images)) == {512}
+        assert set(map(len, encoded_images)) == {embed_dim}
 
     def test_clip_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image):
         encoded_images = clip_encoder([dummy_pil_image, misshaped_pil_image])
 
         assert len(encoded_images) == 2
-        assert set(map(len, encoded_images)) == {512}
+        assert set(map(len, encoded_images)) == {embed_dim}
 
     def test_clip_device(self):
         device = clip_encoder._model.device.type
diff --git a/tests/unit/encoders/test_vit.py b/tests/unit/encoders/test_vit.py
index 0df7f462..72d01b34 100644
--- a/tests/unit/encoders/test_vit.py
+++ b/tests/unit/encoders/test_vit.py
@@ -5,7 +5,9 @@ from PIL import Image
 
 from semantic_router.encoders import VitEncoder
 
-vit_encoder = VitEncoder()
+test_model_name = "hf-internal-testing/tiny-random-vit"
+vit_encoder = VitEncoder(name=test_model_name)
+embed_dim = 32
 
 if torch.cuda.is_available():
     device = "cuda"
@@ -47,7 +49,7 @@ class TestVitEncoder:
             VitEncoder()
 
     def test_vit_encoder_initialization(self):
-        assert vit_encoder.name == "google/vit-base-patch16-224"
+        assert vit_encoder.name == test_model_name
         assert vit_encoder.type == "huggingface"
         assert vit_encoder.score_threshold == 0.5
         assert vit_encoder.device == device
@@ -56,13 +58,13 @@ class TestVitEncoder:
         encoded_images = vit_encoder([dummy_pil_image] * 3)
 
         assert len(encoded_images) == 3
-        assert set(map(len, encoded_images)) == {768}
+        assert set(map(len, encoded_images)) == {embed_dim}
 
     def test_vit_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image):
         encoded_images = vit_encoder([dummy_pil_image, misshaped_pil_image])
 
         assert len(encoded_images) == 2
-        assert set(map(len, encoded_images)) == {768}
+        assert set(map(len, encoded_images)) == {embed_dim}
 
     def test_vit_encoder_process_images_device(self, dummy_pil_image):
         imgs = vit_encoder._process_images([dummy_pil_image] * 3)["pixel_values"]
-- 
GitLab