Skip to content
Snippets Groups Projects
Unverified Commit 53ae469f authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Fixed Bugs with Splitter in Conversation

The split_by_topic was adding the last message from the previous splitting to the current splitting output, resulting in duplicates.

I also made the code more readable.

Also moved Conversation from schema.py to text.py to avoid cyclic imports.
parent 417a2b21
No related branches found
No related tags found
No related merge requests found
from enum import Enum
from typing import List, Literal, Optional, Tuple
from typing import List, Optional
from pydantic.v1 import BaseModel
from pydantic.v1.dataclasses import dataclass
......@@ -10,10 +10,6 @@ from semantic_router.encoders import (
FastEmbedEncoder,
OpenAIEncoder,
)
# from semantic_router.utils.splitters import DocumentSplit, semantic_splitter
# from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
# from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter
# from semantic_router.splitters.cav_sim import CAVSimSplitter
class EncoderType(Enum):
......@@ -73,64 +69,4 @@ class Message(BaseModel):
class DocumentSplit(BaseModel):
docs: List[str]
is_triggered: bool = False
triggered_score: Optional[float] = None
class Conversation(BaseModel):
messages: List[Message]
# topics: List[Tuple[int, str]] = []
# splitter = None
def add_new_messages(self, new_messages: List[Message]):
self.messages.extend(new_messages)
# def configure_splitter(
# self,
# encoder: BaseEncoder,
# threshold: float = 0.5,
# split_method: Literal[
# "consecutive_similarity", "cumulative_similarity", "cav_similarity"
# ] = "consecutive_similarity",
# ):
# if split_method == "consecutive_similarity":
# self.splitter = ConsecutiveSimSplitter(encoder=encoder, similarity_threshold=threshold)
# elif split_method == "cumulative_similarity":
# self.splitter = CumulativeSimSplitter(encoder=encoder, similarity_threshold=threshold)
# elif split_method == "cav_similarity":
# self.splitter = CAVSimSplitter(encoder=encoder, similarity_threshold=threshold)
# else:
# raise ValueError(f"Invalid split method: {split_method}")
# def split_by_topic(self):
# if self.splitter is None:
# raise ValueError("Splitter is not configured. Please call configure_splitter first.")
# # Get the messages that haven't been clustered into topics yet
# unclustered_messages = self.messages[len(self.topics):]
# # Check if there are any messages that have been assigned topics
# if len(self.topics) >= 1:
# # Include the last message in the docs
# docs = [self.topics[-1][1]]
# else:
# # No messages have been assigned topics yet
# docs = []
# # Add the unclustered messages to the docs
# docs.extend([f"{m.role}: {m.content}" for m in unclustered_messages])
# # Use the splitter to split the documents
# new_topics = self.splitter(docs)
# # Check if the first new topic includes the first new message.
# # This means that the first new message shares the same topic as the last old message to have been assigned a topic.
# if docs[-len(unclustered_messages)] in new_topics[0].docs:
# start = self.topics[-1][0]
# else:
# start = self.topics[-1][0] + 1
# # Add the new topics to the list of topics with unique IDs
# for i, topic in enumerate(new_topics, start=start):
# for message in topic.docs:
# self.topics.append((i, message))
# return new_topics
\ No newline at end of file
triggered_score: Optional[float] = None
\ No newline at end of file
from typing import List
from semantic_router.splitters.base import BaseSplitter
import numpy as np
from semantic_router.schema import DocumentSplit
from semantic_router.encoders import BaseEncoder
class CAVSimSplitter(BaseSplitter):
"""
The CAVSimSplitter class is a document splitter that uses the concept of Cumulative Average Vectors (CAV) to determine where to split a sequence of documents based on their semantic similarity.
For example, consider a sequence of documents [A, B, C, D, E, F]. The CAVSimSplitter works as follows:
1. It starts with the first document (A) and calculates the cosine similarity between the embedding of A and the average embedding of the next two documents (B, C) if they exist, or the next one document (B) if only one exists.
- Cosine Similarity: cos_sim(A, avg(B, C))
2. It then moves to the next document (B), calculates the average embedding of the current documents (A, B), and calculates the cosine similarity with the average embedding of the next two documents (C, D) if they exist, or the next one document (C) if only one exists.
- Cosine Similarity: cos_sim(avg(A, B), avg(C, D))
3. This process continues, with the average embedding being calculated for the current cumulative documents and the next one or two documents. For example, at document C:
- Cosine Similarity: cos_sim(avg(A, B, C), avg(D, E))
4. If the similarity score between the average embedding of the current cumulative documents and the average embedding of the next one or two documents falls below the specified similarity threshold, a split is triggered. In our example, let's say the similarity score falls below the threshold between the average of documents A, B, C and the average of D, E. The splitter will then create a split, resulting in two groups of documents: [A, B, C] and [D, E].
5. After a split occurs, the process restarts with the next document in the sequence. For example, after the split between C and D, the process restarts with D and calculates the cosine similarity between the embedding of D and the average embedding of the next two documents if they exist.
- Cosine Similarity: cos_sim(D, avg(E, F))
6. Then we start accumulating and averaging from the left again. On the right there is only one more document left, F:
- Cosine Similarity: cos_sim(avg(D, E), F)
7. The process continues until all documents have been processed.
The result is a list of DocumentSplit objects, each representing a group of semantically similar documents.
"""
def __init__(
self,
encoder: BaseEncoder,
name: str = "cav_similarity_splitter",
similarity_threshold: float = 0.45,
):
super().__init__(
name=name,
similarity_threshold=similarity_threshold,
encoder=encoder
)
def __call__(self, docs: List[str]):
total_docs = len(docs)
splits = []
curr_split_start_idx = 0
curr_split_num = 1
doc_embeds = self.encoder(docs)
for idx in range(1, total_docs):
curr_split_docs_embeds = doc_embeds[curr_split_start_idx : idx + 1]
avg_embedding = np.mean(curr_split_docs_embeds, axis=0)
# Compute the average embedding for the next two documents, if available
if idx + 3 <= total_docs: # Check if the next two indices are within the range
next_doc_embeds = doc_embeds[idx + 1 : idx + 3]
next_avg_embed = np.mean(next_doc_embeds, axis=0)
elif idx + 2 <= total_docs: # Check if the next index is within the range
next_avg_embed = doc_embeds[idx + 1]
else:
next_avg_embed = None
if next_avg_embed is not None:
curr_sim_score = np.dot(avg_embedding, next_avg_embed) / (
np.linalg.norm(avg_embedding)
* np.linalg.norm(next_avg_embed)
)
if curr_sim_score < self.similarity_threshold:
splits.append(
DocumentSplit(
docs=list(docs[curr_split_start_idx : idx + 1]),
is_triggered=True,
triggered_score=curr_sim_score,
)
)
curr_split_start_idx = idx + 1
curr_split_num += 1
splits.append(DocumentSplit(docs=list(docs[curr_split_start_idx:])))
return splits
\ No newline at end of file
from pydantic.v1 import BaseModel, Field
from typing import Union, List, Literal, Tuple
from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter
from semantic_router.splitters.running_avg_sim import RunningAvgSimSplitter
from semantic_router.encoders import BaseEncoder
from semantic_router.schema import Message
# Define a type alias for the splitter to simplify the annotation
SplitterType = Union[ConsecutiveSimSplitter, CumulativeSimSplitter, RunningAvgSimSplitter, None]
class Conversation(BaseModel):
messages: List[Message] = Field(default_factory=list) # Ensure this is initialized as an empty list
topics: List[Tuple[int, str]] = []
splitter: SplitterType = None
def add_new_messages(self, new_messages: List[Message]):
self.messages.extend(new_messages)
def remove_topics(self):
self.topics = []
def configure_splitter(
self,
encoder: BaseEncoder,
threshold: float = 0.5,
split_method: Literal[
"consecutive_similarity", "cumulative_similarity", "running_avg_similarity"
] = "consecutive_similarity",
):
if split_method == "consecutive_similarity":
self.splitter = ConsecutiveSimSplitter(encoder=encoder, similarity_threshold=threshold)
elif split_method == "cumulative_similarity":
self.splitter = CumulativeSimSplitter(encoder=encoder, similarity_threshold=threshold)
elif split_method == "running_avg_similarity":
self.splitter = RunningAvgSimSplitter(encoder=encoder, similarity_threshold=threshold)
else:
raise ValueError(f"Invalid split method: {split_method}")
def split_by_topic(self):
if self.splitter is None:
raise ValueError("Splitter is not configured. Please call configure_splitter first.")
new_topics = []
# DEBUGGING: Start.
print('#'*50)
print('self.topics')
print(self.topics)
print('#'*50)
# DEBUGGING: End.
# DEBUGGING: Start.
print('#'*50)
print('self.messages')
print(self.messages)
print('#'*50)
# DEBUGGING: End.
# Get the messages that haven't been clustered into topics yet
unclustered_messages = self.messages[len(self.topics):]
# DEBUGGING: Start.
print('#'*50)
print('unclustered_messages')
print(unclustered_messages)
print('#'*50)
# DEBUGGING: End.
# If there are no unclustered messages, return early
if not unclustered_messages:
print("No unclustered messages to process.")
return self.topics, new_topics
# Extract the last topic ID and message from the previous splitting, if they exist.
if self.topics:
last_topic_id_from_last_splitting, last_message_from_last_splitting = self.topics[-1]
else:
last_topic_id_from_last_splitting, last_message_from_last_splitting = None, None
# Initialize docs with the last message from the last topic if it exists
docs = [last_message_from_last_splitting] if last_message_from_last_splitting else []
# Add the unclustered messages to the docs
docs.extend([f"{m.role}: {m.content}" for m in unclustered_messages])
# DEBUGGING: Start.
print('#'*50)
print('docs')
print(docs)
print('#'*50)
# DEBUGGING: End.
# Use the splitter to split the documents
new_topics = self.splitter(docs)
# DEBUGGING: Start.
print('#'*50)
print('new_topics')
print(new_topics)
print('#'*50)
# DEBUGGING: End.
# Ensure there are new topics before proceeding
if not new_topics:
return self.topics, []
# Check if there are any previously assigned topics
if self.topics and new_topics:
# Check if the first new topic includes the last message that was assigned a topic in the previous splitting.
# This indicates that the new messages may continue the same topic as the last message from the previous split.
if last_topic_id_from_last_splitting and last_message_from_last_splitting and last_message_from_last_splitting in new_topics[0].docs:
start = last_topic_id_from_last_splitting
else:
start = self.topics[-1][0] + 1
else:
start = 0 # Start from 0 if no previous topics
# If the last message from the previous splitting is found in the first new topic, remove it
if self.topics and new_topics[0].docs[0] == self.topics[-1][1]:
new_topics[0].docs.pop(0)
# Add the new topics to the list of topics with unique IDs
for i, topic in enumerate(new_topics, start=start):
for message in topic.docs:
self.topics.append((i, message))
return self.topics, new_topics
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment