from typing import List

import numpy as np

from semantic_router.encoders import BaseEncoder
from semantic_router.schema import DocumentSplit
from semantic_router.splitters.base import BaseSplitter
from semantic_router.utils.logger import logger


class DynamicCumulativeSplitter(BaseSplitter):
    """
    Splits documents dynamically based on the cumulative similarity of document
    embeddings, adjusting thresholds and window sizes based on recent similarities.
    """

    def __init__(
        self,
        encoder: BaseEncoder,
        name: str = "dynamic_cumulative_similarity_splitter",
        score_threshold: float = 0.9,
    ):
        super().__init__(name=name, encoder=encoder, score_threshold=score_threshold)
        # Log the initialization details
        logger.info(
            f"Initialized {self.name} with score threshold: {self.score_threshold}"
        )

    def encode_documents(self, docs: List[str]) -> np.ndarray:
        # Encode the documents using the provided encoder and return as a numpy array
        encoded_docs = self.encoder(docs)
        logger.info(f"Encoded {len(docs)} documents")
        return np.array(encoded_docs)

    def adjust_threshold(self, similarities):
        # Adjust the similarity threshold based on recent similarities
        if len(similarities) <= 5:
            # If not enough data, return the default score threshold
            return self.score_threshold

        # Calculate mean and standard deviation of the last 5 similarities
        recent_similarities = similarities[-5:]
        mean_similarity, std_dev_similarity = np.mean(recent_similarities), np.std(
            recent_similarities
        )

        # Calculate the change in mean and standard deviation if enough data is
        # available
        delta_mean = delta_std_dev = 0
        if len(similarities) > 10:
            previous_similarities = similarities[-10:-5]
            delta_mean = mean_similarity - np.mean(previous_similarities)
            delta_std_dev = std_dev_similarity - np.std(previous_similarities)

        # Adjust the threshold based on the calculated metrics
        adjustment_factor = std_dev_similarity + abs(delta_mean) + abs(delta_std_dev)
        adjusted_threshold = mean_similarity - adjustment_factor
        dynamic_lower_bound = max(0.2, 0.2 + delta_mean - delta_std_dev)
        min_split_threshold = 0.3

        # Ensure the new threshold is within a sensible range
        new_threshold = max(
            np.clip(adjusted_threshold, dynamic_lower_bound, self.score_threshold),
            min_split_threshold,
        )
        logger.debug(
            f"Adjusted threshold to {new_threshold}, with dynamic lower "
            f"bound {dynamic_lower_bound}"
        )
        return new_threshold

    def calculate_dynamic_context_similarity(self, encoded_docs):
        # Calculate the dynamic context similarity to determine split indices
        split_indices, similarities = [0], []
        dynamic_window_size = 5  # Initial window size
        norms = np.linalg.norm(
            encoded_docs, axis=1
        )  # Pre-calculate norms for efficiency

        for idx in range(1, len(encoded_docs)):
            # Adjust window size based on the standard deviation of recent similarities
            if len(similarities) > 10:
                std_dev_recent = np.std(similarities[-10:])
                dynamic_window_size = 5 if std_dev_recent < 0.05 else 10

            # Calculate the similarity for the current document
            window_start = max(0, idx - dynamic_window_size)
            cumulative_context = np.mean(encoded_docs[window_start:idx], axis=0)
            cumulative_norm = np.linalg.norm(cumulative_context)
            curr_sim_score = np.dot(cumulative_context, encoded_docs[idx]) / (
                cumulative_norm * norms[idx] + 1e-10
            )

            similarities.append(curr_sim_score)
            # If the similarity is below the dynamically adjusted threshold,
            # mark a new split
            if curr_sim_score < self.adjust_threshold(similarities):
                split_indices.append(idx)

        return split_indices, similarities

    def __call__(self, docs: List[str]):
        # Main method to split the documents
        logger.info(f"Splitting {len(docs)} documents")
        encoded_docs = self.encode_documents(docs)
        split_indices, similarities = self.calculate_dynamic_context_similarity(
            encoded_docs
        )
        splits = []

        # Create DocumentSplit objects for each identified split
        last_idx = 0
        for idx in split_indices:
            if idx == 0:
                continue
            splits.append(
                DocumentSplit(
                    docs=docs[last_idx:idx],
                    is_triggered=(idx - last_idx > 1),
                    triggered_score=(
                        similarities[idx - 1] if idx - 1 < len(similarities) else None
                    ),
                )
            )
            last_idx = idx
        splits.append(
            DocumentSplit(
                docs=docs[last_idx:],
                is_triggered=(len(docs) - last_idx > 1),
                triggered_score=similarities[-1] if similarities else None,
            )
        )
        logger.info(f"Completed splitting documents into {len(splits)} splits")

        return splits