From 7a534ced70c65740b560fff70fc419f425403fc3 Mon Sep 17 00:00:00 2001 From: James Briggs <james.briggs@hotmail.com> Date: Mon, 26 Aug 2024 12:15:44 +0200 Subject: [PATCH] fix: tests and errors messages --- semantic_router/encoders/vit.py | 6 +++--- tests/unit/encoders/test_vit.py | 19 ++++++++++--------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/semantic_router/encoders/vit.py b/semantic_router/encoders/vit.py index 9002ec23..44ae5801 100644 --- a/semantic_router/encoders/vit.py +++ b/semantic_router/encoders/vit.py @@ -27,7 +27,7 @@ class VitEncoder(BaseEncoder): from transformers import ViTImageProcessor, ViTModel except ImportError: raise ImportError( - "Please install transformers to use HuggingFaceEncoder. " + "Please install transformers to use VitEncoder. " "You can install it with: " "`pip install semantic-router[vision]`" ) @@ -37,7 +37,7 @@ class VitEncoder(BaseEncoder): import torchvision.transforms as T except ImportError: raise ImportError( - "Please install Pytorch to use HuggingFaceEncoder. " + "Please install Pytorch to use VitEncoder. " "You can install it with: " "`pip install semantic-router[vision]`" ) @@ -46,7 +46,7 @@ class VitEncoder(BaseEncoder): from PIL import Image except ImportError: raise ImportError( - "Please install PIL to use HuggingFaceEncoder. " + "Please install PIL to use VitEncoder. " "You can install it with: " "`pip install semantic-router[vision]`" ) diff --git a/tests/unit/encoders/test_vit.py b/tests/unit/encoders/test_vit.py index 85e66a9f..f1180985 100644 --- a/tests/unit/encoders/test_vit.py +++ b/tests/unit/encoders/test_vit.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import os import numpy as np import pytest @@ -33,15 +35,14 @@ def misshaped_pil_image(): class TestVitEncoder: - def test_vit_encoder__import_errors_transformers(self, mocker): - mocker.patch.dict("sys.modules", {"transformers": None}) - with pytest.raises(ImportError): - VitEncoder() - - def test_vit_encoder__import_errors_torch(self, mocker): - mocker.patch.dict("sys.modules", {"torch": None}) - with pytest.raises(ImportError): - VitEncoder() + def test_vit_encoder__import_errors_transformers(self): + with patch.dict("sys.modules", {"transformers": None}): + with pytest.raises(ImportError) as error: + VitEncoder() + + assert "Please install transformers to use VitEncoder" in str( + error.value + ) @pytest.mark.skipif( os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run" -- GitLab