Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
rolling_window.py 13.26 KiB
from typing import List

import numpy as np

from semantic_router.encoders.base import BaseEncoder
from semantic_router.schema import DocumentSplit
from semantic_router.splitters.base import BaseSplitter
from semantic_router.splitters.utils import split_to_sentences, tiktoken_length
from semantic_router.utils.logger import logger


class RollingWindowSplitter(BaseSplitter):
    def __init__(
        self,
        encoder: BaseEncoder,
        threshold_adjustment=0.01,
        window_size=5,
        min_split_tokens=100,
        max_split_tokens=300,
        split_tokens_tolerance=10,
        plot_splits=False,
        name="rolling_window_splitter",
    ):
        super().__init__(name=name, encoder=encoder)
        self.calculated_threshold: float
        self.encoder = encoder
        self.threshold_adjustment = threshold_adjustment
        self.window_size = window_size
        self.plot_splits = plot_splits
        self.min_split_tokens = min_split_tokens
        self.max_split_tokens = max_split_tokens
        self.split_tokens_tolerance = split_tokens_tolerance

    def encode_documents(self, docs: List[str]) -> np.ndarray:
        try:
            embeddings = self.encoder(docs)
            return np.array(embeddings)
        except Exception as e:
            logger.error(f"Error encoding documents {docs}: {e}")
            raise

    def calculate_similarity_scores(self, encoded_docs: np.ndarray) -> List[float]:
        raw_similarities = []
        for idx in range(1, len(encoded_docs)):
            window_start = max(0, idx - self.window_size)
            cumulative_context = np.mean(encoded_docs[window_start:idx], axis=0)
            curr_sim_score = np.dot(cumulative_context, encoded_docs[idx]) / (
                np.linalg.norm(cumulative_context) * np.linalg.norm(encoded_docs[idx])
                + 1e-10
            )
            raw_similarities.append(curr_sim_score)
        return raw_similarities

    def find_split_indices(self, similarities: List[float]) -> List[int]:
        split_indices = []
        for idx in range(1, len(similarities)):
            if similarities[idx] < self.calculated_threshold:
                split_indices.append(idx + 1)
        return split_indices

    def find_optimal_threshold(self, docs: List[str], encoded_docs: np.ndarray):
        token_counts = [tiktoken_length(doc) for doc in docs]
        cumulative_token_counts = np.cumsum([0] + token_counts)
        similarity_scores = self.calculate_similarity_scores(encoded_docs)

        # Analyze the distribution of similarity scores to set initial bounds
        median_score = np.median(similarity_scores)
        std_dev = np.std(similarity_scores)

        # Set initial bounds based on median and standard deviation
        low = max(0.0, float(median_score - std_dev))
        high = min(1.0, float(median_score + std_dev))

        iteration = 0
        while low <= high:
            self.calculated_threshold = (low + high) / 2
            logger.info(
                f"Iteration {iteration}: Trying threshold: {self.calculated_threshold}"
            )
            split_indices = self.find_split_indices(similarity_scores)

            # Calculate the token counts for each split using the cumulative sums
            split_token_counts = [
                cumulative_token_counts[end] - cumulative_token_counts[start]
                for start, end in zip(
                    [0] + split_indices, split_indices + [len(token_counts)]
                )
            ]

            # Calculate the median token count for the splits
            median_tokens = np.median(split_token_counts)
            logger.info(
                f"Iteration {iteration}: Median tokens per split: {median_tokens}"
            )
            if (
                self.min_split_tokens - self.split_tokens_tolerance
                <= median_tokens
                <= self.max_split_tokens + self.split_tokens_tolerance
            ):
                logger.info(
                    f"Iteration {iteration}: "
                    f"Optimal threshold {self.calculated_threshold} found "
                    f"with median tokens ({median_tokens}) in target range "
                    f" {self.min_split_tokens}-{self.max_split_tokens}."
                )
                break
            elif median_tokens < self.min_split_tokens:
                high = self.calculated_threshold - self.threshold_adjustment
                logger.info(f"Iteration {iteration}: Adjusting high to {high}")
            else:
                low = self.calculated_threshold + self.threshold_adjustment
                logger.info(f"Iteration {iteration}: Adjusting low to {low}")
            iteration += 1

        logger.info(f"Final optimal threshold: {self.calculated_threshold}")
        return self.calculated_threshold

    def split_documents(
        self, docs: List[str], split_indices: List[int], similarities: List[float]
    ) -> List[DocumentSplit]:
        """
        This method iterates through each document, appending it to the current split
        until it either reaches a split point (determined by split_indices) or exceeds
        the maximum token limit for a split (self.max_split_tokens).
        When a document causes the current token count to exceed this limit,
        or when a split point is reached and the minimum token requirement is met,
        the current split is finalized and added to the List of splits.
        """
        token_counts = [tiktoken_length(doc) for doc in docs]
        splits, current_split = [], []
        current_tokens_count = 0

        for doc_idx, doc in enumerate(docs):
            doc_token_count = token_counts[doc_idx]

            # Check if current index is a split point based on similarity
            if doc_idx + 1 in split_indices:
                if current_tokens_count + doc_token_count >= self.min_split_tokens:
                    # Include the current document before splitting
                    # if it doesn't exceed the max limit
                    current_split.append(doc)
                    current_tokens_count += doc_token_count

                    triggered_score = (
                        similarities[doc_idx] if doc_idx < len(similarities) else None
                    )
                    splits.append(
                        DocumentSplit(
                            docs=current_split.copy(),
                            is_triggered=True,
                            triggered_score=triggered_score,
                            token_count=current_tokens_count,
                        )
                    )
                    logger.info(
                        f"Split finalized with {current_tokens_count} tokens due to "
                        f"threshold {self.calculated_threshold}."
                    )
                    current_split, current_tokens_count = [], 0
                    continue  # Move to the next document after splitting

            # Check if adding the current document exceeds the max token limit
            if current_tokens_count + doc_token_count > self.max_split_tokens:
                if current_tokens_count >= self.min_split_tokens:
                    splits.append(
                        DocumentSplit(
                            docs=current_split.copy(),
                            is_triggered=False,
                            triggered_score=None,
                            token_count=current_tokens_count,
                        )
                    )
                    logger.info(
                        f"Split finalized with {current_tokens_count} tokens due to "
                        f"exceeding token limit of {self.max_split_tokens}."
                    )
                    current_split, current_tokens_count = [], 0

            current_split.append(doc)
            current_tokens_count += doc_token_count

        # Handle the last split
        if current_split:
            splits.append(
                DocumentSplit(
                    docs=current_split.copy(),
                    is_triggered=False,
                    triggered_score=None,
                    token_count=current_tokens_count,
                )
            )
            logger.info(
                f"Final split added with {current_tokens_count} "
                "tokens due to remaining documents."
            )

        # Validation to ensure no tokens are lost during the split
        original_token_count = sum(token_counts)
        split_token_count = sum(
            [tiktoken_length(doc) for split in splits for doc in split.docs]
        )
        if original_token_count != split_token_count:
            logger.error(
                f"Token count mismatch: {original_token_count} != {split_token_count}"
            )
            raise ValueError(
                f"Token count mismatch: {original_token_count} != {split_token_count}"
            )

        return splits

    def plot_similarity_scores(
        self, similarities: List[float], split_indices: List[int]
    ):
        try:
            from matplotlib import pyplot as plt
        except ImportError:
            logger.warning(
                "Plotting is disabled. Please `pip install "
                "semantic-router[processing]`."
            )
            return

        if not self.plot_splits:
            return
        plt.figure(figsize=(12, 6))
        plt.plot(similarities, label="Similarity Scores", marker="o")
        for split_index in split_indices:
            plt.axvline(
                x=split_index - 1,
                color="r",
                linestyle="--",
                label="Split" if split_index == split_indices[0] else "",
            )
        plt.axhline(
            y=self.calculated_threshold,
            color="g",
            linestyle="-.",
            label="Threshold Similarity Score",
        )

        # Annotating each similarity score
        for i, score in enumerate(similarities):
            plt.annotate(
                f"{score:.2f}",  # Formatting to two decimal places
                (i, score),
                textcoords="offset points",
                xytext=(0, 10),  # Positioning the text above the point
                ha="center",
            )  # Center-align the text

        plt.xlabel("Document Segment Index")
        plt.ylabel("Similarity Score")
        plt.title(
            f"Threshold: {self.calculated_threshold} |"
            f" Window Size: {self.window_size}",
            loc="right",
            fontsize=10,
        )
        plt.suptitle("Document Similarity Scores", fontsize=14)
        plt.legend()
        plt.show()

    def plot_sentence_similarity_scores(
        self, docs: List[str], threshold: float, window_size: int
    ):
        try:
            from matplotlib import pyplot as plt
        except ImportError:
            logger.warning("Plotting is disabled. Please `pip install matplotlib`.")
            return
        """
        Computes similarity scores between the average of the last
        'window_size' sentences and the next one,
        plots a graph of these similarity scores, and prints the first
        sentence after a similarity score below
        a specified threshold.
        """
        sentences = [sentence for doc in docs for sentence in split_to_sentences(doc)]
        encoded_sentences = self.encode_documents(sentences)
        similarity_scores = []

        for i in range(window_size, len(encoded_sentences)):
            window_avg_encoding = np.mean(
                encoded_sentences[i - window_size : i], axis=0
            )
            sim_score = np.dot(window_avg_encoding, encoded_sentences[i]) / (
                np.linalg.norm(window_avg_encoding)
                * np.linalg.norm(encoded_sentences[i])
                + 1e-10
            )
            similarity_scores.append(sim_score)

        plt.figure(figsize=(10, 8))
        plt.plot(similarity_scores, marker="o", linestyle="-", color="b")
        plt.title("Sliding Window Sentence Similarity Scores")
        plt.xlabel("Sentence Index")
        plt.ylabel("Similarity Score")
        plt.grid(True)
        plt.axhline(y=threshold, color="r", linestyle="--", label="Threshold")
        plt.show()

        for i, score in enumerate(similarity_scores):
            if score < threshold:
                print(
                    f"First sentence after similarity score "
                    f"below {threshold}: {sentences[i + window_size]}"
                )

    def __call__(self, docs: List[str]) -> List[DocumentSplit]:
        if not docs:
            raise ValueError("At least one document is required for splitting.")

        if len(docs) == 1:
            token_count = tiktoken_length(docs[0])
            if token_count > self.max_split_tokens:
                logger.warning(
                    f"Single document exceeds the maximum token limit "
                    f"of {self.max_split_tokens}. "
                    "Splitting to sentences before semantically splitting."
                )
            docs = split_to_sentences(docs[0])
        encoded_docs = self.encode_documents(docs)
        self.find_optimal_threshold(docs, encoded_docs)
        similarities = self.calculate_similarity_scores(encoded_docs)
        split_indices = self.find_split_indices(similarities=similarities)
        splits = self.split_documents(docs, split_indices, similarities)
        self.plot_similarity_scores(similarities, split_indices)
        return splits