diff --git a/semantic_router/encoders/vit.py b/semantic_router/encoders/vit.py index 9002ec23dcf6f5c3dac0c5f295b51c8304ba5b32..44ae58018635e67a10d37150a1de1c783a4cee97 100644 --- a/semantic_router/encoders/vit.py +++ b/semantic_router/encoders/vit.py @@ -27,7 +27,7 @@ class VitEncoder(BaseEncoder): from transformers import ViTImageProcessor, ViTModel except ImportError: raise ImportError( - "Please install transformers to use HuggingFaceEncoder. " + "Please install transformers to use VitEncoder. " "You can install it with: " "`pip install semantic-router[vision]`" ) @@ -37,7 +37,7 @@ class VitEncoder(BaseEncoder): import torchvision.transforms as T except ImportError: raise ImportError( - "Please install Pytorch to use HuggingFaceEncoder. " + "Please install Pytorch to use VitEncoder. " "You can install it with: " "`pip install semantic-router[vision]`" ) @@ -46,7 +46,7 @@ class VitEncoder(BaseEncoder): from PIL import Image except ImportError: raise ImportError( - "Please install PIL to use HuggingFaceEncoder. " + "Please install PIL to use VitEncoder. " "You can install it with: " "`pip install semantic-router[vision]`" ) diff --git a/tests/unit/encoders/test_vit.py b/tests/unit/encoders/test_vit.py index 85e66a9f961aa6b68d5b8f330988d713faab5f1f..b598a818a3af0753a0627fead4801a56c5b03f54 100644 --- a/tests/unit/encoders/test_vit.py +++ b/tests/unit/encoders/test_vit.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import os import numpy as np import pytest @@ -33,15 +35,12 @@ def misshaped_pil_image(): class TestVitEncoder: - def test_vit_encoder__import_errors_transformers(self, mocker): - mocker.patch.dict("sys.modules", {"transformers": None}) - with pytest.raises(ImportError): - VitEncoder() - - def test_vit_encoder__import_errors_torch(self, mocker): - mocker.patch.dict("sys.modules", {"torch": None}) - with pytest.raises(ImportError): - VitEncoder() + def test_vit_encoder__import_errors_transformers(self): + with patch.dict("sys.modules", {"transformers": None}): + with pytest.raises(ImportError) as error: + VitEncoder() + + assert "Please install transformers to use VitEncoder" in str(error.value) @pytest.mark.skipif( os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"