From 73258344b861da8a0111e408327b7dfd82825797 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= <Danielgriffiths1790@gmail.com> Date: Thu, 4 Jan 2024 10:48:08 +0000 Subject: [PATCH] removed none types for mypy --- semantic_router/encoders/tfidf.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index 394f32fb..68baceaa 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -7,16 +7,16 @@ import string class TfidfEncoder(BaseEncoder): - idf: dict | None = None - word_index: dict | None = None + idf: np.ndarray + word_index: dict def __init__(self, name: str = "tfidf"): super().__init__(name=name) - self.word_index = None - self.idf = None + self.word_index = {} + self.idf = np.array([]) def __call__(self, docs: list[str]) -> list[list[float]]: - if self.word_index is None or self.idf is None: + if len(self.word_index) == 0 or self.idf.size == 0: raise ValueError("Vectorizer is not initialized.") if len(docs) == 0: raise ValueError("No documents to encode.") @@ -43,6 +43,8 @@ class TfidfEncoder(BaseEncoder): return word_index def _compute_tf(self, docs: list[str]) -> np.ndarray: + if len(self.word_index) == 0: + raise ValueError("Word index is not initialized.") tf = np.zeros((len(docs), len(self.word_index))) for i, doc in enumerate(docs): word_counts = Counter(doc.split()) @@ -54,6 +56,8 @@ class TfidfEncoder(BaseEncoder): return tf def _compute_idf(self, docs: list[str]) -> np.ndarray: + if len(self.word_index) == 0: + raise ValueError("Word index is not initialized.") idf = np.zeros(len(self.word_index)) for doc in docs: words = set(doc.split()) -- GitLab