From 6cc3e3467e356d740bc42e662956e53148ee5798 Mon Sep 17 00:00:00 2001
From: James Briggs <james.briggs@hotmail.com>
Date: Wed, 12 Jun 2024 13:29:53 +0800
Subject: [PATCH] fix: adjust clip import checks

---
 tests/unit/encoders/test_clip.py | 19 ++++++++++++-------
 1 file changed, 12 insertions(+), 7 deletions(-)

diff --git a/tests/unit/encoders/test_clip.py b/tests/unit/encoders/test_clip.py
index 47456ff0..15acb706 100644
--- a/tests/unit/encoders/test_clip.py
+++ b/tests/unit/encoders/test_clip.py
@@ -3,6 +3,7 @@ import numpy as np
 import pytest
 import torch
 from PIL import Image
+from unittest.mock import patch
 
 from semantic_router.encoders import CLIPEncoder
 
@@ -32,16 +33,20 @@ def misshaped_pil_image():
     return Image.fromarray(np.random.rand(64, 64, 3).astype(np.uint8))
 
 
-class TestVitEncoder:
+class TestClipEncoder:
     def test_clip_encoder__import_errors_transformers(self, mocker):
-        mocker.patch.dict("sys.modules", {"transformers": None})
-        with pytest.raises(ImportError):
-            CLIPEncoder()
+        with patch.dict("sys.modules", {"transformers": None}):
+            with pytest.raises(ImportError) as error:
+                CLIPEncoder()
+        
+        assert "install transformers" in str(error.value)
 
     def test_clip_encoder__import_errors_torch(self, mocker):
-        mocker.patch.dict("sys.modules", {"torch": None})
-        with pytest.raises(ImportError):
-            CLIPEncoder()
+        with patch.dict("sys.modules", {"torch": None}):
+            with pytest.raises(ImportError) as error:
+                CLIPEncoder()
+
+        assert "install Pytorch" in str(error.value)
 
     @pytest.mark.skipif(
         os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
-- 
GitLab