-
Ismail Ashraq authoredIsmail Ashraq authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
bm25.py 2.34 KiB
from typing import Any
from semantic_router.encoders import BaseEncoder
from semantic_router.utils.logger import logger
class BM25Encoder(BaseEncoder):
model: Any | None = None
idx_mapping: dict[int, int] | None = None
type: str = "sparse"
def __init__(
self,
name: str = "bm25",
score_threshold: float = 0.82,
use_default_params: bool = True,
):
super().__init__(name=name, score_threshold=score_threshold)
try:
from pinecone_text.sparse import BM25Encoder as encoder
except ImportError:
raise ImportError(
"Please install pinecone-text to use BM25Encoder. "
"You can install it with: `pip install semantic-router[hybrid]`"
)
self.model = encoder()
if use_default_params:
logger.info("Downloading and initializing default sBM25 model parameters.")
self.model = encoder.default()
self._set_idx_mapping()
def _set_idx_mapping(self):
params = self.model.get_params()
doc_freq = params["doc_freq"]
if isinstance(doc_freq, dict):
indices = doc_freq["indices"]
self.idx_mapping = {int(idx): i for i, idx in enumerate(indices)}
else:
raise TypeError("Expected a dictionary for 'doc_freq'")
def __call__(self, docs: list[str]) -> list[list[float]]:
if self.model is None or self.idx_mapping is None:
raise ValueError("Model or index mapping is not initialized.")
if len(docs) == 1:
sparse_dicts = self.model.encode_queries(docs)
elif len(docs) > 1:
sparse_dicts = self.model.encode_documents(docs)
else:
raise ValueError("No documents to encode.")
embeds = [[0.0] * len(self.idx_mapping)] * len(docs)
for i, output in enumerate(sparse_dicts):
indices = output["indices"]
values = output["values"]
for idx, val in zip(indices, values):
if idx in self.idx_mapping:
position = self.idx_mapping[idx]
embeds[i][position] = val
return embeds
def fit(self, docs: list[str]):
if self.model is None:
raise ValueError("Model is not initialized.")
self.model.fit(docs)
self._set_idx_mapping()