From 2ab6116bdcd613244b19a7077b4bcab3eccc85d5 Mon Sep 17 00:00:00 2001 From: James Briggs <james.briggs@hotmail.com> Date: Sat, 2 Mar 2024 16:16:12 +0800 Subject: [PATCH] update vit and clip mock models --- coverage.xml | 2 +- tests/unit/encoders/test_clip.py | 12 +++++++----- tests/unit/encoders/test_vit.py | 10 ++++++---- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/coverage.xml b/coverage.xml index eac3fdcb..fefa412d 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,5 +1,5 @@ <?xml version="1.0" ?> -<coverage version="7.4.3" timestamp="1709363338273" lines-valid="2072" lines-covered="1619" line-rate="0.7814" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> +<coverage version="7.4.3" timestamp="1709366989240" lines-valid="2072" lines-covered="1619" line-rate="0.7814" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> <!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.4.3 --> <!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd --> <sources> diff --git a/tests/unit/encoders/test_clip.py b/tests/unit/encoders/test_clip.py index fc057e85..ebbb80e5 100644 --- a/tests/unit/encoders/test_clip.py +++ b/tests/unit/encoders/test_clip.py @@ -5,7 +5,9 @@ from PIL import Image from semantic_router.encoders import CLIPEncoder -clip_encoder = CLIPEncoder() +test_model_name = "hf-internal-testing/tiny-random-CLIPModel" +clip_encoder = CLIPEncoder(name=test_model_name) +embed_dim = 64 if torch.cuda.is_available(): device = "cuda" @@ -42,7 +44,7 @@ class TestVitEncoder: CLIPEncoder() def test_clip_encoder_initialization(self): - assert clip_encoder.name == "openai/clip-vit-base-patch16" + assert clip_encoder.name == test_model_name assert clip_encoder.type == "huggingface" assert clip_encoder.score_threshold == 0.2 assert clip_encoder.device == device @@ -51,19 +53,19 @@ class TestVitEncoder: embeddings = clip_encoder(["hello", "world"]) assert len(embeddings) == 2 - assert len(embeddings[0]) == 512 + assert len(embeddings[0]) == embed_dim def test_clip_encoder_call_image(self, dummy_pil_image): encoded_images = clip_encoder([dummy_pil_image] * 3) assert len(encoded_images) == 3 - assert set(map(len, encoded_images)) == {512} + assert set(map(len, encoded_images)) == {embed_dim} def test_clip_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image): encoded_images = clip_encoder([dummy_pil_image, misshaped_pil_image]) assert len(encoded_images) == 2 - assert set(map(len, encoded_images)) == {512} + assert set(map(len, encoded_images)) == {embed_dim} def test_clip_device(self): device = clip_encoder._model.device.type diff --git a/tests/unit/encoders/test_vit.py b/tests/unit/encoders/test_vit.py index 0df7f462..72d01b34 100644 --- a/tests/unit/encoders/test_vit.py +++ b/tests/unit/encoders/test_vit.py @@ -5,7 +5,9 @@ from PIL import Image from semantic_router.encoders import VitEncoder -vit_encoder = VitEncoder() +test_model_name = "hf-internal-testing/tiny-random-vit" +vit_encoder = VitEncoder(name=test_model_name) +embed_dim = 32 if torch.cuda.is_available(): device = "cuda" @@ -47,7 +49,7 @@ class TestVitEncoder: VitEncoder() def test_vit_encoder_initialization(self): - assert vit_encoder.name == "google/vit-base-patch16-224" + assert vit_encoder.name == test_model_name assert vit_encoder.type == "huggingface" assert vit_encoder.score_threshold == 0.5 assert vit_encoder.device == device @@ -56,13 +58,13 @@ class TestVitEncoder: encoded_images = vit_encoder([dummy_pil_image] * 3) assert len(encoded_images) == 3 - assert set(map(len, encoded_images)) == {768} + assert set(map(len, encoded_images)) == {embed_dim} def test_vit_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image): encoded_images = vit_encoder([dummy_pil_image, misshaped_pil_image]) assert len(encoded_images) == 2 - assert set(map(len, encoded_images)) == {768} + assert set(map(len, encoded_images)) == {embed_dim} def test_vit_encoder_process_images_device(self, dummy_pil_image): imgs = vit_encoder._process_images([dummy_pil_image] * 3)["pixel_values"] -- GitLab