diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py index 68150cb767faec1a8c3c5417f9d5f6db3a977cda..b1298b373e2420aacee1621634333e3f33c3c7f2 100644 --- a/semantic_router/encoders/bm25.py +++ b/semantic_router/encoders/bm25.py @@ -9,7 +9,12 @@ class BM25Encoder(BaseEncoder): idx_mapping: dict[int, int] | None = None type: str = "sparse" - def __init__(self, name: str = "bm25", score_threshold: float = 0.82): + def __init__( + self, + name: str = "bm25", + score_threshold: float = 0.82, + use_default_params: bool = True, + ): super().__init__(name=name, score_threshold=score_threshold) try: from pinecone_text.sparse import BM25Encoder as encoder @@ -18,9 +23,15 @@ class BM25Encoder(BaseEncoder): "Please install pinecone-text to use BM25Encoder. " "You can install it with: `pip install semantic-router[hybrid]`" ) - logger.info("Downloading and initializing BM25 model parameters.") - self.model = encoder.default() + self.model = encoder() + + if use_default_params: + logger.info("Downloading and initializing default sBM25 model parameters.") + self.model = encoder.default() + self._set_idx_mapping() + + def _set_idx_mapping(self): params = self.model.get_params() doc_freq = params["doc_freq"] if isinstance(doc_freq, dict): @@ -53,3 +64,4 @@ class BM25Encoder(BaseEncoder): if self.model is None: raise ValueError("Model is not initialized.") self.model.fit(docs) + self._set_idx_mapping()