From 0f1ecf70c6a6f3225b47af6e9fb39bbd5e2bc5b9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?=
 <Danielgriffiths1790@gmail.com>
Date: Fri, 5 Jan 2024 09:54:57 +0000
Subject: [PATCH] fix: added default values to tfidf tests

---
 tests/unit/encoders/test_tfidf.py | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/tests/unit/encoders/test_tfidf.py b/tests/unit/encoders/test_tfidf.py
index 9e4cae29..5aa8dfc8 100644
--- a/tests/unit/encoders/test_tfidf.py
+++ b/tests/unit/encoders/test_tfidf.py
@@ -1,6 +1,7 @@
 import pytest
 from semantic_router.encoders import TfidfEncoder
 from semantic_router.route import Route
+import numpy as np
 
 
 @pytest.fixture
@@ -10,8 +11,8 @@ def tfidf_encoder():
 
 class TestTfidfEncoder:
     def test_initialization(self, tfidf_encoder):
-        assert tfidf_encoder.word_index is None
-        assert tfidf_encoder.idf is None
+        assert tfidf_encoder.word_index == {}
+        assert (tfidf_encoder.idf == np.array([])).all()
 
     def test_fit(self, tfidf_encoder):
         routes = [
@@ -21,8 +22,8 @@ class TestTfidfEncoder:
             )
         ]
         tfidf_encoder.fit(routes)
-        assert tfidf_encoder.word_index is not None
-        assert tfidf_encoder.idf is not None
+        assert tfidf_encoder.word_index != {}
+        assert not np.array_equal(tfidf_encoder.idf, np.array([]))
 
     def test_call_method(self, tfidf_encoder):
         routes = [
-- 
GitLab