Skip to content
Snippets Groups Projects
Unverified Commit e5183afb authored by James Briggs's avatar James Briggs Committed by GitHub
Browse files

Merge pull request #175 from aurelio-labs/simonas/splitter-fixes

fix: RollingWindowSplitter fixes
parents 2f7d92c6 21887050
Branches
Tags
No related merge requests found
Source diff could not be displayed: it is too large. Options to address this: view the blob.
This diff is collapsed.
from dataclasses import dataclass
from typing import List from typing import List
import numpy as np import numpy as np
...@@ -9,6 +10,31 @@ from semantic_router.splitters.utils import split_to_sentences, tiktoken_length ...@@ -9,6 +10,31 @@ from semantic_router.splitters.utils import split_to_sentences, tiktoken_length
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
@dataclass
class SplitStatistics:
total_documents: int
total_splits: int
splits_by_threshold: int
splits_by_max_chunk_size: int
splits_by_last_split: int
min_token_size: int
max_token_size: int
splits_by_similarity_ratio: float
def __str__(self):
return (
f"Splitting Statistics:\n"
f" - Total Documents: {self.total_documents}\n"
f" - Total Splits: {self.total_splits}\n"
f" - Splits by Threshold: {self.splits_by_threshold}\n"
f" - Splits by Max Chunk Size: {self.splits_by_max_chunk_size}\n"
f" - Last Split: {self.splits_by_last_split}\n"
f" - Minimum Token Size of Split: {self.min_token_size}\n"
f" - Maximum Token Size of Split: {self.max_token_size}\n"
f" - Similarity Split Ratio: {self.splits_by_similarity_ratio:.2f}"
)
class RollingWindowSplitter(BaseSplitter): class RollingWindowSplitter(BaseSplitter):
def __init__( def __init__(
self, self,
...@@ -32,6 +58,7 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -32,6 +58,7 @@ class RollingWindowSplitter(BaseSplitter):
self.min_split_tokens = min_split_tokens self.min_split_tokens = min_split_tokens
self.max_split_tokens = max_split_tokens self.max_split_tokens = max_split_tokens
self.split_tokens_tolerance = split_tokens_tolerance self.split_tokens_tolerance = split_tokens_tolerance
self.statistics: SplitStatistics
def encode_documents(self, docs: List[str]) -> np.ndarray: def encode_documents(self, docs: List[str]) -> np.ndarray:
try: try:
...@@ -55,15 +82,20 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -55,15 +82,20 @@ class RollingWindowSplitter(BaseSplitter):
def find_split_indices(self, similarities: List[float]) -> List[int]: def find_split_indices(self, similarities: List[float]) -> List[int]:
split_indices = [] split_indices = []
for idx in range(1, len(similarities)): for idx, score in enumerate(similarities):
if similarities[idx] < self.calculated_threshold: logger.debug(f"Similarity score at index {idx}: {score}")
if score < self.calculated_threshold:
logger.debug(
f"Adding to split_indices due to score < threshold: "
f"{score} < {self.calculated_threshold}"
)
# Split after the document at idx
split_indices.append(idx + 1) split_indices.append(idx + 1)
return split_indices return split_indices
def find_optimal_threshold(self, docs: List[str], encoded_docs: np.ndarray): def find_optimal_threshold(self, docs: List[str], similarity_scores: List[float]):
token_counts = [tiktoken_length(doc) for doc in docs] token_counts = [tiktoken_length(doc) for doc in docs]
cumulative_token_counts = np.cumsum([0] + token_counts) cumulative_token_counts = np.cumsum([0] + token_counts)
similarity_scores = self.calculate_similarity_scores(encoded_docs)
# Analyze the distribution of similarity scores to set initial bounds # Analyze the distribution of similarity scores to set initial bounds
median_score = np.median(similarity_scores) median_score = np.median(similarity_scores)
...@@ -74,12 +106,13 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -74,12 +106,13 @@ class RollingWindowSplitter(BaseSplitter):
high = min(1.0, float(median_score + std_dev)) high = min(1.0, float(median_score + std_dev))
iteration = 0 iteration = 0
median_tokens = 0
while low <= high: while low <= high:
self.calculated_threshold = (low + high) / 2 self.calculated_threshold = (low + high) / 2
logger.info( split_indices = self.find_split_indices(similarity_scores)
logger.debug(
f"Iteration {iteration}: Trying threshold: {self.calculated_threshold}" f"Iteration {iteration}: Trying threshold: {self.calculated_threshold}"
) )
split_indices = self.find_split_indices(similarity_scores)
# Calculate the token counts for each split using the cumulative sums # Calculate the token counts for each split using the cumulative sums
split_token_counts = [ split_token_counts = [
...@@ -91,7 +124,7 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -91,7 +124,7 @@ class RollingWindowSplitter(BaseSplitter):
# Calculate the median token count for the splits # Calculate the median token count for the splits
median_tokens = np.median(split_token_counts) median_tokens = np.median(split_token_counts)
logger.info( logger.debug(
f"Iteration {iteration}: Median tokens per split: {median_tokens}" f"Iteration {iteration}: Median tokens per split: {median_tokens}"
) )
if ( if (
...@@ -99,22 +132,22 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -99,22 +132,22 @@ class RollingWindowSplitter(BaseSplitter):
<= median_tokens <= median_tokens
<= self.max_split_tokens + self.split_tokens_tolerance <= self.max_split_tokens + self.split_tokens_tolerance
): ):
logger.info( logger.debug("Median tokens in target range. Stopping iteration.")
f"Iteration {iteration}: "
f"Optimal threshold {self.calculated_threshold} found "
f"with median tokens ({median_tokens}) in target range "
f" {self.min_split_tokens}-{self.max_split_tokens}."
)
break break
elif median_tokens < self.min_split_tokens: elif median_tokens < self.min_split_tokens:
high = self.calculated_threshold - self.threshold_adjustment high = self.calculated_threshold - self.threshold_adjustment
logger.info(f"Iteration {iteration}: Adjusting high to {high}") logger.debug(f"Iteration {iteration}: Adjusting high to {high}")
else: else:
low = self.calculated_threshold + self.threshold_adjustment low = self.calculated_threshold + self.threshold_adjustment
logger.info(f"Iteration {iteration}: Adjusting low to {low}") logger.debug(f"Iteration {iteration}: Adjusting low to {low}")
iteration += 1 iteration += 1
logger.info(f"Final optimal threshold: {self.calculated_threshold}") logger.info(
f"Optimal threshold {self.calculated_threshold} found "
f"with median tokens ({median_tokens}) in target range "
f"({self.min_split_tokens}-{self.max_split_tokens})."
)
return self.calculated_threshold return self.calculated_threshold
def split_documents( def split_documents(
...@@ -132,9 +165,15 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -132,9 +165,15 @@ class RollingWindowSplitter(BaseSplitter):
splits, current_split = [], [] splits, current_split = [], []
current_tokens_count = 0 current_tokens_count = 0
# Statistics
splits_by_threshold = 0
splits_by_max_chunk_size = 0
splits_by_last_split = 0
for doc_idx, doc in enumerate(docs): for doc_idx, doc in enumerate(docs):
doc_token_count = token_counts[doc_idx] doc_token_count = token_counts[doc_idx]
logger.debug(f"Accumulative token count: {current_tokens_count} tokens")
logger.debug(f"Document token count: {doc_token_count} tokens")
# Check if current index is a split point based on similarity # Check if current index is a split point based on similarity
if doc_idx + 1 in split_indices: if doc_idx + 1 in split_indices:
if current_tokens_count + doc_token_count >= self.min_split_tokens: if current_tokens_count + doc_token_count >= self.min_split_tokens:
...@@ -154,11 +193,12 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -154,11 +193,12 @@ class RollingWindowSplitter(BaseSplitter):
token_count=current_tokens_count, token_count=current_tokens_count,
) )
) )
logger.info( logger.debug(
f"Split finalized with {current_tokens_count} tokens due to " f"Split finalized with {current_tokens_count} tokens due to "
f"threshold {self.calculated_threshold}." f"threshold {self.calculated_threshold}."
) )
current_split, current_tokens_count = [], 0 current_split, current_tokens_count = [], 0
splits_by_threshold += 1
continue # Move to the next document after splitting continue # Move to the next document after splitting
# Check if adding the current document exceeds the max token limit # Check if adding the current document exceeds the max token limit
...@@ -172,7 +212,8 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -172,7 +212,8 @@ class RollingWindowSplitter(BaseSplitter):
token_count=current_tokens_count, token_count=current_tokens_count,
) )
) )
logger.info( splits_by_max_chunk_size += 1
logger.debug(
f"Split finalized with {current_tokens_count} tokens due to " f"Split finalized with {current_tokens_count} tokens due to "
f"exceeding token limit of {self.max_split_tokens}." f"exceeding token limit of {self.max_split_tokens}."
) )
...@@ -191,7 +232,8 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -191,7 +232,8 @@ class RollingWindowSplitter(BaseSplitter):
token_count=current_tokens_count, token_count=current_tokens_count,
) )
) )
logger.info( splits_by_last_split += 1
logger.debug(
f"Final split added with {current_tokens_count} " f"Final split added with {current_tokens_count} "
"tokens due to remaining documents." "tokens due to remaining documents."
) )
...@@ -209,10 +251,38 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -209,10 +251,38 @@ class RollingWindowSplitter(BaseSplitter):
f"Token count mismatch: {original_token_count} != {split_token_count}" f"Token count mismatch: {original_token_count} != {split_token_count}"
) )
# Statistics
total_splits = len(splits)
splits_by_similarity_ratio = (
splits_by_threshold / total_splits if total_splits else 0
)
min_token_size = max_token_size = 0
if splits:
token_counts = [
split.token_count for split in splits if split.token_count is not None
]
min_token_size, max_token_size = min(token_counts, default=0), max(
token_counts, default=0
)
self.statistics = SplitStatistics(
total_documents=len(docs),
total_splits=total_splits,
splits_by_threshold=splits_by_threshold,
splits_by_max_chunk_size=splits_by_max_chunk_size,
splits_by_last_split=splits_by_last_split,
min_token_size=min_token_size,
max_token_size=max_token_size,
splits_by_similarity_ratio=splits_by_similarity_ratio,
)
return splits return splits
def plot_similarity_scores( def plot_similarity_scores(
self, similarities: List[float], split_indices: List[int] self,
similarities: List[float],
split_indices: List[int],
splits: list[DocumentSplit],
): ):
try: try:
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
...@@ -225,16 +295,18 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -225,16 +295,18 @@ class RollingWindowSplitter(BaseSplitter):
if not self.plot_splits: if not self.plot_splits:
return return
plt.figure(figsize=(12, 6)) fig, axs = plt.subplots(2, 1, figsize=(12, 12)) # Adjust for two plots
plt.plot(similarities, label="Similarity Scores", marker="o")
# Plot 1: Similarity Scores
axs[0].plot(similarities, label="Similarity Scores", marker="o")
for split_index in split_indices: for split_index in split_indices:
plt.axvline( axs[0].axvline(
x=split_index - 1, x=split_index - 1,
color="r", color="r",
linestyle="--", linestyle="--",
label="Split" if split_index == split_indices[0] else "", label="Split" if split_index == split_indices[0] else "",
) )
plt.axhline( axs[0].axhline(
y=self.calculated_threshold, y=self.calculated_threshold,
color="g", color="g",
linestyle="-.", linestyle="-.",
...@@ -243,7 +315,7 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -243,7 +315,7 @@ class RollingWindowSplitter(BaseSplitter):
# Annotating each similarity score # Annotating each similarity score
for i, score in enumerate(similarities): for i, score in enumerate(similarities):
plt.annotate( axs[0].annotate(
f"{score:.2f}", # Formatting to two decimal places f"{score:.2f}", # Formatting to two decimal places
(i, score), (i, score),
textcoords="offset points", textcoords="offset points",
...@@ -251,16 +323,35 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -251,16 +323,35 @@ class RollingWindowSplitter(BaseSplitter):
ha="center", ha="center",
) # Center-align the text ) # Center-align the text
plt.xlabel("Document Segment Index") axs[0].set_xlabel("Document Segment Index")
plt.ylabel("Similarity Score") axs[0].set_ylabel("Similarity Score")
plt.title( axs[0].set_title(
f"Threshold: {self.calculated_threshold} |" f"Threshold: {self.calculated_threshold} |"
f" Window Size: {self.window_size}", f" Window Size: {self.window_size}",
loc="right", loc="right",
fontsize=10, fontsize=10,
) )
plt.suptitle("Document Similarity Scores", fontsize=14) axs[0].legend()
plt.legend()
# Plot 2: Split Token Size Distribution
token_counts = [split.token_count for split in splits]
axs[1].bar(range(len(token_counts)), token_counts, color="lightblue")
axs[1].set_title("Split Token Sizes")
axs[1].set_xlabel("Split Index")
axs[1].set_ylabel("Token Count")
axs[1].set_xticks(range(len(token_counts)))
axs[1].set_xticklabels([str(i) for i in range(len(token_counts))])
axs[1].grid(True)
# Annotate each bar with the token size
for idx, token_count in enumerate(token_counts):
if not token_count:
continue
axs[1].text(
idx, token_count + 0.01, str(token_count), ha="center", va="bottom"
)
plt.tight_layout()
plt.show() plt.show()
def plot_sentence_similarity_scores( def plot_sentence_similarity_scores(
...@@ -323,12 +414,13 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -323,12 +414,13 @@ class RollingWindowSplitter(BaseSplitter):
) )
docs = split_to_sentences(docs[0]) docs = split_to_sentences(docs[0])
encoded_docs = self.encode_documents(docs) encoded_docs = self.encode_documents(docs)
similarities = self.calculate_similarity_scores(encoded_docs)
if self.dynamic_threshold: if self.dynamic_threshold:
self.find_optimal_threshold(docs, encoded_docs) self.find_optimal_threshold(docs, similarities)
else: else:
self.calculated_threshold = self.encoder.score_threshold self.calculated_threshold = self.encoder.score_threshold
similarities = self.calculate_similarity_scores(encoded_docs)
split_indices = self.find_split_indices(similarities=similarities) split_indices = self.find_split_indices(similarities=similarities)
splits = self.split_documents(docs, split_indices, similarities) splits = self.split_documents(docs, split_indices, similarities)
self.plot_similarity_scores(similarities, split_indices) self.plot_similarity_scores(similarities, split_indices, splits)
logger.info(self.statistics)
return splits return splits
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment