diff --git a/semantic_router/splitters/rolling_window.py b/semantic_router/splitters/rolling_window.py index 092433fe3a91c5fb139e097e5fcb90256f036930..2f80ff3b7782ed4247f6624d00262715d6140e60 100644 --- a/semantic_router/splitters/rolling_window.py +++ b/semantic_router/splitters/rolling_window.py @@ -39,6 +39,7 @@ class RollingWindowSplitter(BaseSplitter): def __init__( self, encoder: BaseEncoder, + name="rolling_window_splitter", threshold_adjustment=0.01, dynamic_threshold: bool = True, window_size=5, @@ -46,7 +47,7 @@ class RollingWindowSplitter(BaseSplitter): max_split_tokens=300, split_tokens_tolerance=10, plot_splits=False, - name="rolling_window_splitter", + enable_statistics=False, ): super().__init__(name=name, encoder=encoder) self.calculated_threshold: float @@ -58,6 +59,7 @@ class RollingWindowSplitter(BaseSplitter): self.min_split_tokens = min_split_tokens self.max_split_tokens = max_split_tokens self.split_tokens_tolerance = split_tokens_tolerance + self.enable_statistics = enable_statistics self.statistics: SplitStatistics def __call__(self, docs: List[str]) -> List[DocumentSplit]: @@ -88,8 +90,13 @@ class RollingWindowSplitter(BaseSplitter): self.calculated_threshold = self.encoder.score_threshold split_indices = self._find_split_indices(similarities=similarities) splits = self._split_documents(docs, split_indices, similarities) - self.plot_similarity_scores(similarities, split_indices, splits) - logger.info(self.statistics) + + if self.plot_splits: + self.plot_similarity_scores(similarities, split_indices, splits) + + if self.enable_statistics: + print(self.statistics) + return splits def _encode_documents(self, docs: List[str]) -> np.ndarray: @@ -174,7 +181,7 @@ class RollingWindowSplitter(BaseSplitter): logger.debug(f"Iteration {iteration}: Adjusting low to {low}") iteration += 1 - logger.info( + logger.debug( f"Optimal threshold {self.calculated_threshold} found " f"with median tokens ({median_tokens}) in target range " f"({self.min_split_tokens}-{self.max_split_tokens})." @@ -325,9 +332,7 @@ class RollingWindowSplitter(BaseSplitter): ) return - if not self.plot_splits: - return - fig, axs = plt.subplots(2, 1, figsize=(12, 12)) # Adjust for two plots + _, axs = plt.subplots(2, 1, figsize=(12, 12)) # Adjust for two plots # Plot 1: Similarity Scores axs[0].plot(similarities, label="Similarity Scores", marker="o")