diff --git a/tests/unit/encoders/test_clip.py b/tests/unit/encoders/test_clip.py index 47456ff0121b2e312034844b1eca1b8e96a7a0dc..15acb70638738c94b6c0f5637f9b3859db769147 100644 --- a/tests/unit/encoders/test_clip.py +++ b/tests/unit/encoders/test_clip.py @@ -3,6 +3,7 @@ import numpy as np import pytest import torch from PIL import Image +from unittest.mock import patch from semantic_router.encoders import CLIPEncoder @@ -32,16 +33,20 @@ def misshaped_pil_image(): return Image.fromarray(np.random.rand(64, 64, 3).astype(np.uint8)) -class TestVitEncoder: +class TestClipEncoder: def test_clip_encoder__import_errors_transformers(self, mocker): - mocker.patch.dict("sys.modules", {"transformers": None}) - with pytest.raises(ImportError): - CLIPEncoder() + with patch.dict("sys.modules", {"transformers": None}): + with pytest.raises(ImportError) as error: + CLIPEncoder() + + assert "install transformers" in str(error.value) def test_clip_encoder__import_errors_torch(self, mocker): - mocker.patch.dict("sys.modules", {"torch": None}) - with pytest.raises(ImportError): - CLIPEncoder() + with patch.dict("sys.modules", {"torch": None}): + with pytest.raises(ImportError) as error: + CLIPEncoder() + + assert "install Pytorch" in str(error.value) @pytest.mark.skipif( os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"