From cf72979985e436148ef6b410a37e17296c689788 Mon Sep 17 00:00:00 2001 From: Ismail Ashraq <issey1455@gmail.com> Date: Wed, 24 Jan 2024 16:54:25 +0500 Subject: [PATCH] return triggered flag in response --- semantic_router/utils/splitters.py | 36 ++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/semantic_router/utils/splitters.py b/semantic_router/utils/splitters.py index 9f0c4704..10b95708 100644 --- a/semantic_router/utils/splitters.py +++ b/semantic_router/utils/splitters.py @@ -1,10 +1,17 @@ -from typing import Dict, List, Literal +from typing import List, Literal import numpy as np +from pydantic.v1 import BaseModel from semantic_router.encoders import BaseEncoder +class DocumentSplit(BaseModel): + docs: List[str] + is_triggered: bool = False + triggered_score: float | None = None + + def semantic_splitter( encoder: BaseEncoder, docs: List[str], @@ -12,7 +19,7 @@ def semantic_splitter( split_method: Literal[ "consecutive_similarity_drop", "cumulative_similarity_drop" ] = "consecutive_similarity_drop", -) -> Dict[str, List[str]]: +) -> List[DocumentSplit]: """ Splits a list of documents base on semantic similarity changes. @@ -33,7 +40,7 @@ def semantic_splitter( Dict[str, List[str]]: Splits with corresponding documents. """ total_docs = len(docs) - splits = {} + splits = [] curr_split_start_idx = 0 curr_split_num = 1 @@ -43,8 +50,15 @@ def semantic_splitter( sim_matrix = np.matmul(norm_embeds, norm_embeds.T) for idx in range(1, total_docs): - if idx < len(sim_matrix) and sim_matrix[idx - 1][idx] < threshold: - splits[f"split {curr_split_num}"] = docs[curr_split_start_idx:idx] + curr_sim_score = sim_matrix[idx - 1][idx] + if idx < len(sim_matrix) and curr_sim_score < threshold: + splits.append( + DocumentSplit( + docs=docs[curr_split_start_idx:idx], + is_triggered=True, + triggered_score=curr_sim_score, + ) + ) curr_split_start_idx = idx curr_split_num += 1 @@ -63,9 +77,13 @@ def semantic_splitter( ) if similarity < threshold: - splits[f"split {curr_split_num}"] = docs[ - curr_split_start_idx : idx + 1 - ] + splits.append( + DocumentSplit( + docs=docs[curr_split_start_idx : idx + 1], + is_triggered=True, + triggered_score=curr_sim_score, + ) + ) curr_split_start_idx = idx + 1 curr_split_num += 1 @@ -75,5 +93,5 @@ def semantic_splitter( " 'cumulative_similarity_drop'." ) - splits[f"split {curr_split_num}"] = docs[curr_split_start_idx:] + splits.append(DocumentSplit(docs=docs[curr_split_start_idx:])) return splits -- GitLab