From 7abb86d532fa2853f64bc97e6f605f6d214999bc Mon Sep 17 00:00:00 2001 From: Ismail Ashraq <issey1455@gmail.com> Date: Tue, 9 Jan 2024 18:26:10 +0500 Subject: [PATCH] make default bm25 params optional --- semantic_router/encoders/bm25.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py index 68150cb7..b1298b37 100644 --- a/semantic_router/encoders/bm25.py +++ b/semantic_router/encoders/bm25.py @@ -9,7 +9,12 @@ class BM25Encoder(BaseEncoder): idx_mapping: dict[int, int] | None = None type: str = "sparse" - def __init__(self, name: str = "bm25", score_threshold: float = 0.82): + def __init__( + self, + name: str = "bm25", + score_threshold: float = 0.82, + use_default_params: bool = True, + ): super().__init__(name=name, score_threshold=score_threshold) try: from pinecone_text.sparse import BM25Encoder as encoder @@ -18,9 +23,15 @@ class BM25Encoder(BaseEncoder): "Please install pinecone-text to use BM25Encoder. " "You can install it with: `pip install semantic-router[hybrid]`" ) - logger.info("Downloading and initializing BM25 model parameters.") - self.model = encoder.default() + self.model = encoder() + + if use_default_params: + logger.info("Downloading and initializing default sBM25 model parameters.") + self.model = encoder.default() + self._set_idx_mapping() + + def _set_idx_mapping(self): params = self.model.get_params() doc_freq = params["doc_freq"] if isinstance(doc_freq, dict): @@ -53,3 +64,4 @@ class BM25Encoder(BaseEncoder): if self.model is None: raise ValueError("Model is not initialized.") self.model.fit(docs) + self._set_idx_mapping() -- GitLab