From 6ab5584869d5fd446c91bb3dd62b875e85aecc74 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Thu, 1 Feb 2024 13:48:16 +0400
Subject: [PATCH] Found Promising Splitter Technique

Also commented out splitter code in Conversation class temporarily due to circular reference.
---
 semantic_router/schema.py                    | 98 +++++++++----------
 semantic_router/splitters/consecutive_sim.py | 99 ++++++++++++++++++++
 2 files changed, 148 insertions(+), 49 deletions(-)

diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 129d154a..0876bae4 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -11,9 +11,9 @@ from semantic_router.encoders import (
     OpenAIEncoder,
 )
 # from semantic_router.utils.splitters import DocumentSplit, semantic_splitter
-from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
-from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter
-from semantic_router.splitters.cav_sim import CAVSimSplitter
+# from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
+# from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter
+# from semantic_router.splitters.cav_sim import CAVSimSplitter
 
 
 class EncoderType(Enum):
@@ -77,60 +77,60 @@ class DocumentSplit(BaseModel):
 
 class Conversation(BaseModel):
     messages: List[Message]
-    topics: List[Tuple[int, str]] = []
-    splitter = None
+    # topics: List[Tuple[int, str]] = []
+    # splitter = None
 
     def add_new_messages(self, new_messages: List[Message]):
         self.messages.extend(new_messages)
 
-    def configure_splitter(
-        self,
-        encoder: BaseEncoder,
-        threshold: float = 0.5,
-        split_method: Literal[
-            "consecutive_similarity", "cumulative_similarity", "cav_similarity"
-        ] = "consecutive_similarity",
-    ):
-        if split_method == "consecutive_similarity":
-            self.splitter = ConsecutiveSimSplitter(encoder=encoder, similarity_threshold=threshold)
-        elif split_method == "cumulative_similarity":
-            self.splitter = CumulativeSimSplitter(encoder=encoder, similarity_threshold=threshold)
-        elif split_method == "cav_similarity":
-            self.splitter = CAVSimSplitter(encoder=encoder, similarity_threshold=threshold)
-        else:
-            raise ValueError(f"Invalid split method: {split_method}")
-
-    def split_by_topic(self):
-        if self.splitter is None:
-            raise ValueError("Splitter is not configured. Please call configure_splitter first.")
+    # def configure_splitter(
+    #     self,
+    #     encoder: BaseEncoder,
+    #     threshold: float = 0.5,
+    #     split_method: Literal[
+    #         "consecutive_similarity", "cumulative_similarity", "cav_similarity"
+    #     ] = "consecutive_similarity",
+    # ):
+    #     if split_method == "consecutive_similarity":
+    #         self.splitter = ConsecutiveSimSplitter(encoder=encoder, similarity_threshold=threshold)
+    #     elif split_method == "cumulative_similarity":
+    #         self.splitter = CumulativeSimSplitter(encoder=encoder, similarity_threshold=threshold)
+    #     elif split_method == "cav_similarity":
+    #         self.splitter = CAVSimSplitter(encoder=encoder, similarity_threshold=threshold)
+    #     else:
+    #         raise ValueError(f"Invalid split method: {split_method}")
+
+    # def split_by_topic(self):
+    #     if self.splitter is None:
+    #         raise ValueError("Splitter is not configured. Please call configure_splitter first.")
         
-        # Get the messages that haven't been clustered into topics yet
-        unclustered_messages = self.messages[len(self.topics):]
+    #     # Get the messages that haven't been clustered into topics yet
+    #     unclustered_messages = self.messages[len(self.topics):]
         
-        # Check if there are any messages that have been assigned topics
-        if len(self.topics) >= 1:
-            # Include the last message in the docs
-            docs = [self.topics[-1][1]]
-        else:
-            # No messages have been assigned topics yet
-            docs = []
+    #     # Check if there are any messages that have been assigned topics
+    #     if len(self.topics) >= 1:
+    #         # Include the last message in the docs
+    #         docs = [self.topics[-1][1]]
+    #     else:
+    #         # No messages have been assigned topics yet
+    #         docs = []
         
-        # Add the unclustered messages to the docs
-        docs.extend([f"{m.role}: {m.content}" for m in unclustered_messages])
+    #     # Add the unclustered messages to the docs
+    #     docs.extend([f"{m.role}: {m.content}" for m in unclustered_messages])
         
-        # Use the splitter to split the documents
-        new_topics = self.splitter(docs)
+    #     # Use the splitter to split the documents
+    #     new_topics = self.splitter(docs)
         
-        # Check if the first new topic includes the first new message.
-        # This means that the first new message shares the same topic as the last old message to have been assigned a topic.
-        if docs[-len(unclustered_messages)] in new_topics[0].docs:
-            start = self.topics[-1][0]
-        else:
-            start = len(self.topics) + 1
+    #     # Check if the first new topic includes the first new message.
+    #     # This means that the first new message shares the same topic as the last old message to have been assigned a topic.
+    #     if docs[-len(unclustered_messages)] in new_topics[0].docs:
+    #         start = self.topics[-1][0]
+    #     else:
+    #         start = self.topics[-1][0] + 1
         
