Skip to content
Snippets Groups Projects
Unverified Commit a6b3914f authored by James Briggs's avatar James Briggs
Browse files

fix: adjust test params

parent 3b0bfc96
No related branches found
No related tags found
No related merge requests found
import os
import numpy as np
import pytest
import torch
......@@ -6,7 +7,6 @@ from PIL import Image
from semantic_router.encoders import CLIPEncoder
test_model_name = "aurelio-ai/sr-test-clip"
clip_encoder = CLIPEncoder(name=test_model_name)
embed_dim = 64
if torch.cuda.is_available():
......@@ -43,35 +43,59 @@ class TestVitEncoder:
with pytest.raises(ImportError):
CLIPEncoder()
@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_clip_encoder_initialization(self):
clip_encoder = CLIPEncoder(name=test_model_name)
assert clip_encoder.name == test_model_name
assert clip_encoder.type == "huggingface"
assert clip_encoder.score_threshold == 0.2
assert clip_encoder.device == device
@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_clip_encoder_call_text(self):
clip_encoder = CLIPEncoder(name=test_model_name)
embeddings = clip_encoder(["hello", "world"])
assert len(embeddings) == 2
assert len(embeddings[0]) == embed_dim
@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_clip_encoder_call_image(self, dummy_pil_image):
clip_encoder = CLIPEncoder(name=test_model_name)
encoded_images = clip_encoder([dummy_pil_image] * 3)
assert len(encoded_images) == 3
assert set(map(len, encoded_images)) == {embed_dim}
@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_clip_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image):
clip_encoder = CLIPEncoder(name=test_model_name)
encoded_images = clip_encoder([dummy_pil_image, misshaped_pil_image])
assert len(encoded_images) == 2
assert set(map(len, encoded_images)) == {embed_dim}
@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_clip_device(self):
clip_encoder = CLIPEncoder(name=test_model_name)
device = clip_encoder._model.device.type
assert device == device
@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_clip_encoder_ensure_rgb(self, dummy_black_and_white_img):
clip_encoder = CLIPEncoder(name=test_model_name)
rgb_image = clip_encoder._ensure_rgb(dummy_black_and_white_img)
assert rgb_image.mode == "RGB"
......
from unittest.mock import patch
import os
import numpy as np
import pytest
from semantic_router.encoders.huggingface import HuggingFaceEncoder
test_model_name = "aurelio-ai/sr-test-huggingface"
encoder = HuggingFaceEncoder(name=test_model_name)
class TestHuggingFaceEncoder:
......@@ -26,7 +26,11 @@ class TestHuggingFaceEncoder:
assert "Please install Pytorch to use HuggingFaceEncoder" in str(error.value)
@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_huggingface_encoder_mean_pooling(self):
encoder = HuggingFaceEncoder(name=test_model_name)
test_docs = ["This is a test", "This is another test"]
embeddings = encoder(test_docs, pooling_strategy="mean")
assert isinstance(embeddings, list)
......@@ -34,7 +38,11 @@ class TestHuggingFaceEncoder:
assert all(isinstance(embedding, list) for embedding in embeddings)
assert all(len(embedding) > 0 for embedding in embeddings)
@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_huggingface_encoder_max_pooling(self):
encoder = HuggingFaceEncoder(name=test_model_name)
test_docs = ["This is a test", "This is another test"]
embeddings = encoder(test_docs, pooling_strategy="max")
assert isinstance(embeddings, list)
......@@ -42,7 +50,11 @@ class TestHuggingFaceEncoder:
assert all(isinstance(embedding, list) for embedding in embeddings)
assert all(len(embedding) > 0 for embedding in embeddings)
@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_huggingface_encoder_normalized_embeddings(self):
encoder = HuggingFaceEncoder(name=test_model_name)
docs = ["This is a test document.", "Another test document."]
unnormalized_embeddings = encoder(docs, normalize_embeddings=False)
normalized_embeddings = encoder(docs, normalize_embeddings=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment