From 671cd0738bce63cfb5312cdfe9b66f6711a8051e Mon Sep 17 00:00:00 2001
From: Simonas <20096648+simjak@users.noreply.github.com>
Date: Wed, 13 Dec 2023 12:37:31 +0200
Subject: [PATCH] added test + lint + codecov

---
 Makefile                         |  1 +
 README.md                        |  1 +
 coverage.xml                     | 99 +++++++++++++++++++-------------
 poetry.lock                      | 49 +++++++++++++++-
 pyproject.toml                   |  4 ++
 semantic_router/encoders/base.py |  2 +-
 semantic_router/encoders/bm25.py | 25 +++++---
 semantic_router/hybrid_layer.py  | 10 +++-
 semantic_router/layer.py         | 10 +++-
 semantic_router/schema.py        |  2 +-
 tests/unit/encoders/test_bm25.py | 19 ++++++
 11 files changed, 164 insertions(+), 58 deletions(-)

diff --git a/Makefile b/Makefile
index 3a3c42cd..8de202fa 100644
--- a/Makefile
+++ b/Makefile
@@ -9,6 +9,7 @@ lint_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep
 lint lint_diff:
 	poetry run black $(PYTHON_FILES) --check
 	poetry run ruff .
+	poetry run mypy $(PYTHON_FILES)
 
 test:
 	poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml --cov-fail-under=100
diff --git a/README.md b/README.md
index 9dac4222..b4b3c0e3 100644
--- a/README.md
+++ b/README.md
@@ -7,6 +7,7 @@
 <img alt="" src="https://img.shields.io/github/repo-size/aurelio-labs/semantic-router" />
 <img alt="GitHub Issues" src="https://img.shields.io/github/issues/aurelio-labs/semantic-router" />
 <img alt="GitHub Pull Requests" src="https://img.shields.io/github/issues-pr/aurelio-labs/semantic-router" />
+<img src="https://codecov.io/gh/aurelio-labs/semantic-router/graph/badge.svg?token=H8OOMV2TUF" />
 <img alt="Github License" src="https://img.shields.io/badge/License-MIT-yellow.svg" />
 </p>
 
diff --git a/coverage.xml b/coverage.xml
index 755c321e..8e6ca91d 100644
--- a/coverage.xml
+++ b/coverage.xml
@@ -1,5 +1,5 @@
 <?xml version="1.0" ?>
-<coverage version="7.3.2" timestamp="1702462041712" lines-valid="317" lines-covered="317" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0">
+<coverage version="7.3.2" timestamp="1702463592393" lines-valid="334" lines-covered="334" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0">
 	<!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.2 -->
 	<!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd -->
 	<sources>
@@ -22,8 +22,8 @@
 						<line number="1" hits="1"/>
 						<line number="2" hits="1"/>
 						<line number="3" hits="1"/>
-						<line number="4" hits="1"/>
-						<line number="6" hits="1"/>
+						<line number="5" hits="1"/>
+						<line number="11" hits="1"/>
 						<line number="12" hits="1"/>
 						<line number="15" hits="1"/>
 						<line number="16" hits="1"/>
@@ -102,10 +102,13 @@
 						<line number="131" hits="1"/>
 						<line number="132" hits="1"/>
 						<line number="135" hits="1"/>
-						<line number="137" hits="1"/>
+						<line number="136" hits="1"/>
 						<line number="138" hits="1"/>
 						<line number="139" hits="1"/>
 						<line number="141" hits="1"/>
+						<line number="142" hits="1"/>
+						<line number="143" hits="1"/>
+						<line number="145" hits="1"/>
 					</lines>
 				</class>
 				<class name="layer.py" filename="layer.py" complexity="0" line-rate="1" branch-rate="0">
@@ -115,68 +118,73 @@
 						<line number="3" hits="1"/>
 						<line number="8" hits="1"/>
 						<line number="9" hits="1"/>
-						<line number="12" hits="1"/>
+						<line number="10" hits="1"/>
 						<line number="13" hits="1"/>
 						<line number="14" hits="1"/>
 						<line number="15" hits="1"/>
-						<line number="17" hits="1"/>
+						<line number="16" hits="1"/>
 						<line number="18" hits="1"/>
-						<line number="20" hits="1"/>
+						<line number="19" hits="1"/>
 						<line number="21" hits="1"/>
 						<line number="22" hits="1"/>
 						<line number="23" hits="1"/>
-						<line number="25" hits="1"/>
-						<line number="27" hits="1"/>
-						<line number="29" hits="1"/>
-						<line number="31" hits="1"/>
+						<line number="24" hits="1"/>
+						<line number="26" hits="1"/>
+						<line number="28" hits="1"/>
+						<line number="30" hits="1"/>
 						<line number="32" hits="1"/>
 						<line number="33" hits="1"/>
 						<line number="34" hits="1"/>
 						<line number="35" hits="1"/>
 						<line number="36" hits="1"/>
-						<line number="38" hits="1"/>
-						<line number="40" hits="1"/>
-						<line number="42" hits="1"/>
-						<line number="45" hits="1"/>
+						<line number="37" hits="1"/>
+						<line number="39" hits="1"/>
+						<line number="41" hits="1"/>
+						<line number="43" hits="1"/>
 						<line number="46" hits="1"/>
-						<line number="48" hits="1"/>
+						<line number="47" hits="1"/>
 						<line number="49" hits="1"/>
-						<line number="51" hits="1"/>
+						<line number="50" hits="1"/>
 						<line number="52" hits="1"/>
-						<line number="54" hits="1"/>
+						<line number="53" hits="1"/>
 						<line number="55" hits="1"/>
-						<line number="57" hits="1"/>
-						<line number="59" hits="1"/>
-						<line number="62" hits="1"/>
-						<line number="65" hits="1"/>
+						<line number="56" hits="1"/>
+						<line number="58" hits="1"/>
+						<line number="60" hits="1"/>
+						<line number="63" hits="1"/>
 						<line number="66" hits="1"/>
 						<line number="67" hits="1"/>
-						<line number="74" hits="1"/>
+						<line number="68" hits="1"/>
 						<line number="75" hits="1"/>
-						<line number="81" hits="1"/>
-						<line number="86" hits="1"/>
+						<line number="76" hits="1"/>
+						<line number="82" hits="1"/>
 						<line number="87" hits="1"/>
-						<line number="89" hits="1"/>
-						<line number="91" hits="1"/>
+						<line number="88" hits="1"/>
+						<line number="90" hits="1"/>
 						<line number="92" hits="1"/>
-						<line number="94" hits="1"/>
+						<line number="93" hits="1"/>
 						<line number="95" hits="1"/>
-						<line number="97" hits="1"/>
+						<line number="96" hits="1"/>
+						<line number="98" hits="1"/>
 						<line number="99" hits="1"/>
-						<line number="100" hits="1"/>
 						<line number="101" hits="1"/>
 						<line number="102" hits="1"/>
 						<line number="103" hits="1"/>
 						<line number="104" hits="1"/>
 						<line number="105" hits="1"/>
+						<line number="106" hits="1"/>
 						<line number="107" hits="1"/>
-						<line number="110" hits="1"/>
-						<line number="111" hits="1"/>
-						<line number="114" hits="1"/>
+						<line number="109" hits="1"/>
+						<line number="112" hits="1"/>
+						<line number="113" hits="1"/>
 						<line number="116" hits="1"/>
 						<line number="117" hits="1"/>
-						<line number="118" hits="1"/>
+						<line number="119" hits="1"/>
 						<line number="120" hits="1"/>
+						<line number="122" hits="1"/>
+						<line number="123" hits="1"/>
+						<line number="124" hits="1"/>
+						<line number="126" hits="1"/>
 					</lines>
 				</class>
 				<class name="linear.py" filename="linear.py" complexity="0" line-rate="1" branch-rate="0">
@@ -271,31 +279,40 @@
 					<lines>
 						<line number="1" hits="1"/>
 						<line number="3" hits="1"/>
-						<line number="6" hits="1"/>
-						<line number="7" hits="1"/>
+						<line number="5" hits="1"/>
 						<line number="8" hits="1"/>
+						<line number="9" hits="1"/>
 						<line number="10" hits="1"/>
-						<line number="11" hits="1"/>
+						<line number="12" hits="1"/>
 						<line number="13" hits="1"/>
 						<line number="14" hits="1"/>
+						<line number="16" hits="1"/>
+						<line number="17" hits="1"/>
+						<line number="18" hits="1"/>
 						<line number="19" hits="1"/>
 						<line number="20" hits="1"/>
-						<line number="21" hits="1"/>
 						<line number="22" hits="1"/>
-						<line number="23" hits="1"/>
+						<line number="24" hits="1"/>
 						<line number="25" hits="1"/>
+						<line number="26" hits="1"/>
 						<line number="27" hits="1"/>
 						<line number="28" hits="1"/>
 						<line number="29" hits="1"/>
 						<line number="30" hits="1"/>
-						<line number="31" hits="1"/>
 						<line number="32" hits="1"/>
-						<line number="33" hits="1"/>
 						<line number="34" hits="1"/>
+						<line number="35" hits="1"/>
 						<line number="36" hits="1"/>
 						<line number="37" hits="1"/>
+						<line number="38" hits="1"/>
 						<line number="39" hits="1"/>
 						<line number="40" hits="1"/>
+						<line number="41" hits="1"/>
+						<line number="42" hits="1"/>
+						<line number="44" hits="1"/>
+						<line number="45" hits="1"/>
+						<line number="46" hits="1"/>
+						<line number="47" hits="1"/>
 					</lines>
 				</class>
 				<class name="cohere.py" filename="encoders/cohere.py" complexity="0" line-rate="1" branch-rate="0">
diff --git a/poetry.lock b/poetry.lock
index 3bedc8de..b459e6ba 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1065,6 +1065,53 @@ files = [
     {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"},
 ]
 
+[[package]]
+name = "mypy"
+version = "1.7.1"
+description = "Optional static typing for Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "mypy-1.7.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:12cce78e329838d70a204293e7b29af9faa3ab14899aec397798a4b41be7f340"},
+    {file = "mypy-1.7.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1484b8fa2c10adf4474f016e09d7a159602f3239075c7bf9f1627f5acf40ad49"},
+    {file = "mypy-1.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31902408f4bf54108bbfb2e35369877c01c95adc6192958684473658c322c8a5"},
+    {file = "mypy-1.7.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f2c2521a8e4d6d769e3234350ba7b65ff5d527137cdcde13ff4d99114b0c8e7d"},
+    {file = "mypy-1.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:fcd2572dd4519e8a6642b733cd3a8cfc1ef94bafd0c1ceed9c94fe736cb65b6a"},
+    {file = "mypy-1.7.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4b901927f16224d0d143b925ce9a4e6b3a758010673eeded9b748f250cf4e8f7"},
+    {file = "mypy-1.7.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2f7f6985d05a4e3ce8255396df363046c28bea790e40617654e91ed580ca7c51"},
+    {file = "mypy-1.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:944bdc21ebd620eafefc090cdf83158393ec2b1391578359776c00de00e8907a"},
+    {file = "mypy-1.7.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9c7ac372232c928fff0645d85f273a726970c014749b924ce5710d7d89763a28"},
+    {file = "mypy-1.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:f6efc9bd72258f89a3816e3a98c09d36f079c223aa345c659622f056b760ab42"},
+    {file = "mypy-1.7.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6dbdec441c60699288adf051f51a5d512b0d818526d1dcfff5a41f8cd8b4aaf1"},
+    {file = "mypy-1.7.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4fc3d14ee80cd22367caaaf6e014494415bf440980a3045bf5045b525680ac33"},
+    {file = "mypy-1.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c6e4464ed5f01dc44dc9821caf67b60a4e5c3b04278286a85c067010653a0eb"},
+    {file = "mypy-1.7.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:d9b338c19fa2412f76e17525c1b4f2c687a55b156320acb588df79f2e6fa9fea"},
+    {file = "mypy-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:204e0d6de5fd2317394a4eff62065614c4892d5a4d1a7ee55b765d7a3d9e3f82"},
+    {file = "mypy-1.7.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:84860e06ba363d9c0eeabd45ac0fde4b903ad7aa4f93cd8b648385a888e23200"},
+    {file = "mypy-1.7.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8c5091ebd294f7628eb25ea554852a52058ac81472c921150e3a61cdd68f75a7"},
+    {file = "mypy-1.7.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40716d1f821b89838589e5b3106ebbc23636ffdef5abc31f7cd0266db936067e"},
+    {file = "mypy-1.7.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5cf3f0c5ac72139797953bd50bc6c95ac13075e62dbfcc923571180bebb662e9"},
+    {file = "mypy-1.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:78e25b2fd6cbb55ddfb8058417df193f0129cad5f4ee75d1502248e588d9e0d7"},
+    {file = "mypy-1.7.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:75c4d2a6effd015786c87774e04331b6da863fc3fc4e8adfc3b40aa55ab516fe"},
+    {file = "mypy-1.7.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2643d145af5292ee956aa0a83c2ce1038a3bdb26e033dadeb2f7066fb0c9abce"},
+    {file = "mypy-1.7.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75aa828610b67462ffe3057d4d8a4112105ed211596b750b53cbfe182f44777a"},
+    {file = "mypy-1.7.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ee5d62d28b854eb61889cde4e1dbc10fbaa5560cb39780c3995f6737f7e82120"},
+    {file = "mypy-1.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:72cf32ce7dd3562373f78bd751f73c96cfb441de147cc2448a92c1a308bd0ca6"},
+    {file = "mypy-1.7.1-py3-none-any.whl", hash = "sha256:f7c5d642db47376a0cc130f0de6d055056e010debdaf0707cd2b0fc7e7ef30ea"},
+    {file = "mypy-1.7.1.tar.gz", hash = "sha256:fcb6d9afb1b6208b4c712af0dafdc650f518836065df0d4fb1d800f5d6773db2"},
+]
+
+[package.dependencies]
+mypy-extensions = ">=1.0.0"
+tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
+typing-extensions = ">=4.1.0"
+
+[package.extras]
+dmypy = ["psutil (>=4.0)"]
+install-types = ["pip"]
+mypyc = ["setuptools (>=50)"]
+reports = ["lxml"]
+
 [[package]]
 name = "mypy-extensions"
 version = "1.0.0"
@@ -2055,4 +2102,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.10"
-content-hash = "b17b9fd9486d6c744c41a31ab54f7871daba1e2d4166fda228033c5858f6f9d8"
+content-hash = "58bf19052f05863cb4623e85a73de5758d581ff539cfb69f0920e57f6cb035d0"
diff --git a/pyproject.toml b/pyproject.toml
index 61a95510..5a8e18e0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,6 +29,7 @@ pytest = "^7.4.3"
 pytest-mock = "^3.12.0"
 pytest-cov = "^4.1.0"
 pytest-xdist = "^3.5.0"
+mypy = "^1.7.1"
 
 [build-system]
 requires = ["poetry-core"]
@@ -36,3 +37,6 @@ build-backend = "poetry.core.masonry.api"
 
 [tool.ruff.per-file-ignores]
 "*.ipynb" = ["E402"]
+
+[tool.mypy]
+ignore_missing_imports = true
diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py
index b6de1f89..632ebc79 100644
--- a/semantic_router/encoders/base.py
+++ b/semantic_router/encoders/base.py
@@ -7,5 +7,5 @@ class BaseEncoder(BaseModel):
     class Config:
         arbitrary_types_allowed = True
 
-    def __call__(self, docs: list[str]) -> list[float]:
+    def __call__(self, docs: list[str]) -> list[list[float]]:
         raise NotImplementedError("Subclasses must implement this method")
diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py
index 0d498197..c9da628e 100644
--- a/semantic_router/encoders/bm25.py
+++ b/semantic_router/encoders/bm25.py
@@ -1,29 +1,36 @@
+from typing import Any
+
 from pinecone_text.sparse import BM25Encoder as encoder
 
 from semantic_router.encoders import BaseEncoder
 
 
 class BM25Encoder(BaseEncoder):
-    model: encoder | None = None
+    model: Any | None = None
     idx_mapping: dict[int, int] | None = None
 
     def __init__(self, name: str = "bm25"):
         super().__init__(name=name)
-        # initialize BM25 encoder with default params (trained on MSMarco)
         self.model = encoder.default()
-        self.idx_mapping = {
-            idx: i
-            for i, idx in enumerate(self.model.get_params()["doc_freq"]["indices"])
-        }
+
+        params = self.model.get_params()
+        doc_freq = params["doc_freq"]
+        if isinstance(doc_freq, dict):
+            indices = doc_freq["indices"]
+            self.idx_mapping = {int(idx): i for i, idx in enumerate(indices)}
+        else:
+            raise TypeError("Expected a dictionary for 'doc_freq'")
 
     def __call__(self, docs: list[str]) -> list[list[float]]:
+        if self.model is None or self.idx_mapping is None:
+            raise ValueError("Model or index mapping is not initialized.")
         if len(docs) == 1:
             sparse_dicts = self.model.encode_queries(docs)
         elif len(docs) > 1:
             sparse_dicts = self.model.encode_documents(docs)
         else:
             raise ValueError("No documents to encode.")
-        # convert sparse dict to sparse vector
+
         embeds = [[0.0] * len(self.idx_mapping)] * len(docs)
         for i, output in enumerate(sparse_dicts):
             indices = output["indices"]
@@ -32,9 +39,9 @@ class BM25Encoder(BaseEncoder):
                 if idx in self.idx_mapping:
                     position = self.idx_mapping[idx]
                     embeds[i][position] = val
-                else:
-                    print(idx, "not in encoder.idx_mapping")
         return embeds
 
     def fit(self, docs: list[str]):
+        if self.model is None:
+            raise ValueError("Model is not initialized.")
         self.model.fit(docs)
diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py
index a0452a31..dec6336e 100644
--- a/semantic_router/hybrid_layer.py
+++ b/semantic_router/hybrid_layer.py
@@ -1,7 +1,6 @@
 import numpy as np
 from numpy.linalg import norm
 from tqdm.auto import tqdm
-from semantic_router.utils.logger import logger
 
 from semantic_router.encoders import (
     BaseEncoder,
@@ -10,6 +9,7 @@ from semantic_router.encoders import (
     OpenAIEncoder,
 )
 from semantic_router.schema import Route
+from semantic_router.utils.logger import logger
 
 
 class HybridRouteLayer:
@@ -118,7 +118,7 @@ class HybridRouteLayer:
         return dense, sparse
 
     def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float]]:
-        scores_by_class = {}
+        scores_by_class: dict[str, list[float]] = {}
         for result in query_results:
             score = result["score"]
             route = result["route"]
@@ -132,7 +132,11 @@ class HybridRouteLayer:
         top_class = max(total_scores, key=lambda x: total_scores[x], default=None)
 
         # Return the top class and its associated scores
-        return str(top_class), scores_by_class.get(top_class, [])
+        if top_class is not None:
+            return str(top_class), scores_by_class.get(top_class, [])
+        else:
+            logger.warning("No classification found for semantic classifier.")
+            return "", []
 
     def _pass_threshold(self, scores: list[float], threshold: float) -> bool:
         if scores:
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index efa4862d..cb408c5c 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -7,6 +7,7 @@ from semantic_router.encoders import (
 )
 from semantic_router.linear import similarity_matrix, top_scores
 from semantic_router.schema import Route
+from semantic_router.utils.logger import logger
 
 
 class RouteLayer:
@@ -94,10 +95,11 @@ class RouteLayer:
             routes = self.categories[idx] if self.categories is not None else []
             return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)]
         else:
+            logger.warning("No index found for route layer.")
             return []
 
     def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float]]:
-        scores_by_class = {}
+        scores_by_class: dict[str, list[float]] = {}
         for result in query_results:
             score = result["score"]
             route = result["route"]
@@ -111,7 +113,11 @@ class RouteLayer:
         top_class = max(total_scores, key=lambda x: total_scores[x], default=None)
 
         # Return the top class and its associated scores
-        return str(top_class), scores_by_class.get(top_class, [])
+        if top_class is not None:
+            return str(top_class), scores_by_class.get(top_class, [])
+        else:
+            logger.warning("No classification found for semantic classifier.")
+            return "", []
 
     def _pass_threshold(self, scores: list[float], threshold: float) -> bool:
         if scores:
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 3763db03..007cddcb 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -38,7 +38,7 @@ class Encoder:
         elif self.type == EncoderType.COHERE:
             self.model = CohereEncoder(name)
 
-    def __call__(self, texts: list[str]) -> list[float]:
+    def __call__(self, texts: list[str]) -> list[list[float]]:
         return self.model(texts)
 
 
diff --git a/tests/unit/encoders/test_bm25.py b/tests/unit/encoders/test_bm25.py
index c1987151..e654d7bb 100644
--- a/tests/unit/encoders/test_bm25.py
+++ b/tests/unit/encoders/test_bm25.py
@@ -33,3 +33,22 @@ class TestBM25Encoder:
         assert all(
             isinstance(sublist, list) for sublist in result
         ), "Each item in result should be a list"
+
+    def test_init_with_non_dict_doc_freq(self, mocker):
+        mock_encoder = mocker.MagicMock()
+        mock_encoder.get_params.return_value = {"doc_freq": "not a dict"}
+        mocker.patch(
+            "pinecone_text.sparse.BM25Encoder.default", return_value=mock_encoder
+        )
+        with pytest.raises(TypeError):
+            BM25Encoder()
+
+    def test_call_method_with_uninitialized_model_or_mapping(self, bm25_encoder):
+        bm25_encoder.model = None
+        with pytest.raises(ValueError):
+            bm25_encoder(["test"])
+
+    def test_fit_with_uninitialized_model(self, bm25_encoder):
+        bm25_encoder.model = None
+        with pytest.raises(ValueError):
+            bm25_encoder.fit(["test"])
-- 
GitLab