From 0f1ecf70c6a6f3225b47af6e9fb39bbd5e2bc5b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= <Danielgriffiths1790@gmail.com> Date: Fri, 5 Jan 2024 09:54:57 +0000 Subject: [PATCH] fix: added default values to tfidf tests --- tests/unit/encoders/test_tfidf.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/unit/encoders/test_tfidf.py b/tests/unit/encoders/test_tfidf.py index 9e4cae29..5aa8dfc8 100644 --- a/tests/unit/encoders/test_tfidf.py +++ b/tests/unit/encoders/test_tfidf.py @@ -1,6 +1,7 @@ import pytest from semantic_router.encoders import TfidfEncoder from semantic_router.route import Route +import numpy as np @pytest.fixture @@ -10,8 +11,8 @@ def tfidf_encoder(): class TestTfidfEncoder: def test_initialization(self, tfidf_encoder): - assert tfidf_encoder.word_index is None - assert tfidf_encoder.idf is None + assert tfidf_encoder.word_index == {} + assert (tfidf_encoder.idf == np.array([])).all() def test_fit(self, tfidf_encoder): routes = [ @@ -21,8 +22,8 @@ class TestTfidfEncoder: ) ] tfidf_encoder.fit(routes) - assert tfidf_encoder.word_index is not None - assert tfidf_encoder.idf is not None + assert tfidf_encoder.word_index != {} + assert not np.array_equal(tfidf_encoder.idf, np.array([])) def test_call_method(self, tfidf_encoder): routes = [ -- GitLab