Skip to content
Snippets Groups Projects
Unverified Commit 12af7a50 authored by James Briggs's avatar James Briggs Committed by GitHub
Browse files

Merge pull request #220 from aurelio-labs/simonas/reduce-logging-in-splitter

chore: Reduction of logs for splitter
parents 2adf1245 679c1b7f
No related branches found
No related tags found
No related merge requests found
...@@ -39,6 +39,7 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -39,6 +39,7 @@ class RollingWindowSplitter(BaseSplitter):
def __init__( def __init__(
self, self,
encoder: BaseEncoder, encoder: BaseEncoder,
name="rolling_window_splitter",
threshold_adjustment=0.01, threshold_adjustment=0.01,
dynamic_threshold: bool = True, dynamic_threshold: bool = True,
window_size=5, window_size=5,
...@@ -46,7 +47,7 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -46,7 +47,7 @@ class RollingWindowSplitter(BaseSplitter):
max_split_tokens=300, max_split_tokens=300,
split_tokens_tolerance=10, split_tokens_tolerance=10,
plot_splits=False, plot_splits=False,
name="rolling_window_splitter", enable_statistics=False,
): ):
super().__init__(name=name, encoder=encoder) super().__init__(name=name, encoder=encoder)
self.calculated_threshold: float self.calculated_threshold: float
...@@ -58,6 +59,7 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -58,6 +59,7 @@ class RollingWindowSplitter(BaseSplitter):
self.min_split_tokens = min_split_tokens self.min_split_tokens = min_split_tokens
self.max_split_tokens = max_split_tokens self.max_split_tokens = max_split_tokens
self.split_tokens_tolerance = split_tokens_tolerance self.split_tokens_tolerance = split_tokens_tolerance
self.enable_statistics = enable_statistics
self.statistics: SplitStatistics self.statistics: SplitStatistics
def __call__(self, docs: List[str]) -> List[DocumentSplit]: def __call__(self, docs: List[str]) -> List[DocumentSplit]:
...@@ -88,8 +90,13 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -88,8 +90,13 @@ class RollingWindowSplitter(BaseSplitter):
self.calculated_threshold = self.encoder.score_threshold self.calculated_threshold = self.encoder.score_threshold
split_indices = self._find_split_indices(similarities=similarities) split_indices = self._find_split_indices(similarities=similarities)
splits = self._split_documents(docs, split_indices, similarities) splits = self._split_documents(docs, split_indices, similarities)
self.plot_similarity_scores(similarities, split_indices, splits)
logger.info(self.statistics) if self.plot_splits:
self.plot_similarity_scores(similarities, split_indices, splits)
if self.enable_statistics:
print(self.statistics)
return splits return splits
def _encode_documents(self, docs: List[str]) -> np.ndarray: def _encode_documents(self, docs: List[str]) -> np.ndarray:
...@@ -174,7 +181,7 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -174,7 +181,7 @@ class RollingWindowSplitter(BaseSplitter):
logger.debug(f"Iteration {iteration}: Adjusting low to {low}") logger.debug(f"Iteration {iteration}: Adjusting low to {low}")
iteration += 1 iteration += 1
logger.info( logger.debug(
f"Optimal threshold {self.calculated_threshold} found " f"Optimal threshold {self.calculated_threshold} found "
f"with median tokens ({median_tokens}) in target range " f"with median tokens ({median_tokens}) in target range "
f"({self.min_split_tokens}-{self.max_split_tokens})." f"({self.min_split_tokens}-{self.max_split_tokens})."
...@@ -325,9 +332,7 @@ class RollingWindowSplitter(BaseSplitter): ...@@ -325,9 +332,7 @@ class RollingWindowSplitter(BaseSplitter):
) )
return return
if not self.plot_splits: _, axs = plt.subplots(2, 1, figsize=(12, 12)) # Adjust for two plots
return
fig, axs = plt.subplots(2, 1, figsize=(12, 12)) # Adjust for two plots
# Plot 1: Similarity Scores # Plot 1: Similarity Scores
axs[0].plot(similarities, label="Similarity Scores", marker="o") axs[0].plot(similarities, label="Similarity Scores", marker="o")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment