From 73258344b861da8a0111e408327b7dfd82825797 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?=
 <Danielgriffiths1790@gmail.com>
Date: Thu, 4 Jan 2024 10:48:08 +0000
Subject: [PATCH] removed none types for mypy

---
 semantic_router/encoders/tfidf.py | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py
index 394f32fb..68baceaa 100644
--- a/semantic_router/encoders/tfidf.py
+++ b/semantic_router/encoders/tfidf.py
@@ -7,16 +7,16 @@ import string
 
 
 class TfidfEncoder(BaseEncoder):
-    idf: dict | None = None
-    word_index: dict | None = None
+    idf: np.ndarray
+    word_index: dict
 
     def __init__(self, name: str = "tfidf"):
         super().__init__(name=name)
-        self.word_index = None
-        self.idf = None
+        self.word_index = {}
+        self.idf = np.array([])
 
     def __call__(self, docs: list[str]) -> list[list[float]]:
-        if self.word_index is None or self.idf is None:
+        if len(self.word_index) == 0 or self.idf.size == 0:
             raise ValueError("Vectorizer is not initialized.")
         if len(docs) == 0:
             raise ValueError("No documents to encode.")
@@ -43,6 +43,8 @@ class TfidfEncoder(BaseEncoder):
         return word_index
 
     def _compute_tf(self, docs: list[str]) -> np.ndarray:
+        if len(self.word_index) == 0:
+            raise ValueError("Word index is not initialized.")
         tf = np.zeros((len(docs), len(self.word_index)))
         for i, doc in enumerate(docs):
             word_counts = Counter(doc.split())
@@ -54,6 +56,8 @@ class TfidfEncoder(BaseEncoder):
         return tf
 
     def _compute_idf(self, docs: list[str]) -> np.ndarray:
+        if len(self.word_index) == 0:
+            raise ValueError("Word index is not initialized.")
         idf = np.zeros(len(self.word_index))
         for doc in docs:
             words = set(doc.split())
-- 
GitLab