Skip to content
Snippets Groups Projects
Commit 21912a6a authored by Ismail Ashraq's avatar Ismail Ashraq
Browse files

add more tests

parent 7edeb32e
No related branches found
No related tags found
No related merge requests found
import pytest
import numpy as np
from unittest.mock import patch
from semantic_router.encoders.huggingface import HuggingFaceEncoder
encoder = HuggingFaceEncoder()
class TestHuggingFaceEncoder:
def test_huggingface_encoder(self):
encoder = HuggingFaceEncoder()
def test_huggingface_encoder_import_errors_transformers(self):
with patch.dict("sys.modules", {"transformers": None}):
with pytest.raises(ImportError) as error:
HuggingFaceEncoder()
assert "Please install transformers to use HuggingFaceEncoder" in str(
error.value
)
def test_huggingface_encoder_import_errors_torch(self):
with patch.dict("sys.modules", {"torch": None}):
with pytest.raises(ImportError) as error:
HuggingFaceEncoder()
assert "Please install Pytorch to use HuggingFaceEncoder" in str(error.value)
def test_huggingface_encoder_mean_pooling(self):
test_docs = ["This is a test", "This is another test"]
embeddings = encoder(test_docs, pooling_strategy="mean")
assert isinstance(embeddings, list)
assert len(embeddings) == len(test_docs)
assert all(isinstance(embedding, list) for embedding in embeddings)
assert all(len(embedding) > 0 for embedding in embeddings)
def test_huggingface_encoder_max_pooling(self):
test_docs = ["This is a test", "This is another test"]
embeddings = encoder(test_docs)
embeddings = encoder(test_docs, pooling_strategy="max")
assert isinstance(embeddings, list)
assert len(embeddings) == len(test_docs)
assert all(isinstance(embedding, list) for embedding in embeddings)
assert all(len(embedding) > 0 for embedding in embeddings)
def test_huggingface_encoder_normalized_embeddings(self):
encoder = HuggingFaceEncoder()
docs = ["This is a test document.", "Another test document."]
unnormalized_embeddings = encoder(docs, normalize_embeddings=False)
normalized_embeddings = encoder(docs, normalize_embeddings=True)
......@@ -34,9 +60,3 @@ class TestHuggingFaceEncoder:
rtol=1e-5,
atol=1e-5, # Adjust tolerance levels
)
def test_huggingface_encoder_invalid_pooling_strategy(self):
encoder = HuggingFaceEncoder()
docs = ["This is a test document.", "Another test document."]
with pytest.raises(ValueError):
encoder(docs, pooling_strategy="invalid_strategy")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment