Skip to content
Snippets Groups Projects
Commit 1eb0fc70 authored by Simonas's avatar Simonas
Browse files

chore: Added comments

parent 6f9fc0e2
No related branches found
No related tags found
No related merge requests found
...@@ -18,99 +18,89 @@ class DynamicCumulativeSplitter(BaseSplitter): ...@@ -18,99 +18,89 @@ class DynamicCumulativeSplitter(BaseSplitter):
self, self,
encoder: BaseEncoder, encoder: BaseEncoder,
name: str = "dynamic_cumulative_similarity_splitter", name: str = "dynamic_cumulative_similarity_splitter",
score_threshold: float = 0.3, score_threshold: float = 0.9,
): ):
super().__init__(name=name, encoder=encoder, score_threshold=score_threshold) super().__init__(name=name, encoder=encoder, score_threshold=score_threshold)
# Log the initialization details
logger.info( logger.info(
f"Initialized {self.name} with score threshold: {self.score_threshold}" f"Initialized {self.name} with score threshold: {self.score_threshold}"
) )
def encode_documents(self, docs: List[str]) -> np.ndarray: def encode_documents(self, docs: List[str]) -> np.ndarray:
# Encode the documents using the provided encoder and return as a numpy array
encoded_docs = self.encoder(docs) encoded_docs = self.encoder(docs)
encoded_docs_array = np.array(encoded_docs)
logger.info(f"Encoded {len(docs)} documents") logger.info(f"Encoded {len(docs)} documents")
return encoded_docs_array return np.array(encoded_docs)
def adjust_threshold(self, similarities): def adjust_threshold(self, similarities):
if len(similarities) > 5: # Adjust the similarity threshold based on recent similarities
# Calculate recent mean similarity and standard deviation if len(similarities) <= 5:
recent_similarities = similarities[-5:] # If not enough data, return the default score threshold
mean_similarity = np.mean(recent_similarities) return self.score_threshold
std_dev_similarity = np.std(recent_similarities)
logger.debug( # Calculate mean and standard deviation of the last 5 similarities
f"Recent mean similarity: {mean_similarity}, " recent_similarities = similarities[-5:]
f"std dev: {std_dev_similarity}" mean_similarity, std_dev_similarity = np.mean(recent_similarities), np.std(
) recent_similarities
)
# Calculate the rate of change (delta) for mean
# similarity and standard deviation
if len(similarities) > 10:
previous_similarities = similarities[-10:-5]
previous_mean_similarity = np.mean(previous_similarities)
previous_std_dev_similarity = np.std(previous_similarities)
delta_mean = mean_similarity - previous_mean_similarity
delta_std_dev = std_dev_similarity - previous_std_dev_similarity
else:
delta_mean = delta_std_dev = 0
# Adjust the threshold based on the deviation from the mean similarity
# and the rate of change in mean similarity and standard deviation
adjustment_factor = (
std_dev_similarity + abs(delta_mean) + abs(delta_std_dev)
)
adjusted_threshold = mean_similarity - adjustment_factor
# Dynamically set the lower bound based on the rate of change
dynamic_lower_bound = max(0.2, 0.2 + delta_mean - delta_std_dev)
# Introduce a minimum split threshold that is higher than the
# dynamic lower bound
min_split_threshold = 0.3
# Ensure the new threshold is within a sensible range,
# dynamically adjusting the lower bound
# and considering the minimum split threshold
new_threshold = max(
np.clip(adjusted_threshold, dynamic_lower_bound, self.score_threshold),
min_split_threshold,
)
logger.debug( # Calculate the change in mean and standard deviation if enough data is
f"Adjusted threshold to {new_threshold}, with dynamic lower " # available
f"bound {dynamic_lower_bound}" delta_mean = delta_std_dev = 0
) if len(similarities) > 10:
return new_threshold previous_similarities = similarities[-10:-5]
return self.score_threshold delta_mean = mean_similarity - np.mean(previous_similarities)
delta_std_dev = std_dev_similarity - np.std(previous_similarities)
# Adjust the threshold based on the calculated metrics
adjustment_factor = std_dev_similarity + abs(delta_mean) + abs(delta_std_dev)
adjusted_threshold = mean_similarity - adjustment_factor
dynamic_lower_bound = max(0.2, 0.2 + delta_mean - delta_std_dev)
min_split_threshold = 0.3
# Ensure the new threshold is within a sensible range
new_threshold = max(
np.clip(adjusted_threshold, dynamic_lower_bound, self.score_threshold),
min_split_threshold,
)
logger.debug(
f"Adjusted threshold to {new_threshold}, with dynamic lower "
f"bound {dynamic_lower_bound}"
)
return new_threshold
def calculate_dynamic_context_similarity(self, encoded_docs): def calculate_dynamic_context_similarity(self, encoded_docs):
split_indices = [0] # Calculate the dynamic context similarity to determine split indices
similarities = [] split_indices, similarities = [0], []
dynamic_window_size = 5 # Starting window size dynamic_window_size = 5 # Initial window size
norms = np.linalg.norm(
encoded_docs, axis=1
) # Pre-calculate norms for efficiency
norms = np.linalg.norm(encoded_docs, axis=1)
for idx in range(1, len(encoded_docs)): for idx in range(1, len(encoded_docs)):
# Adjust window size based on recent variability # Adjust window size based on the standard deviation of recent similarities
if len(similarities) > 10: if len(similarities) > 10:
std_dev_recent = np.std(similarities[-10:]) std_dev_recent = np.std(similarities[-10:])
dynamic_window_size = 5 if std_dev_recent < 0.05 else 10 dynamic_window_size = 5 if std_dev_recent < 0.05 else 10
# Calculate the similarity for the current document
window_start = max(0, idx - dynamic_window_size) window_start = max(0, idx - dynamic_window_size)
cumulative_context = np.mean(encoded_docs[window_start:idx], axis=0) cumulative_context = np.mean(encoded_docs[window_start:idx], axis=0)
cumulative_norm = np.linalg.norm(cumulative_context) cumulative_norm = np.linalg.norm(cumulative_context)
curr_sim_score = np.dot(cumulative_context, encoded_docs[idx]) / ( curr_sim_score = np.dot(cumulative_context, encoded_docs[idx]) / (
cumulative_norm * norms[idx] + 1e-10 cumulative_norm * norms[idx] + 1e-10
) )
similarities.append(curr_sim_score) similarities.append(curr_sim_score)
# If the similarity is below the dynamically adjusted threshold,
dynamic_threshold = self.adjust_threshold(similarities) # mark a new split
if curr_sim_score < dynamic_threshold: if curr_sim_score < self.adjust_threshold(similarities):
split_indices.append(idx) split_indices.append(idx)
return split_indices, similarities return split_indices, similarities
def __call__(self, docs: List[str]): def __call__(self, docs: List[str]):
# Main method to split the documents
logger.info(f"Splitting {len(docs)} documents") logger.info(f"Splitting {len(docs)} documents")
encoded_docs = self.encode_documents(docs) encoded_docs = self.encode_documents(docs)
split_indices, similarities = self.calculate_dynamic_context_similarity( split_indices, similarities = self.calculate_dynamic_context_similarity(
...@@ -118,6 +108,7 @@ class DynamicCumulativeSplitter(BaseSplitter): ...@@ -118,6 +108,7 @@ class DynamicCumulativeSplitter(BaseSplitter):
) )
splits = [] splits = []
# Create DocumentSplit objects for each identified split
last_idx = 0 last_idx = 0
for idx in split_indices: for idx in split_indices:
if idx == 0: if idx == 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment