diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 5dc4f5e87fba43d15e5e1ec39b578e3f37a54cfe..07b460b15c65088eeca6bffe394da8751ccb92ff 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, List, Literal, Optional +from typing import List, Literal, Optional from pydantic.v1 import BaseModel from pydantic.v1.dataclasses import dataclass @@ -10,7 +10,7 @@ from semantic_router.encoders import ( FastEmbedEncoder, OpenAIEncoder, ) -from semantic_router.utils.splitters import semantic_splitter +from semantic_router.utils.splitters import semantic_splitter, DocumentSplit class EncoderType(Enum): @@ -77,7 +77,7 @@ class Conversation(BaseModel): split_method: Literal[ "consecutive_similarity_drop", "cumulative_similarity_drop" ] = "consecutive_similarity_drop", - ) -> Dict[str, List[str]]: + ) -> list[DocumentSplit]: docs = [f"{m.role}: {m.content}" for m in self.messages] return semantic_splitter( encoder=encoder, docs=docs, threshold=threshold, split_method=split_method diff --git a/semantic_router/utils/splitters.py b/semantic_router/utils/splitters.py index 9f0c4704de778b78d03371f71cc73637698213db..0f3dc3410c2e6a0cb640b22dd7c111febd0f95cf 100644 --- a/semantic_router/utils/splitters.py +++ b/semantic_router/utils/splitters.py @@ -1,10 +1,17 @@ -from typing import Dict, List, Literal +from typing import List, Literal, Optional import numpy as np +from pydantic.v1 import BaseModel from semantic_router.encoders import BaseEncoder +class DocumentSplit(BaseModel): + docs: List[str] + is_triggered: bool = False + triggered_score: Optional[float] = None + + def semantic_splitter( encoder: BaseEncoder, docs: List[str], @@ -12,7 +19,7 @@ def semantic_splitter( split_method: Literal[ "consecutive_similarity_drop", "cumulative_similarity_drop" ] = "consecutive_similarity_drop", -) -> Dict[str, List[str]]: +) -> List[DocumentSplit]: """ Splits a list of documents base on semantic similarity changes. @@ -33,7 +40,7 @@ def semantic_splitter( Dict[str, List[str]]: Splits with corresponding documents. """ total_docs = len(docs) - splits = {} + splits = [] curr_split_start_idx = 0 curr_split_num = 1 @@ -43,8 +50,15 @@ def semantic_splitter( sim_matrix = np.matmul(norm_embeds, norm_embeds.T) for idx in range(1, total_docs): - if idx < len(sim_matrix) and sim_matrix[idx - 1][idx] < threshold: - splits[f"split {curr_split_num}"] = docs[curr_split_start_idx:idx] + curr_sim_score = sim_matrix[idx - 1][idx] + if idx < len(sim_matrix) and curr_sim_score < threshold: + splits.append( + DocumentSplit( + docs=docs[curr_split_start_idx:idx], + is_triggered=True, + triggered_score=curr_sim_score, + ) + ) curr_split_start_idx = idx curr_split_num += 1 @@ -57,15 +71,19 @@ def semantic_splitter( curr_split_docs_embed = encoder([curr_split_docs])[0] next_doc_embed = encoder([next_doc])[0] - similarity = np.dot(curr_split_docs_embed, next_doc_embed) / ( + curr_sim_score = np.dot(curr_split_docs_embed, next_doc_embed) / ( np.linalg.norm(curr_split_docs_embed) * np.linalg.norm(next_doc_embed) ) - if similarity < threshold: - splits[f"split {curr_split_num}"] = docs[ - curr_split_start_idx : idx + 1 - ] + if curr_sim_score < threshold: + splits.append( + DocumentSplit( + docs=docs[curr_split_start_idx : idx + 1], + is_triggered=True, + triggered_score=curr_sim_score, + ) + ) curr_split_start_idx = idx + 1 curr_split_num += 1 @@ -75,5 +93,5 @@ def semantic_splitter( " 'cumulative_similarity_drop'." ) - splits[f"split {curr_split_num}"] = docs[curr_split_start_idx:] + splits.append(DocumentSplit(docs=docs[curr_split_start_idx:])) return splits diff --git a/tests/unit/test_splitters.py b/tests/unit/test_splitters.py index ac9c037c7985bbce54144de4c6e4e7096162dc19..f0e8e3f32734f988e2d54a878c26f7f5f1994228 100644 --- a/tests/unit/test_splitters.py +++ b/tests/unit/test_splitters.py @@ -17,7 +17,8 @@ def test_semantic_splitter_consecutive_similarity_drop(): result = semantic_splitter(mock_encoder, docs, threshold, split_method) - assert result == {"split 1": ["doc1", "doc2", "doc3"], "split 2": ["doc4", "doc5"]} + assert result[0].docs == ["doc1", "doc2", "doc3"] + assert result[1].docs == ["doc4", "doc5"] def test_semantic_splitter_cumulative_similarity_drop(): @@ -33,7 +34,8 @@ def test_semantic_splitter_cumulative_similarity_drop(): result = semantic_splitter(mock_encoder, docs, threshold, split_method) - assert result == {"split 1": ["doc1", "doc2"], "split 2": ["doc3", "doc4", "doc5"]} + assert result[0].docs == ["doc1", "doc2"] + assert result[1].docs == ["doc3", "doc4", "doc5"] def test_semantic_splitter_invalid_method(): @@ -62,7 +64,5 @@ def test_split_by_topic(): encoder=mock_encoder, threshold=0.5, split_method="consecutive_similarity_drop" ) - assert result == { - "split 1": ["User: What is the latest news?"], - "split 2": ["Bot: How is the weather today?"], - } + assert result[0].docs == ["User: What is the latest news?"] + assert result[1].docs == ["Bot: How is the weather today?"]