Skip to content
Snippets Groups Projects
Commit ad2a7a8f authored by Simonas's avatar Simonas
Browse files

fix: Split index error

parent 004620b8
No related branches found
No related tags found
No related merge requests found
Source diff could not be displayed: it is too large. Options to address this: view the blob.
This diff is collapsed.
from dataclasses import dataclass
from typing import List
import numpy as np
......@@ -9,6 +10,31 @@ from semantic_router.splitters.utils import split_to_sentences, tiktoken_length
from semantic_router.utils.logger import logger
@dataclass
class SplitStatistics:
total_documents: int
total_splits: int
splits_by_threshold: int
splits_by_max_chunk_size: int
splits_by_last_split: int
min_token_size: int
max_token_size: int
splits_by_similarity_ratio: float
def __str__(self):
return (
f"Splitting Statistics:\n"
f" - Total Documents: {self.total_documents}\n"
f" - Total Splits: {self.total_splits}\n"
f" - Splits by Threshold: {self.splits_by_threshold}\n"
f" - Splits by Max Chunk Size: {self.splits_by_max_chunk_size}\n"
f" - Last Split: {self.splits_by_last_split}\n"
f" - Minimum Token Size of Split: {self.min_token_size}\n"
f" - Maximum Token Size of Split: {self.max_token_size}\n"
f" - Similarity Split Ratio: {self.splits_by_similarity_ratio:.2f}"
)
class RollingWindowSplitter(BaseSplitter):
def __init__(
self,
......@@ -32,6 +58,7 @@ class RollingWindowSplitter(BaseSplitter):
self.min_split_tokens = min_split_tokens
self.max_split_tokens = max_split_tokens
self.split_tokens_tolerance = split_tokens_tolerance
self.statistics: SplitStatistics
def encode_documents(self, docs: List[str]) -> np.ndarray:
try:
......@@ -55,15 +82,20 @@ class RollingWindowSplitter(BaseSplitter):
def find_split_indices(self, similarities: List[float]) -> List[int]:
split_indices = []
for idx in range(1, len(similarities)):
if similarities[idx] < self.calculated_threshold:
for idx, score in enumerate(similarities):
logger.debug(f"Similarity score at index {idx}: {score}")
if score < self.calculated_threshold:
logger.debug(
f"Adding to split_indices due to score < threshold: "
f"{score} < {self.calculated_threshold}"
)
# Split after the document at idx
split_indices.append(idx + 1)
return split_indices
def find_optimal_threshold(self, docs: List[str], encoded_docs: np.ndarray):
def find_optimal_threshold(self, docs: List[str], similarity_scores: List[float]):
token_counts = [tiktoken_length(doc) for doc in docs]
cumulative_token_counts = np.cumsum([0] + token_counts)
similarity_scores = self.calculate_similarity_scores(encoded_docs)
# Analyze the distribution of similarity scores to set initial bounds
median_score = np.median(similarity_scores)
......@@ -77,10 +109,10 @@ class RollingWindowSplitter(BaseSplitter):
median_tokens = 0
while low <= high:
self.calculated_threshold = (low + high) / 2
split_indices = self.find_split_indices(similarity_scores)
logger.debug(
f"Iteration {iteration}: Trying threshold: {self.calculated_threshold}"
)
split_indices = self.find_split_indices(similarity_scores)
# Calculate the token counts for each split using the cumulative sums
split_token_counts = [
......@@ -140,7 +172,8 @@ class RollingWindowSplitter(BaseSplitter):
for doc_idx, doc in enumerate(docs):
doc_token_count = token_counts[doc_idx]
logger.debug(f"Accumulative token count: {current_tokens_count} tokens")
logger.debug(f"Document token count: {doc_token_count} tokens")
# Check if current index is a split point based on similarity
if doc_idx + 1 in split_indices:
if current_tokens_count + doc_token_count >= self.min_split_tokens:
......@@ -226,17 +259,16 @@ class RollingWindowSplitter(BaseSplitter):
min_token_size = min(split.token_count for split in splits) if splits else 0
max_token_size = max(split.token_count for split in splits) if splits else 0
statistics_summary = (
f"Splitting Statistics:\n"
f" - Total Splits: {total_splits}\n"
f" - Splits by Threshold: {splits_by_threshold}\n"
f" - Splits by Max Chunk Size: {splits_by_max_chunk_size}\n"
f" - Last Split: {splits_by_last_split}\n"
f" - Minimum Token Size of Split: {min_token_size}\n"
f" - Maximum Token Size of Split: {max_token_size}\n"
f" - Similarity Split Ratio: {splits_by_similarity_ratio:.2f}"
self.statistics = SplitStatistics(
total_documents=len(docs),
total_splits=total_splits,
splits_by_threshold=splits_by_threshold,
splits_by_max_chunk_size=splits_by_max_chunk_size,
splits_by_last_split=splits_by_last_split,
min_token_size=min_token_size,
max_token_size=max_token_size,
splits_by_similarity_ratio=splits_by_similarity_ratio,
)
logger.info(statistics_summary)
return splits
......@@ -376,12 +408,13 @@ class RollingWindowSplitter(BaseSplitter):
)
docs = split_to_sentences(docs[0])
encoded_docs = self.encode_documents(docs)
similarities = self.calculate_similarity_scores(encoded_docs)
if self.dynamic_threshold:
self.find_optimal_threshold(docs, encoded_docs)
self.find_optimal_threshold(docs, similarities)
else:
self.calculated_threshold = self.encoder.score_threshold
similarities = self.calculate_similarity_scores(encoded_docs)
split_indices = self.find_split_indices(similarities=similarities)
splits = self.split_documents(docs, split_indices, similarities)
self.plot_similarity_scores(similarities, split_indices, splits)
logger.info(self.statistics)
return splits
......@@ -2,8 +2,9 @@ import re
from typing import Any
from colorama import Fore, Style
from semantic_router.splitters import RollingWindowSplitter
from semantic_router.encoders import BaseEncoder
from semantic_router.splitters import RollingWindowSplitter
class UnstructuredSemanticSplitter:
......@@ -126,4 +127,4 @@ class UnstructuredSemanticSplitter:
return chunks_with_title
async def __call__(self, elements: list[dict[str, Any]]) -> list[dict[str, Any]]:
return await self.split_grouped_elements(elements, self.splitter)
\ No newline at end of file
return await self.split_grouped_elements(elements, self.splitter)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment