Skip to content
Snippets Groups Projects
Unverified Commit 1916fecd authored by James Briggs's avatar James Briggs Committed by GitHub
Browse files

Merge pull request #319 from aurelio-labs/james/update-pinecone-test

fix: add skip to vit
parents 5120edcd c4685d8f
No related branches found
No related tags found
No related merge requests found
import os
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
...@@ -48,29 +49,44 @@ class TestVitEncoder: ...@@ -48,29 +49,44 @@ class TestVitEncoder:
with pytest.raises(ImportError): with pytest.raises(ImportError):
VitEncoder() VitEncoder()
@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_initialization(self): def test_vit_encoder_initialization(self):
assert vit_encoder.name == test_model_name assert vit_encoder.name == test_model_name
assert vit_encoder.type == "huggingface" assert vit_encoder.type == "huggingface"
assert vit_encoder.score_threshold == 0.5 assert vit_encoder.score_threshold == 0.5
assert vit_encoder.device == device assert vit_encoder.device == device
@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_call(self, dummy_pil_image): def test_vit_encoder_call(self, dummy_pil_image):
encoded_images = vit_encoder([dummy_pil_image] * 3) encoded_images = vit_encoder([dummy_pil_image] * 3)
assert len(encoded_images) == 3 assert len(encoded_images) == 3
assert set(map(len, encoded_images)) == {embed_dim} assert set(map(len, encoded_images)) == {embed_dim}
@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image): def test_vit_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image):
encoded_images = vit_encoder([dummy_pil_image, misshaped_pil_image]) encoded_images = vit_encoder([dummy_pil_image, misshaped_pil_image])
assert len(encoded_images) == 2 assert len(encoded_images) == 2
assert set(map(len, encoded_images)) == {embed_dim} assert set(map(len, encoded_images)) == {embed_dim}
@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_process_images_device(self, dummy_pil_image): def test_vit_encoder_process_images_device(self, dummy_pil_image):
imgs = vit_encoder._process_images([dummy_pil_image] * 3)["pixel_values"] imgs = vit_encoder._process_images([dummy_pil_image] * 3)["pixel_values"]
assert imgs.device.type == device assert imgs.device.type == device
@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_ensure_rgb(self, dummy_black_and_white_img): def test_vit_encoder_ensure_rgb(self, dummy_black_and_white_img):
rgb_image = vit_encoder._ensure_rgb(dummy_black_and_white_img) rgb_image = vit_encoder._ensure_rgb(dummy_black_and_white_img)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment