diff --git a/tests/unit/encoders/test_tfidf.py b/tests/unit/encoders/test_tfidf.py index 9e4cae29a781e5c5dd30ff9d7775fc15f5a9ff4c..5aa8dfc8f3e257f42c4ac3544a3d7d3ba747acb5 100644 --- a/tests/unit/encoders/test_tfidf.py +++ b/tests/unit/encoders/test_tfidf.py @@ -1,6 +1,7 @@ import pytest from semantic_router.encoders import TfidfEncoder from semantic_router.route import Route +import numpy as np @pytest.fixture @@ -10,8 +11,8 @@ def tfidf_encoder(): class TestTfidfEncoder: def test_initialization(self, tfidf_encoder): - assert tfidf_encoder.word_index is None - assert tfidf_encoder.idf is None + assert tfidf_encoder.word_index == {} + assert (tfidf_encoder.idf == np.array([])).all() def test_fit(self, tfidf_encoder): routes = [ @@ -21,8 +22,8 @@ class TestTfidfEncoder: ) ] tfidf_encoder.fit(routes) - assert tfidf_encoder.word_index is not None - assert tfidf_encoder.idf is not None + assert tfidf_encoder.word_index != {} + assert not np.array_equal(tfidf_encoder.idf, np.array([])) def test_call_method(self, tfidf_encoder): routes = [