-        # Add the new topics to the list of topics with unique IDs
-        for i, topic in enumerate(new_topics, start=start):
-            for message in topic.docs:
-                self.topics.append((i, message))
+    #     # Add the new topics to the list of topics with unique IDs
+    #     for i, topic in enumerate(new_topics, start=start):
+    #         for message in topic.docs:
+    #             self.topics.append((i, message))
         
-        return new_topics
\ No newline at end of file
+    #     return new_topics
\ No newline at end of file
diff --git a/semantic_router/splitters/consecutive_sim.py b/semantic_router/splitters/consecutive_sim.py
index a9038750..b61610e5 100644
--- a/semantic_router/splitters/consecutive_sim.py
+++ b/semantic_router/splitters/consecutive_sim.py
@@ -44,4 +44,103 @@ class ConsecutiveSimSplitter(BaseSplitter):
                 curr_split_start_idx = idx
                 curr_split_num += 1
         splits.append(DocumentSplit(docs=list(docs[curr_split_start_idx:])))
+        return splits
+    
+
+class ConsecutiveAvgSimSplitter(BaseSplitter):
+    def __init__(
+        self,
+        encoder: BaseEncoder,
+        name: str = "consecutive_similarity_splitter",
+        similarity_threshold: float = 0.45,
+        drop_threshold: float = 0.1  # Additional parameter to control the drop threshold
+    ):
+        super().__init__(
+            name=name, 
+            similarity_threshold=similarity_threshold,
+            encoder=encoder
+        )
+
+    def __call__(self, docs: List[str], drop_threshold):
+        doc_embeds = self.encoder(docs)
+        norm_embeds = doc_embeds / np.linalg.norm(doc_embeds, axis=1, keepdims=True)
+        sim_matrix = np.matmul(norm_embeds, norm_embeds.T)
+        total_docs = len(docs)
+        splits = []
+        curr_split_start_idx = 0
+
+        # Calculate similarity scores between consecutive documents
+        sim_scores = [sim_matrix[i][i+1] for i in range(total_docs - 1)]
+
+        # Calculate running average of similarity scores
+        running_avg = [np.mean(sim_scores[:i+1]) for i in range(len(sim_scores))]
+
+        for idx, curr_sim_score in enumerate(sim_scores):
+            # Check for a significant drop in similarity compared to the running average
+            if idx > 0 and (running_avg[idx-1] - curr_sim_score) > drop_threshold:
+                splits.append(
+                    DocumentSplit(
+                        docs=list(docs[curr_split_start_idx:idx+1]),  # Include current doc in the split
+                        is_triggered=True,
+                        triggered_score=curr_sim_score,
+                    )
+                )
+                curr_split_start_idx = idx + 1  # Update the start index for the next split
+
+        # Add the last split
+        if curr_split_start_idx < total_docs:
+            splits.append(DocumentSplit(docs=list(docs[curr_split_start_idx:])))
+
+        return splits
+    
+
+class ConsecutiveAvgSimSplitter2(BaseSplitter):
+    def __init__(
+        self,
+        encoder: BaseEncoder,
+        name: str = "consecutive_similarity_splitter",
+        similarity_threshold: float = 0.45,
+        drop_threshold: float = 0.1  # Additional parameter to control the drop threshold
+    ):
+        super().__init__(
+            name=name, 
+            similarity_threshold=similarity_threshold,
+            encoder=encoder
+        )
+
+    def __call__(self, docs: List[str], drop_threshold):
+        doc_embeds = self.encoder(docs)
+        norm_embeds = doc_embeds / np.linalg.norm(doc_embeds, axis=1, keepdims=True)
+        sim_matrix = np.matmul(norm_embeds, norm_embeds.T)
+        total_docs = len(docs)
+        splits = []
+        curr_split_start_idx = 0
+
+        # Initialize an empty list to store similarity scores for the current topic segment
+        segment_sim_scores = []
+
+        for idx in range(total_docs - 1):
+            curr_sim_score = sim_matrix[idx][idx + 1]
+            segment_sim_scores.append(curr_sim_score)
+
+            # Calculate running average of similarity scores for the current segment
+            running_avg = np.mean(segment_sim_scores)
+
+            # Check for a significant drop in similarity compared to the running average
+            if idx > 0 and (running_avg - curr_sim_score) > drop_threshold:
+                splits.append(
+                    DocumentSplit(
+                        docs=list(docs[curr_split_start_idx:idx + 1]),  # Include current doc in the split
+                        is_triggered=True,
+                        triggered_score=curr_sim_score,
+                    )
+                )
+                curr_split_start_idx = idx + 1
+                # Reset the similarity scores for the new segment
+                segment_sim_scores = [curr_sim_score]
+
+        # Add the last split
+        if curr_split_start_idx < total_docs:
+            splits.append(DocumentSplit(docs=list(docs[curr_split_start_idx:])))
+
         return splits
\ No newline at end of file
-- 
GitLab