Skip to content
Snippets Groups Projects
Unverified Commit 1dd0220b authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Removed Running Avg Sim Splitter

We need to figure out how to check if first two messages belong to the same topic before implementing.

At the moment they code is written such that the first two messages automatically are sorted into the same topic.
parent 1c1be294
Branches
Tags
No related merge requests found
from typing import List
from semantic_router.splitters.base import BaseSplitter
from semantic_router.encoders import BaseEncoder
import numpy as np
from semantic_router.schema import DocumentSplit
class RunningAvgSimSplitter(BaseSplitter):
def __init__(
self,
encoder: BaseEncoder,
name: str = "consecutive_similarity_splitter",
similarity_threshold: float = 0.04,
):
super().__init__(
name=name,
similarity_threshold=similarity_threshold,
encoder=encoder
)
def __call__(self, docs: List[str]):
doc_embeds = self.encoder(docs)
norm_embeds = doc_embeds / np.linalg.norm(doc_embeds, axis=1, keepdims=True)
sim_matrix = np.matmul(norm_embeds, norm_embeds.T)
total_docs = len(docs)
splits = []
curr_split_start_idx = 0
# Initialize an empty list to store similarity scores for the current topic segment
segment_sim_scores = []
for idx in range(total_docs - 1):
curr_sim_score = sim_matrix[idx][idx + 1]
segment_sim_scores.append(curr_sim_score)
# Calculate running average of similarity scores for the current segment
running_avg = np.mean(segment_sim_scores)
# Check for a significant drop in similarity compared to the running average
similarity_drop = running_avg - curr_sim_score
if idx > 0 and similarity_drop > self.similarity_threshold:
splits.append(
DocumentSplit(
docs=list(docs[curr_split_start_idx:idx + 1]), # Include current doc in the split
is_triggered=True,
triggered_score=similarity_drop,
)
)
curr_split_start_idx = idx + 1
# Reset the similarity scores for the new segment
segment_sim_scores = [curr_sim_score]
# Add the last split
if curr_split_start_idx < total_docs:
splits.append(DocumentSplit(docs=list(docs[curr_split_start_idx:])))
return splits
\ No newline at end of file
......@@ -2,7 +2,6 @@ from pydantic.v1 import BaseModel, Field
from typing import Union, List, Literal, Tuple
from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter
from semantic_router.splitters.running_avg_sim import RunningAvgSimSplitter
from semantic_router.encoders import BaseEncoder
from semantic_router.schema import Message
from semantic_router.schema import DocumentSplit
......@@ -40,7 +39,7 @@ class Conversation(BaseModel):
encoder: BaseEncoder,
threshold: float = 0.5,
split_method: Literal[
"consecutive_similarity", "cumulative_similarity", "running_avg_similarity"
"consecutive_similarity", "cumulative_similarity"
] = "consecutive_similarity",
):
......@@ -53,8 +52,8 @@ class Conversation(BaseModel):
:type encoder: BaseEncoder
:param threshold: The similarity threshold to be used by the splitter. Defaults to 0.5.
:type threshold: float
:param split_method: The method to be used for splitting the conversation into topics. Can be one of "consecutive_similarity", "cumulative_similarity", or "running_avg_similarity". Defaults to "consecutive_similarity".
:type split_method: Literal["consecutive_similarity", "cumulative_similarity", "running_avg_similarity"]
:param split_method: The method to be used for splitting the conversation into topics. Can be one of "consecutive_similarity" or "cumulative_similarity". Defaults to "consecutive_similarity".
:type split_method: Literal["consecutive_similarity", "cumulative_similarity"]
:raises ValueError: If an invalid split method is provided.
"""
......@@ -62,8 +61,6 @@ class Conversation(BaseModel):
self.splitter = ConsecutiveSimSplitter(encoder=encoder, similarity_threshold=threshold)
elif split_method == "cumulative_similarity":
self.splitter = CumulativeSimSplitter(encoder=encoder, similarity_threshold=threshold)
elif split_method == "running_avg_similarity":
self.splitter = RunningAvgSimSplitter(encoder=encoder, similarity_threshold=threshold)
else:
raise ValueError(f"Invalid split method: {split_method}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment