Skip to content
Snippets Groups Projects
Commit 21912a6a authored by Ismail Ashraq's avatar Ismail Ashraq
Browse files

add more tests

parent 7edeb32e
Branches
Tags
No related merge requests found
import pytest import pytest
import numpy as np import numpy as np
from unittest.mock import patch
from semantic_router.encoders.huggingface import HuggingFaceEncoder from semantic_router.encoders.huggingface import HuggingFaceEncoder
encoder = HuggingFaceEncoder()
class TestHuggingFaceEncoder: class TestHuggingFaceEncoder:
def test_huggingface_encoder(self): def test_huggingface_encoder_import_errors_transformers(self):
encoder = HuggingFaceEncoder() with patch.dict("sys.modules", {"transformers": None}):
with pytest.raises(ImportError) as error:
HuggingFaceEncoder()
assert "Please install transformers to use HuggingFaceEncoder" in str(
error.value
)
def test_huggingface_encoder_import_errors_torch(self):
with patch.dict("sys.modules", {"torch": None}):
with pytest.raises(ImportError) as error:
HuggingFaceEncoder()
assert "Please install Pytorch to use HuggingFaceEncoder" in str(error.value)
def test_huggingface_encoder_mean_pooling(self):
test_docs = ["This is a test", "This is another test"]
embeddings = encoder(test_docs, pooling_strategy="mean")
assert isinstance(embeddings, list)
assert len(embeddings) == len(test_docs)
assert all(isinstance(embedding, list) for embedding in embeddings)
assert all(len(embedding) > 0 for embedding in embeddings)
def test_huggingface_encoder_max_pooling(self):
test_docs = ["This is a test", "This is another test"] test_docs = ["This is a test", "This is another test"]
embeddings = encoder(test_docs) embeddings = encoder(test_docs, pooling_strategy="max")
assert isinstance(embeddings, list) assert isinstance(embeddings, list)
assert len(embeddings) == len(test_docs) assert len(embeddings) == len(test_docs)
assert all(isinstance(embedding, list) for embedding in embeddings) assert all(isinstance(embedding, list) for embedding in embeddings)
assert all(len(embedding) > 0 for embedding in embeddings) assert all(len(embedding) > 0 for embedding in embeddings)
def test_huggingface_encoder_normalized_embeddings(self): def test_huggingface_encoder_normalized_embeddings(self):
encoder = HuggingFaceEncoder()
docs = ["This is a test document.", "Another test document."] docs = ["This is a test document.", "Another test document."]
unnormalized_embeddings = encoder(docs, normalize_embeddings=False) unnormalized_embeddings = encoder(docs, normalize_embeddings=False)
normalized_embeddings = encoder(docs, normalize_embeddings=True) normalized_embeddings = encoder(docs, normalize_embeddings=True)
...@@ -34,9 +60,3 @@ class TestHuggingFaceEncoder: ...@@ -34,9 +60,3 @@ class TestHuggingFaceEncoder:
rtol=1e-5, rtol=1e-5,
atol=1e-5, # Adjust tolerance levels atol=1e-5, # Adjust tolerance levels
) )
def test_huggingface_encoder_invalid_pooling_strategy(self):
encoder = HuggingFaceEncoder()
docs = ["This is a test document.", "Another test document."]
with pytest.raises(ValueError):
encoder(docs, pooling_strategy="invalid_strategy")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment