Skip to content
Snippets Groups Projects
Commit cf729799 authored by Ismail Ashraq's avatar Ismail Ashraq
Browse files

return triggered flag in response

parent 8597362b
No related branches found
No related tags found
No related merge requests found
from typing import Dict, List, Literal
from typing import List, Literal
import numpy as np
from pydantic.v1 import BaseModel
from semantic_router.encoders import BaseEncoder
class DocumentSplit(BaseModel):
docs: List[str]
is_triggered: bool = False
triggered_score: float | None = None
def semantic_splitter(
encoder: BaseEncoder,
docs: List[str],
......@@ -12,7 +19,7 @@ def semantic_splitter(
split_method: Literal[
"consecutive_similarity_drop", "cumulative_similarity_drop"
] = "consecutive_similarity_drop",
) -> Dict[str, List[str]]:
) -> List[DocumentSplit]:
"""
Splits a list of documents base on semantic similarity changes.
......@@ -33,7 +40,7 @@ def semantic_splitter(
Dict[str, List[str]]: Splits with corresponding documents.
"""
total_docs = len(docs)
splits = {}
splits = []
curr_split_start_idx = 0
curr_split_num = 1
......@@ -43,8 +50,15 @@ def semantic_splitter(
sim_matrix = np.matmul(norm_embeds, norm_embeds.T)
for idx in range(1, total_docs):
if idx < len(sim_matrix) and sim_matrix[idx - 1][idx] < threshold:
splits[f"split {curr_split_num}"] = docs[curr_split_start_idx:idx]
curr_sim_score = sim_matrix[idx - 1][idx]
if idx < len(sim_matrix) and curr_sim_score < threshold:
splits.append(
DocumentSplit(
docs=docs[curr_split_start_idx:idx],
is_triggered=True,
triggered_score=curr_sim_score,
)
)
curr_split_start_idx = idx
curr_split_num += 1
......@@ -63,9 +77,13 @@ def semantic_splitter(
)
if similarity < threshold:
splits[f"split {curr_split_num}"] = docs[
curr_split_start_idx : idx + 1
]
splits.append(
DocumentSplit(
docs=docs[curr_split_start_idx : idx + 1],
is_triggered=True,
triggered_score=curr_sim_score,
)
)
curr_split_start_idx = idx + 1
curr_split_num += 1
......@@ -75,5 +93,5 @@ def semantic_splitter(
" 'cumulative_similarity_drop'."
)
splits[f"split {curr_split_num}"] = docs[curr_split_start_idx:]
splits.append(DocumentSplit(docs=docs[curr_split_start_idx:]))
return splits
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment