From 6cc3e3467e356d740bc42e662956e53148ee5798 Mon Sep 17 00:00:00 2001 From: James Briggs <james.briggs@hotmail.com> Date: Wed, 12 Jun 2024 13:29:53 +0800 Subject: [PATCH] fix: adjust clip import checks --- tests/unit/encoders/test_clip.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/unit/encoders/test_clip.py b/tests/unit/encoders/test_clip.py index 47456ff0..15acb706 100644 --- a/tests/unit/encoders/test_clip.py +++ b/tests/unit/encoders/test_clip.py @@ -3,6 +3,7 @@ import numpy as np import pytest import torch from PIL import Image +from unittest.mock import patch from semantic_router.encoders import CLIPEncoder @@ -32,16 +33,20 @@ def misshaped_pil_image(): return Image.fromarray(np.random.rand(64, 64, 3).astype(np.uint8)) -class TestVitEncoder: +class TestClipEncoder: def test_clip_encoder__import_errors_transformers(self, mocker): - mocker.patch.dict("sys.modules", {"transformers": None}) - with pytest.raises(ImportError): - CLIPEncoder() + with patch.dict("sys.modules", {"transformers": None}): + with pytest.raises(ImportError) as error: + CLIPEncoder() + + assert "install transformers" in str(error.value) def test_clip_encoder__import_errors_torch(self, mocker): - mocker.patch.dict("sys.modules", {"torch": None}) - with pytest.raises(ImportError): - CLIPEncoder() + with patch.dict("sys.modules", {"torch": None}): + with pytest.raises(ImportError) as error: + CLIPEncoder() + + assert "install Pytorch" in str(error.value) @pytest.mark.skipif( os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" -- GitLab