diff --git a/semantic_router/utils/splitters.py b/semantic_router/utils/splitters.py index 9f0c4704de778b78d03371f71cc73637698213db..10b957086cadb562aaf72151e893b033d75d442d 100644 --- a/semantic_router/utils/splitters.py +++ b/semantic_router/utils/splitters.py @@ -1,10 +1,17 @@ -from typing import Dict, List, Literal +from typing import List, Literal import numpy as np +from pydantic.v1 import BaseModel from semantic_router.encoders import BaseEncoder +class DocumentSplit(BaseModel): + docs: List[str] + is_triggered: bool = False + triggered_score: float | None = None + + def semantic_splitter( encoder: BaseEncoder, docs: List[str], @@ -12,7 +19,7 @@ def semantic_splitter( split_method: Literal[ "consecutive_similarity_drop", "cumulative_similarity_drop" ] = "consecutive_similarity_drop", -) -> Dict[str, List[str]]: +) -> List[DocumentSplit]: """ Splits a list of documents base on semantic similarity changes. @@ -33,7 +40,7 @@ def semantic_splitter( Dict[str, List[str]]: Splits with corresponding documents. """ total_docs = len(docs) - splits = {} + splits = [] curr_split_start_idx = 0 curr_split_num = 1 @@ -43,8 +50,15 @@ def semantic_splitter( sim_matrix = np.matmul(norm_embeds, norm_embeds.T) for idx in range(1, total_docs): - if idx < len(sim_matrix) and sim_matrix[idx - 1][idx] < threshold: - splits[f"split {curr_split_num}"] = docs[curr_split_start_idx:idx] + curr_sim_score = sim_matrix[idx - 1][idx] + if idx < len(sim_matrix) and curr_sim_score < threshold: + splits.append( + DocumentSplit( + docs=docs[curr_split_start_idx:idx], + is_triggered=True, + triggered_score=curr_sim_score, + ) + ) curr_split_start_idx = idx curr_split_num += 1 @@ -63,9 +77,13 @@ def semantic_splitter( ) if similarity < threshold: - splits[f"split {curr_split_num}"] = docs[ - curr_split_start_idx : idx + 1 - ] + splits.append( + DocumentSplit( + docs=docs[curr_split_start_idx : idx + 1], + is_triggered=True, + triggered_score=curr_sim_score, + ) + ) curr_split_start_idx = idx + 1 curr_split_num += 1 @@ -75,5 +93,5 @@ def semantic_splitter( " 'cumulative_similarity_drop'." ) - splits[f"split {curr_split_num}"] = docs[curr_split_start_idx:] + splits.append(DocumentSplit(docs=docs[curr_split_start_idx:])) return splits