From a6b3914f12f0449a78655363310e7a301016adfd Mon Sep 17 00:00:00 2001
From: James Briggs <james.briggs@hotmail.com>
Date: Wed, 12 Jun 2024 13:16:58 +0800
Subject: [PATCH] fix: adjust test params

---
 tests/unit/encoders/test_clip.py        | 26 ++++++++++++++++++++++++-
 tests/unit/encoders/test_huggingface.py | 14 ++++++++++++-
 2 files changed, 38 insertions(+), 2 deletions(-)

diff --git a/tests/unit/encoders/test_clip.py b/tests/unit/encoders/test_clip.py
index de997e36..47456ff0 100644
--- a/tests/unit/encoders/test_clip.py
+++ b/tests/unit/encoders/test_clip.py
@@ -1,3 +1,4 @@
+import os
 import numpy as np
 import pytest
 import torch
@@ -6,7 +7,6 @@ from PIL import Image
 from semantic_router.encoders import CLIPEncoder
 
 test_model_name = "aurelio-ai/sr-test-clip"
-clip_encoder = CLIPEncoder(name=test_model_name)
 embed_dim = 64
 
 if torch.cuda.is_available():
@@ -43,35 +43,59 @@ class TestVitEncoder:
         with pytest.raises(ImportError):
             CLIPEncoder()
 
+    @pytest.mark.skipif(
+        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
+    )
     def test_clip_encoder_initialization(self):
+        clip_encoder = CLIPEncoder(name=test_model_name)
         assert clip_encoder.name == test_model_name
         assert clip_encoder.type == "huggingface"
         assert clip_encoder.score_threshold == 0.2
         assert clip_encoder.device == device
 
+    @pytest.mark.skipif(
+        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
+    )
     def test_clip_encoder_call_text(self):
+        clip_encoder = CLIPEncoder(name=test_model_name)
         embeddings = clip_encoder(["hello", "world"])
 
         assert len(embeddings) == 2
         assert len(embeddings[0]) == embed_dim
 
+    @pytest.mark.skipif(
+        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
+    )
     def test_clip_encoder_call_image(self, dummy_pil_image):
+        clip_encoder = CLIPEncoder(name=test_model_name)
         encoded_images = clip_encoder([dummy_pil_image] * 3)
 
         assert len(encoded_images) == 3
         assert set(map(len, encoded_images)) == {embed_dim}
 
+    @pytest.mark.skipif(
+        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
+    )
     def test_clip_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image):
+        clip_encoder = CLIPEncoder(name=test_model_name)
         encoded_images = clip_encoder([dummy_pil_image, misshaped_pil_image])
 
         assert len(encoded_images) == 2
         assert set(map(len, encoded_images)) == {embed_dim}
 
+    @pytest.mark.skipif(
+        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
+    )
     def test_clip_device(self):
+        clip_encoder = CLIPEncoder(name=test_model_name)
         device = clip_encoder._model.device.type
         assert device == device
 
+    @pytest.mark.skipif(
+        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
+    )
     def test_clip_encoder_ensure_rgb(self, dummy_black_and_white_img):
+        clip_encoder = CLIPEncoder(name=test_model_name)
         rgb_image = clip_encoder._ensure_rgb(dummy_black_and_white_img)
 
         assert rgb_image.mode == "RGB"
diff --git a/tests/unit/encoders/test_huggingface.py b/tests/unit/encoders/test_huggingface.py
index f14c7a68..b615a87d 100644
--- a/tests/unit/encoders/test_huggingface.py
+++ b/tests/unit/encoders/test_huggingface.py
@@ -1,12 +1,12 @@
 from unittest.mock import patch
 
+import os
 import numpy as np
 import pytest
 
 from semantic_router.encoders.huggingface import HuggingFaceEncoder
 
 test_model_name = "aurelio-ai/sr-test-huggingface"
-encoder = HuggingFaceEncoder(name=test_model_name)
 
 
 class TestHuggingFaceEncoder:
@@ -26,7 +26,11 @@ class TestHuggingFaceEncoder:
 
         assert "Please install Pytorch to use HuggingFaceEncoder" in str(error.value)
 
+    @pytest.mark.skipif(
+        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
+    )
     def test_huggingface_encoder_mean_pooling(self):
+        encoder = HuggingFaceEncoder(name=test_model_name)
         test_docs = ["This is a test", "This is another test"]
         embeddings = encoder(test_docs, pooling_strategy="mean")
         assert isinstance(embeddings, list)
@@ -34,7 +38,11 @@ class TestHuggingFaceEncoder:
         assert all(isinstance(embedding, list) for embedding in embeddings)
         assert all(len(embedding) > 0 for embedding in embeddings)
 
+    @pytest.mark.skipif(
+        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
+    )
     def test_huggingface_encoder_max_pooling(self):
+        encoder = HuggingFaceEncoder(name=test_model_name)
         test_docs = ["This is a test", "This is another test"]
         embeddings = encoder(test_docs, pooling_strategy="max")
         assert isinstance(embeddings, list)
@@ -42,7 +50,11 @@ class TestHuggingFaceEncoder:
         assert all(isinstance(embedding, list) for embedding in embeddings)
         assert all(len(embedding) > 0 for embedding in embeddings)
 
+    @pytest.mark.skipif(
+        os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
+    )
     def test_huggingface_encoder_normalized_embeddings(self):
+        encoder = HuggingFaceEncoder(name=test_model_name)
         docs = ["This is a test document.", "Another test document."]
         unnormalized_embeddings = encoder(docs, normalize_embeddings=False)
         normalized_embeddings = encoder(docs, normalize_embeddings=True)
-- 
GitLab