Skip to content
Snippets Groups Projects
Unverified Commit 3f9c3845 authored by James Briggs's avatar James Briggs Committed by GitHub
Browse files

Merge pull request #72 from aurelio-labs/ashraq/semantic-splitter

feat: Add Semantic Splitter
parents 412d74cf dd2b5750
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,8 @@ from semantic_router.encoders import (
OpenAIEncoder,
)
from semantic_router.utils.splitters import semantic_splitter
class EncoderType(Enum):
HUGGINGFACE = "huggingface"
......@@ -41,3 +43,23 @@ class Encoder:
def __call__(self, texts: list[str]) -> list[list[float]]:
return self.model(texts)
class Message(BaseModel):
role: str
content: str
class Conversation(BaseModel):
messages: list[Message]
def split_by_topic(
self,
encoder: BaseEncoder,
threshold: float = 0.5,
split_method: str = "consecutive_similarity_drop",
):
docs = [f"{m.role}: {m.content}" for m in self.messages]
return semantic_splitter(
encoder=encoder, docs=docs, threshold=threshold, split_method=split_method
)
import numpy as np
from semantic_router.encoders import BaseEncoder
def semantic_splitter(
encoder: BaseEncoder,
docs: list[str],
threshold: float,
split_method: str = "consecutive_similarity_drop",
) -> dict[str, list[str]]:
"""
Splits a list of documents base on semantic similarity changes.
Method 1: "consecutive_similarity_drop" - This method splits documents based on
the changes in similarity scores between consecutive documents.
Method 2: "cumulative_similarity_drop" - This method segments the documents based on the
changes in cumulative similarity score of the documents within the same split.
Args:
encoder (BaseEncoder): Encoder for document embeddings.
docs (list[str]): Documents to split.
threshold (float): The similarity drop value that will trigger a new document split.
split_method (str): The method to use for splitting.
Returns:
Dict[str, list[str]]: Splits with corresponding documents.
"""
total_docs = len(docs)
splits = {}
curr_split_start_idx = 0
curr_split_num = 1
if split_method == "consecutive_similarity_drop":
doc_embeds = encoder(docs)
norm_embeds = doc_embeds / np.linalg.norm(doc_embeds, axis=1, keepdims=True)
sim_matrix = np.matmul(norm_embeds, norm_embeds.T)
for idx in range(1, total_docs):
if idx < len(sim_matrix) and sim_matrix[idx - 1][idx] < threshold:
splits[f"split {curr_split_num}"] = docs[curr_split_start_idx:idx]
curr_split_start_idx = idx
curr_split_num += 1
elif split_method == "cumulative_similarity_drop":
for idx in range(1, total_docs):
if idx + 1 < total_docs:
curr_split_docs = "\n".join(docs[curr_split_start_idx : idx + 1])
next_doc = docs[idx + 1]
curr_split_docs_embed = encoder([curr_split_docs])[0]
next_doc_embed = encoder([next_doc])[0]
similarity = np.dot(curr_split_docs_embed, next_doc_embed) / (
np.linalg.norm(curr_split_docs_embed)
* np.linalg.norm(next_doc_embed)
)
if similarity < threshold:
splits[f"split {curr_split_num}"] = docs[
curr_split_start_idx : idx + 1
]
curr_split_start_idx = idx + 1
curr_split_num += 1
else:
raise ValueError(
"Invalid 'split_method'. Choose either 'consecutive_similarity_drop' or 'cumulative_similarity_drop'."
)
splits[f"split {curr_split_num}"] = docs[curr_split_start_idx:]
return splits
import pytest
from unittest.mock import Mock
from semantic_router.utils.splitters import semantic_splitter
from semantic_router.schema import Conversation, Message
def test_semantic_splitter_consecutive_similarity_drop():
# Mock the BaseEncoder
mock_encoder = Mock()
mock_encoder.return_value = [[0.5, 0], [0.5, 0], [0.5, 0], [0, 0.5], [0, 0.5]]
docs = ["doc1", "doc2", "doc3", "doc4", "doc5"]
threshold = 0.5
split_method = "consecutive_similarity_drop"
result = semantic_splitter(mock_encoder, docs, threshold, split_method)
assert result == {"split 1": ["doc1", "doc2", "doc3"], "split 2": ["doc4", "doc5"]}
def test_semantic_splitter_cumulative_similarity_drop():
# Mock the BaseEncoder
mock_encoder = Mock()
mock_encoder.side_effect = (
lambda x: [[0.5, 0]] if "doc1" in x or "doc1\ndoc2" in x else [[0, 0.5]]
)
docs = ["doc1", "doc2", "doc3", "doc4", "doc5"]
threshold = 0.5
split_method = "cumulative_similarity_drop"
result = semantic_splitter(mock_encoder, docs, threshold, split_method)
assert result == {"split 1": ["doc1", "doc2"], "split 2": ["doc3", "doc4", "doc5"]}
def test_semantic_splitter_invalid_method():
# Mock the BaseEncoder
mock_encoder = Mock()
docs = ["doc1", "doc2", "doc3", "doc4", "doc5"]
threshold = 0.5
split_method = "invalid_method"
with pytest.raises(ValueError):
semantic_splitter(mock_encoder, docs, threshold, split_method)
def test_split_by_topic():
mock_encoder = Mock()
mock_encoder.return_value = [[0.5, 0], [0, 0.5]]
messages = [
Message(role="User", content="What is the latest news?"),
Message(role="Bot", content="How is the weather today?"),
]
conversation = Conversation(messages=messages)
result = conversation.split_by_topic(
encoder=mock_encoder, threshold=0.5, split_method="consecutive_similarity_drop"
)
assert result == {
"split 1": ["User: What is the latest news?"],
"split 2": ["Bot: How is the weather today?"],
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment