diff --git a/tests/unit/encoders/test_huggingface.py b/tests/unit/encoders/test_huggingface.py index 7a3c6fc5b1daea38aaccd1c627ebffdf1ee9b74b..0aa8cb79be8cb7f2be76720501eb9f030fdb3f2b 100644 --- a/tests/unit/encoders/test_huggingface.py +++ b/tests/unit/encoders/test_huggingface.py @@ -1,20 +1,46 @@ import pytest import numpy as np +from unittest.mock import patch from semantic_router.encoders.huggingface import HuggingFaceEncoder +encoder = HuggingFaceEncoder() + + class TestHuggingFaceEncoder: - def test_huggingface_encoder(self): - encoder = HuggingFaceEncoder() + def test_huggingface_encoder_import_errors_transformers(self): + with patch.dict("sys.modules", {"transformers": None}): + with pytest.raises(ImportError) as error: + HuggingFaceEncoder() + + assert "Please install transformers to use HuggingFaceEncoder" in str( + error.value + ) + + def test_huggingface_encoder_import_errors_torch(self): + with patch.dict("sys.modules", {"torch": None}): + with pytest.raises(ImportError) as error: + HuggingFaceEncoder() + + assert "Please install Pytorch to use HuggingFaceEncoder" in str(error.value) + + def test_huggingface_encoder_mean_pooling(self): + test_docs = ["This is a test", "This is another test"] + embeddings = encoder(test_docs, pooling_strategy="mean") + assert isinstance(embeddings, list) + assert len(embeddings) == len(test_docs) + assert all(isinstance(embedding, list) for embedding in embeddings) + assert all(len(embedding) > 0 for embedding in embeddings) + + def test_huggingface_encoder_max_pooling(self): test_docs = ["This is a test", "This is another test"] - embeddings = encoder(test_docs) + embeddings = encoder(test_docs, pooling_strategy="max") assert isinstance(embeddings, list) assert len(embeddings) == len(test_docs) assert all(isinstance(embedding, list) for embedding in embeddings) assert all(len(embedding) > 0 for embedding in embeddings) def test_huggingface_encoder_normalized_embeddings(self): - encoder = HuggingFaceEncoder() docs = ["This is a test document.", "Another test document."] unnormalized_embeddings = encoder(docs, normalize_embeddings=False) normalized_embeddings = encoder(docs, normalize_embeddings=True) @@ -34,9 +60,3 @@ class TestHuggingFaceEncoder: rtol=1e-5, atol=1e-5, # Adjust tolerance levels ) - - def test_huggingface_encoder_invalid_pooling_strategy(self): - encoder = HuggingFaceEncoder() - docs = ["This is a test document.", "Another test document."] - with pytest.raises(ValueError): - encoder(docs, pooling_strategy="invalid_strategy")