import numpy as np from semantic_router.encoders import BaseEncoder def semantic_splitter( encoder: BaseEncoder, docs: list[str], threshold: float, split_method: str = "consecutive_similarity_drop", ) -> dict[str, list[str]]: """ Splits a list of documents base on semantic similarity changes. Method 1: "consecutive_similarity_drop" - This method splits documents based on the changes in similarity scores between consecutive documents. Method 2: "cumulative_similarity_drop" - This method segments the documents based on the changes in cumulative similarity score of the documents within the same split. Args: encoder (BaseEncoder): Encoder for document embeddings. docs (list[str]): Documents to split. threshold (float): The similarity drop value that will trigger a new document split. split_method (str): The method to use for splitting. Returns: Dict[str, list[str]]: Splits with corresponding documents. """ total_docs = len(docs) splits = {} curr_split_start_idx = 0 curr_split_num = 1 if split_method == "consecutive_similarity_drop": doc_embeds = encoder(docs) norm_embeds = doc_embeds / np.linalg.norm(doc_embeds, axis=1, keepdims=True) sim_matrix = np.matmul(norm_embeds, norm_embeds.T) for idx in range(1, total_docs): if idx < len(sim_matrix) and sim_matrix[idx - 1][idx] < threshold: splits[f"split {curr_split_num}"] = docs[curr_split_start_idx:idx] curr_split_start_idx = idx curr_split_num += 1 elif split_method == "cumulative_similarity_drop": for idx in range(1, total_docs): if idx + 1 < total_docs: curr_split_docs = "\n".join(docs[curr_split_start_idx : idx + 1]) next_doc = docs[idx + 1] curr_split_docs_embed = encoder([curr_split_docs])[0] next_doc_embed = encoder([next_doc])[0] similarity = np.dot(curr_split_docs_embed, next_doc_embed) / ( np.linalg.norm(curr_split_docs_embed) * np.linalg.norm(next_doc_embed) ) if similarity < threshold: splits[f"split {curr_split_num}"] = docs[ curr_split_start_idx : idx + 1 ] curr_split_start_idx = idx + 1 curr_split_num += 1 else: raise ValueError( "Invalid 'split_method'. Choose either 'consecutive_similarity_drop' or 'cumulative_similarity_drop'." ) splits[f"split {curr_split_num}"] = docs[curr_split_start_idx:] return splits