From 53ae469f37efc47191a992bd8c0ac8126b7c5633 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Thu, 1 Feb 2024 17:21:05 +0400
Subject: [PATCH] Fixed Bugs with Splitter in Conversation

The split_by_topic was adding the last message from the previous splitting to the current splitting output, resulting in duplicates.

I also made the code more readable.

Also moved Conversation from schema.py to text.py to avoid cyclic imports.
---
 semantic_router/schema.py            |  68 +-------------
 semantic_router/splitters/cav_sim.py |  86 ------------------
 semantic_router/text.py              | 129 +++++++++++++++++++++++++++
 3 files changed, 131 insertions(+), 152 deletions(-)
 delete mode 100644 semantic_router/splitters/cav_sim.py
 create mode 100644 semantic_router/text.py

diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 0876bae4..6fcca01f 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -1,5 +1,5 @@
 from enum import Enum
-from typing import List, Literal, Optional, Tuple
+from typing import List, Optional
 
 from pydantic.v1 import BaseModel
 from pydantic.v1.dataclasses import dataclass
@@ -10,10 +10,6 @@ from semantic_router.encoders import (
     FastEmbedEncoder,
     OpenAIEncoder,
 )
-# from semantic_router.utils.splitters import DocumentSplit, semantic_splitter
-# from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
-# from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter
-# from semantic_router.splitters.cav_sim import CAVSimSplitter
 
 
 class EncoderType(Enum):
@@ -73,64 +69,4 @@ class Message(BaseModel):
 class DocumentSplit(BaseModel):
     docs: List[str]
     is_triggered: bool = False
-    triggered_score: Optional[float] = None
-
-class Conversation(BaseModel):
-    messages: List[Message]
-    # topics: List[Tuple[int, str]] = []
-    # splitter = None
-
-    def add_new_messages(self, new_messages: List[Message]):
-        self.messages.extend(new_messages)
-
-    # def configure_splitter(
-    #     self,
-    #     encoder: BaseEncoder,
-    #     threshold: float = 0.5,
-    #     split_method: Literal[
-    #         "consecutive_similarity", "cumulative_similarity", "cav_similarity"
-    #     ] = "consecutive_similarity",
-    # ):
-    #     if split_method == "consecutive_similarity":
-    #         self.splitter = ConsecutiveSimSplitter(encoder=encoder, similarity_threshold=threshold)
-    #     elif split_method == "cumulative_similarity":
-    #         self.splitter = CumulativeSimSplitter(encoder=encoder, similarity_threshold=threshold)
-    #     elif split_method == "cav_similarity":
-    #         self.splitter = CAVSimSplitter(encoder=encoder, similarity_threshold=threshold)
-    #     else:
-    #         raise ValueError(f"Invalid split method: {split_method}")
-
-    # def split_by_topic(self):
-    #     if self.splitter is None:
-    #         raise ValueError("Splitter is not configured. Please call configure_splitter first.")
-        
-    #     # Get the messages that haven't been clustered into topics yet
-    #     unclustered_messages = self.messages[len(self.topics):]
-        
-    #     # Check if there are any messages that have been assigned topics
-    #     if len(self.topics) >= 1:
-    #         # Include the last message in the docs
-    #         docs = [self.topics[-1][1]]
-    #     else:
-    #         # No messages have been assigned topics yet
-    #         docs = []
-        
-    #     # Add the unclustered messages to the docs
-    #     docs.extend([f"{m.role}: {m.content}" for m in unclustered_messages])
-        
-    #     # Use the splitter to split the documents
-    #     new_topics = self.splitter(docs)
-        
-    #     # Check if the first new topic includes the first new message.
-    #     # This means that the first new message shares the same topic as the last old message to have been assigned a topic.
-    #     if docs[-len(unclustered_messages)] in new_topics[0].docs:
-    #         start = self.topics[-1][0]
-    #     else:
-    #         start = self.topics[-1][0] + 1
-        
-    #     # Add the new topics to the list of topics with unique IDs
-    #     for i, topic in enumerate(new_topics, start=start):
-    #         for message in topic.docs:
-    #             self.topics.append((i, message))
-        
-    #     return new_topics
\ No newline at end of file
+    triggered_score: Optional[float] = None
\ No newline at end of file
diff --git a/semantic_router/splitters/cav_sim.py b/semantic_router/splitters/cav_sim.py
deleted file mode 100644
index dd5c4a53..00000000
--- a/semantic_router/splitters/cav_sim.py
+++ /dev/null
@@ -1,86 +0,0 @@
-from typing import List
-from semantic_router.splitters.base import BaseSplitter
-import numpy as np
-from semantic_router.schema import DocumentSplit
-from semantic_router.encoders import BaseEncoder
-
-class CAVSimSplitter(BaseSplitter):
-    
-    """
-    The CAVSimSplitter class is a document splitter that uses the concept of Cumulative Average Vectors (CAV) to determine where to split a sequence of documents based on their semantic similarity.
-
-    For example, consider a sequence of documents [A, B, C, D, E, F]. The CAVSimSplitter works as follows:
-
-    1. It starts with the first document (A) and calculates the cosine similarity between the embedding of A and the average embedding of the next two documents (B, C) if they exist, or the next one document (B) if only one exists.
-    - Cosine Similarity: cos_sim(A, avg(B, C))
-
-    2. It then moves to the next document (B), calculates the average embedding of the current documents (A, B), and calculates the cosine similarity with the average embedding of the next two documents (C, D) if they exist, or the next one document (C) if only one exists.
-    - Cosine Similarity: cos_sim(avg(A, B), avg(C, D))
-
-    3. This process continues, with the average embedding being calculated for the current cumulative documents and the next one or two documents. For example, at document C:
-    - Cosine Similarity: cos_sim(avg(A, B, C), avg(D, E))
-
-    4. If the similarity score between the average embedding of the current cumulative documents and the average embedding of the next one or two documents falls below the specified similarity threshold, a split is triggered. In our example, let's say the similarity score falls below the threshold between the average of documents A, B, C and the average of D, E. The splitter will then create a split, resulting in two groups of documents: [A, B, C] and [D, E].
-
-    5. After a split occurs, the process restarts with the next document in the sequence. For example, after the split between C and D, the process restarts with D and calculates the cosine similarity between the embedding of D and the average embedding of the next two documents if they exist.
-    - Cosine Similarity: cos_sim(D, avg(E, F))
-
-    6. Then we start accumulating and averaging from the left again. On the right there is only one more document left, F:
-    - Cosine Similarity: cos_sim(avg(D, E), F)
-
-    7. The process continues until all documents have been processed.
-
-    The result is a list of DocumentSplit objects, each representing a group of semantically similar documents.
-    """
-
-    def __init__(
-        self,
-        encoder: BaseEncoder,
-        name: str = "cav_similarity_splitter",
-        similarity_threshold: float = 0.45,
-    ):
-        super().__init__(
-            name=name, 
-            similarity_threshold=similarity_threshold,
-            encoder=encoder
-            )
-
-    def __call__(self, docs: List[str]):
-        total_docs = len(docs)
-        splits = []
-        curr_split_start_idx = 0
-        curr_split_num = 1
-        doc_embeds = self.encoder(docs)
-
-        for idx in range(1, total_docs):
-            curr_split_docs_embeds = doc_embeds[curr_split_start_idx : idx + 1]
-            avg_embedding = np.mean(curr_split_docs_embeds, axis=0)
-
-            # Compute the average embedding for the next two documents, if available
-            if idx + 3 <= total_docs:  # Check if the next two indices are within the range
-                next_doc_embeds = doc_embeds[idx + 1 : idx + 3]
-                next_avg_embed = np.mean(next_doc_embeds, axis=0)
-            elif idx + 2 <= total_docs:  # Check if the next index is within the range
-                next_avg_embed = doc_embeds[idx + 1]
-            else:
-                next_avg_embed = None
-
-            if next_avg_embed is not None:
-                curr_sim_score = np.dot(avg_embedding, next_avg_embed) / (
-                    np.linalg.norm(avg_embedding)
-                    * np.linalg.norm(next_avg_embed)
-                )
-
-                if curr_sim_score < self.similarity_threshold:
-                    splits.append(
-                        DocumentSplit(
-                            docs=list(docs[curr_split_start_idx : idx + 1]),
-                            is_triggered=True,
-                            triggered_score=curr_sim_score,
-                        )
-                    )
-                    curr_split_start_idx = idx + 1
-                    curr_split_num += 1
-
-        splits.append(DocumentSplit(docs=list(docs[curr_split_start_idx:])))
-        return splits
\ No newline at end of file
diff --git a/semantic_router/text.py b/semantic_router/text.py
new file mode 100644
index 00000000..0847b739
--- /dev/null
+++ b/semantic_router/text.py
@@ -0,0 +1,129 @@
+from pydantic.v1 import BaseModel, Field
+from typing import Union, List, Literal, Tuple
+from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
+from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter
+from semantic_router.splitters.running_avg_sim import RunningAvgSimSplitter
+from semantic_router.encoders import BaseEncoder
+from semantic_router.schema import Message
+
+# Define a type alias for the splitter to simplify the annotation
+SplitterType = Union[ConsecutiveSimSplitter, CumulativeSimSplitter, RunningAvgSimSplitter, None]
+
+class Conversation(BaseModel):
+    messages: List[Message] = Field(default_factory=list) # Ensure this is initialized as an empty list
+    topics: List[Tuple[int, str]] = []
+    splitter: SplitterType = None
+
+    def add_new_messages(self, new_messages: List[Message]):
+        self.messages.extend(new_messages)
+
+    def remove_topics(self):
+        self.topics = []
+
+    def configure_splitter(
+        self,
+        encoder: BaseEncoder,
+        threshold: float = 0.5,
+        split_method: Literal[
+            "consecutive_similarity", "cumulative_similarity", "running_avg_similarity"
+        ] = "consecutive_similarity",
+    ):
+        if split_method == "consecutive_similarity":
+            self.splitter = ConsecutiveSimSplitter(encoder=encoder, similarity_threshold=threshold)
+        elif split_method == "cumulative_similarity":
+            self.splitter = CumulativeSimSplitter(encoder=encoder, similarity_threshold=threshold)
+        elif split_method == "running_avg_similarity":
+            self.splitter = RunningAvgSimSplitter(encoder=encoder, similarity_threshold=threshold)
+        else:
+            raise ValueError(f"Invalid split method: {split_method}")
+    
+
+    def split_by_topic(self):
+        if self.splitter is None:
+            raise ValueError("Splitter is not configured. Please call configure_splitter first.")
+        new_topics = []
+        # DEBUGGING: Start.
+        print('#'*50)
+        print('self.topics')
+        print(self.topics)
+        print('#'*50)
+        # DEBUGGING: End.
+        # DEBUGGING: Start.
+        print('#'*50)
+        print('self.messages')
+        print(self.messages)
+        print('#'*50)
+        # DEBUGGING: End.
+
+        # Get the messages that haven't been clustered into topics yet
+        unclustered_messages = self.messages[len(self.topics):]
+
+        # DEBUGGING: Start.
+        print('#'*50)
+        print('unclustered_messages')
+        print(unclustered_messages)
+        print('#'*50)
+        # DEBUGGING: End.
+        
+        # If there are no unclustered messages, return early
+        if not unclustered_messages:
+            print("No unclustered messages to process.")
+            return self.topics, new_topics
+
+        # Extract the last topic ID and message from the previous splitting, if they exist.
+        if self.topics:
+            last_topic_id_from_last_splitting, last_message_from_last_splitting = self.topics[-1]
+        else:
+            last_topic_id_from_last_splitting, last_message_from_last_splitting = None, None
+
+        # Initialize docs with the last message from the last topic if it exists
+        docs = [last_message_from_last_splitting] if last_message_from_last_splitting else []
+        
+        # Add the unclustered messages to the docs
+        docs.extend([f"{m.role}: {m.content}" for m in unclustered_messages])
+        
+        # DEBUGGING: Start.
+        print('#'*50)
+        print('docs')
+        print(docs)
+        print('#'*50)
+        # DEBUGGING: End.
+
+        # Use the splitter to split the documents
+        new_topics = self.splitter(docs)
+
+        # DEBUGGING: Start.
+        print('#'*50)
+        print('new_topics')
+        print(new_topics)
+        print('#'*50)
+        # DEBUGGING: End.
+
+        
+        # Ensure there are new topics before proceeding
+        if not new_topics:
+            return self.topics, []
+    
+
+        # Check if there are any previously assigned topics
+        if self.topics and new_topics:
+            # Check if the first new topic includes the last message that was assigned a topic in the previous splitting.
+            # This indicates that the new messages may continue the same topic as the last message from the previous split.
+            if last_topic_id_from_last_splitting and last_message_from_last_splitting and last_message_from_last_splitting in new_topics[0].docs:
+                start = last_topic_id_from_last_splitting
+            else:
+                start = self.topics[-1][0] + 1
+        else:
+            start = 0  # Start from 0 if no previous topics
+
+        # If the last message from the previous splitting is found in the first new topic, remove it
+        if self.topics and new_topics[0].docs[0] == self.topics[-1][1]:
+            new_topics[0].docs.pop(0)
+
+        # Add the new topics to the list of topics with unique IDs
+        for i, topic in enumerate(new_topics, start=start):
+            for message in topic.docs:
+                self.topics.append((i, message))
+        
+        return self.topics, new_topics
+   
\ No newline at end of file
-- 
GitLab