Skip to content
Snippets Groups Projects
Unverified Commit d9853d82 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Tidying Up Conversation Class

The splitting code needed tidying and docstrings.

Delegated tasks in split_by_topic() to newly created methods to reduce number of lines of code in split_by_topic().
parent d29263d2
No related branches found
No related tags found
No related merge requests found
...@@ -5,6 +5,7 @@ from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter ...@@ -5,6 +5,7 @@ from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter
from semantic_router.splitters.running_avg_sim import RunningAvgSimSplitter from semantic_router.splitters.running_avg_sim import RunningAvgSimSplitter
from semantic_router.encoders import BaseEncoder from semantic_router.encoders import BaseEncoder
from semantic_router.schema import Message from semantic_router.schema import Message
from semantic_router.schema import DocumentSplit
# Define a type alias for the splitter to simplify the annotation # Define a type alias for the splitter to simplify the annotation
SplitterType = Union[ConsecutiveSimSplitter, CumulativeSimSplitter, RunningAvgSimSplitter, None] SplitterType = Union[ConsecutiveSimSplitter, CumulativeSimSplitter, RunningAvgSimSplitter, None]
...@@ -20,6 +21,20 @@ class Conversation(BaseModel): ...@@ -20,6 +21,20 @@ class Conversation(BaseModel):
def remove_topics(self): def remove_topics(self):
self.topics = [] self.topics = []
def print_topics(self):
if not self.topics:
print("No topics to display.")
return
print("Topics:")
current_topic_id = None
for topic_id, message in self.topics:
if topic_id != current_topic_id:
if current_topic_id is not None:
print("\n", end="")
print(f"Topic {topic_id + 1}:")
current_topic_id = topic_id
print(f" - {message}")
def configure_splitter( def configure_splitter(
self, self,
encoder: BaseEncoder, encoder: BaseEncoder,
...@@ -28,6 +43,21 @@ class Conversation(BaseModel): ...@@ -28,6 +43,21 @@ class Conversation(BaseModel):
"consecutive_similarity", "cumulative_similarity", "running_avg_similarity" "consecutive_similarity", "cumulative_similarity", "running_avg_similarity"
] = "consecutive_similarity", ] = "consecutive_similarity",
): ):
"""
Configures the splitter for the conversation based on the specified method.
This method sets the splitter attribute of the Conversation class to an instance of the appropriate splitter class, based on the `split_method` parameter. It uses the provided encoder and similarity threshold to initialize the splitter.
:param encoder: The encoder to be used by the splitter for encoding messages.
:type encoder: BaseEncoder
:param threshold: The similarity threshold to be used by the splitter. Defaults to 0.5.
:type threshold: float
:param split_method: The method to be used for splitting the conversation into topics. Can be one of "consecutive_similarity", "cumulative_similarity", or "running_avg_similarity". Defaults to "consecutive_similarity".
:type split_method: Literal["consecutive_similarity", "cumulative_similarity", "running_avg_similarity"]
:raises ValueError: If an invalid split method is provided.
"""
if split_method == "consecutive_similarity": if split_method == "consecutive_similarity":
self.splitter = ConsecutiveSimSplitter(encoder=encoder, similarity_threshold=threshold) self.splitter = ConsecutiveSimSplitter(encoder=encoder, similarity_threshold=threshold)
elif split_method == "cumulative_similarity": elif split_method == "cumulative_similarity":
...@@ -38,57 +68,104 @@ class Conversation(BaseModel): ...@@ -38,57 +68,104 @@ class Conversation(BaseModel):
raise ValueError(f"Invalid split method: {split_method}") raise ValueError(f"Invalid split method: {split_method}")
def split_by_topic(self): def get_last_message_and_topic_id(self):
"""
Retrieves the last message and its corresponding topic ID from the list of topics.
This method scans the list of topics, if any, and returns the topic ID and message of the last entry. If there are no topics, it returns None for both the topic ID and message.
The last message from a previous spiltting is useful because it can be passed to the splitter along with new messages, and if the first new message is assigned the same topic as the last message, then we know that the new message should continue with the same topic ID as the last message.
:return: A tuple containing the topic ID (int) and message (str) of the last topic, or (None, None) if there are no topics.
:rtype: tuple[int | None, str | None]
"""
if self.topics:
return self.topics[-1]
else:
return None, None
def determine_topic_start_index(self, new_topics, last_topic_id, last_message):
"""
Determines the starting index for new topics based on existing topics and the last message.
:param new_topics: The list of new topics generated by the splitter.
:type new_topics: List[DocumentSplit]
:param last_topic_id: The topic ID of the last message from the previous splitting.
:type last_topic_id: int, optional
:param last_message: The last message from the previous splitting.
:type last_message: str, optional
:return: The starting index for new topics.
:rtype: int
"""
if not self.topics or not new_topics:
return 1
if last_topic_id is not None and last_message and last_message in new_topics[0].docs:
return last_topic_id
return self.topics[-1][0] + 1
def append_new_topics(self, new_topics, start) -> None:
"""
Appends new topics to the list of topics with unique IDs.
This method takes a list of new topics generated by the splitter and appends them to the existing list of topics, ensuring each topic is assigned a unique ID starting from the specified starting index.
:param new_topics: The list of new topics generated by the splitter.
:type new_topics: List[DocumentSplit]
:param start: The starting index for new topics.
:type start: int
"""
for i, topic in enumerate(new_topics, start=start):
for message in topic.docs:
self.topics.append((i, message))
def split_by_topic(self) -> Tuple[List[Tuple[int, str]], List[DocumentSplit]]:
"""
Splits the messages into topics based on their semantic similarity.
This method processes unclustered messages, splits them into topics using the configured splitter, and appends the new topics to the existing list of topics with unique IDs. It ensures that messages belonging to the same topic are grouped together, even if they were not processed in the same batch.
:raises ValueError: If the splitter is not configured before calling this method.
:return: A tuple containing the updated list of topics and the list of new topics generated in this call.
:rtype: tuple[list[tuple[int, str]], list[DocumentSplit]]
"""
if self.splitter is None: if self.splitter is None:
raise ValueError("Splitter is not configured. Please call configure_splitter first.") raise ValueError("Splitter is not configured. Please call configure_splitter first.")
new_topics = [] new_topics = []
# Get the messages that haven't been clustered into topics yet # Get unclusteed messages.
unclustered_messages = self.messages[len(self.topics):] unclustered_messages = self.messages[len(self.topics):]
# If there are no unclustered messages, return early
if not unclustered_messages: if not unclustered_messages:
print("No unclustered messages to process.") print("No unclustered messages to process.")
return self.topics, new_topics return self.topics, new_topics
# Extract the last topic ID and message from the previous splitting, if they exist. # Extract the last topic ID and message from the previous splitting, if they exist.
if self.topics: last_topic_id, last_message = self.get_last_message_and_topic_id()
last_topic_id_from_last_splitting, last_message_from_last_splitting = self.topics[-1]
else:
last_topic_id_from_last_splitting, last_message_from_last_splitting = None, None
# Initialize docs with the last message from the last topic if it exists # Initialize docs with the last message from the last topic if it exists, and with unclustered messages.
docs = [last_message_from_last_splitting] if last_message_from_last_splitting else [] # TODO: Currenlty only getting last message from last topic in previous splitting. Should we get more for more reliable continuation of topic ids?
docs = [last_message] if last_message else []
# Add the unclustered messages to the docs
docs.extend([f"{m.role}: {m.content}" for m in unclustered_messages]) docs.extend([f"{m.role}: {m.content}" for m in unclustered_messages])
# Use the splitter to split the documents
new_topics = self.splitter(docs) new_topics = self.splitter(docs)
# Ensure there are new topics before proceeding # Ensure there are new topics before proceeding
if not new_topics: if not new_topics:
return self.topics, [] return self.topics, []
# Check if there are any previously assigned topics # If last_message and the first new message are assigned the same topic ID, then we know the new message should take last_message's place original topic id.
if self.topics and new_topics: start = self.determine_topic_start_index(new_topics, last_topic_id, last_message)
# Check if the first new topic includes the last message that was assigned a topic in the previous splitting.
# This indicates that the new messages may continue the same topic as the last message from the previous split.
if last_topic_id_from_last_splitting and last_message_from_last_splitting and last_message_from_last_splitting in new_topics[0].docs:
start = last_topic_id_from_last_splitting
else:
start = self.topics[-1][0] + 1
else:
start = 0 # Start from 0 if no previous topics
# If the last message from the previous splitting is found in the first new topic, remove it # If the last message from the previous splitting is found in the first new topic, remove it
if self.topics and new_topics[0].docs[0] == self.topics[-1][1]: if self.topics and new_topics[0].docs[0] == self.topics[-1][1]:
new_topics[0].docs.pop(0) new_topics[0].docs.pop(0)
# Add the new topics to the list of topics with unique IDs self.append_new_topics(new_topics, start)
for i, topic in enumerate(new_topics, start=start):
for message in topic.docs:
self.topics.append((i, message))
# TODO: Instead of self.topics as list of tuples should it also be a list of DocumentSplit objects?
return self.topics, new_topics return self.topics, new_topics
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment