From a6b3914f12f0449a78655363310e7a301016adfd Mon Sep 17 00:00:00 2001 From: James Briggs <james.briggs@hotmail.com> Date: Wed, 12 Jun 2024 13:16:58 +0800 Subject: [PATCH] fix: adjust test params --- tests/unit/encoders/test_clip.py | 26 ++++++++++++++++++++++++- tests/unit/encoders/test_huggingface.py | 14 ++++++++++++- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/tests/unit/encoders/test_clip.py b/tests/unit/encoders/test_clip.py index de997e36..47456ff0 100644 --- a/tests/unit/encoders/test_clip.py +++ b/tests/unit/encoders/test_clip.py @@ -1,3 +1,4 @@ +import os import numpy as np import pytest import torch @@ -6,7 +7,6 @@ from PIL import Image from semantic_router.encoders import CLIPEncoder test_model_name = "aurelio-ai/sr-test-clip" -clip_encoder = CLIPEncoder(name=test_model_name) embed_dim = 64 if torch.cuda.is_available(): @@ -43,35 +43,59 @@ class TestVitEncoder: with pytest.raises(ImportError): CLIPEncoder() + @pytest.mark.skipif( + os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" + ) def test_clip_encoder_initialization(self): + clip_encoder = CLIPEncoder(name=test_model_name) assert clip_encoder.name == test_model_name assert clip_encoder.type == "huggingface" assert clip_encoder.score_threshold == 0.2 assert clip_encoder.device == device + @pytest.mark.skipif( + os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" + ) def test_clip_encoder_call_text(self): + clip_encoder = CLIPEncoder(name=test_model_name) embeddings = clip_encoder(["hello", "world"]) assert len(embeddings) == 2 assert len(embeddings[0]) == embed_dim + @pytest.mark.skipif( + os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" + ) def test_clip_encoder_call_image(self, dummy_pil_image): + clip_encoder = CLIPEncoder(name=test_model_name) encoded_images = clip_encoder([dummy_pil_image] * 3) assert len(encoded_images) == 3 assert set(map(len, encoded_images)) == {embed_dim} + @pytest.mark.skipif( + os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" + ) def test_clip_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image): + clip_encoder = CLIPEncoder(name=test_model_name) encoded_images = clip_encoder([dummy_pil_image, misshaped_pil_image]) assert len(encoded_images) == 2 assert set(map(len, encoded_images)) == {embed_dim} + @pytest.mark.skipif( + os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" + ) def test_clip_device(self): + clip_encoder = CLIPEncoder(name=test_model_name) device = clip_encoder._model.device.type assert device == device + @pytest.mark.skipif( + os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" + ) def test_clip_encoder_ensure_rgb(self, dummy_black_and_white_img): + clip_encoder = CLIPEncoder(name=test_model_name) rgb_image = clip_encoder._ensure_rgb(dummy_black_and_white_img) assert rgb_image.mode == "RGB" diff --git a/tests/unit/encoders/test_huggingface.py b/tests/unit/encoders/test_huggingface.py index f14c7a68..b615a87d 100644 --- a/tests/unit/encoders/test_huggingface.py +++ b/tests/unit/encoders/test_huggingface.py @@ -1,12 +1,12 @@ from unittest.mock import patch +import os import numpy as np import pytest from semantic_router.encoders.huggingface import HuggingFaceEncoder test_model_name = "aurelio-ai/sr-test-huggingface" -encoder = HuggingFaceEncoder(name=test_model_name) class TestHuggingFaceEncoder: @@ -26,7 +26,11 @@ class TestHuggingFaceEncoder: assert "Please install Pytorch to use HuggingFaceEncoder" in str(error.value) + @pytest.mark.skipif( + os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" + ) def test_huggingface_encoder_mean_pooling(self): + encoder = HuggingFaceEncoder(name=test_model_name) test_docs = ["This is a test", "This is another test"] embeddings = encoder(test_docs, pooling_strategy="mean") assert isinstance(embeddings, list) @@ -34,7 +38,11 @@ class TestHuggingFaceEncoder: assert all(isinstance(embedding, list) for embedding in embeddings) assert all(len(embedding) > 0 for embedding in embeddings) + @pytest.mark.skipif( + os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" + ) def test_huggingface_encoder_max_pooling(self): + encoder = HuggingFaceEncoder(name=test_model_name) test_docs = ["This is a test", "This is another test"] embeddings = encoder(test_docs, pooling_strategy="max") assert isinstance(embeddings, list) @@ -42,7 +50,11 @@ class TestHuggingFaceEncoder: assert all(isinstance(embedding, list) for embedding in embeddings) assert all(len(embedding) > 0 for embedding in embeddings) + @pytest.mark.skipif( + os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" + ) def test_huggingface_encoder_normalized_embeddings(self): + encoder = HuggingFaceEncoder(name=test_model_name) docs = ["This is a test document.", "Another test document."] unnormalized_embeddings = encoder(docs, normalize_embeddings=False) normalized_embeddings = encoder(docs, normalize_embeddings=True) -- GitLab