From 99b2d075b7d9bcdd7f7c5fdf48a573c2753559e1 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Mon, 5 Feb 2024 13:08:35 +0100 Subject: [PATCH] threshold applies to encoder --- semantic_router/splitters/consecutive_sim.py | 1 + semantic_router/splitters/cumulative_sim.py | 1 + semantic_router/text.py | 5 +++++ 3 files changed, 7 insertions(+) diff --git a/semantic_router/splitters/consecutive_sim.py b/semantic_router/splitters/consecutive_sim.py index 2994d7ff..6bd08845 100644 --- a/semantic_router/splitters/consecutive_sim.py +++ b/semantic_router/splitters/consecutive_sim.py @@ -18,6 +18,7 @@ class ConsecutiveSimSplitter(BaseSplitter): score_threshold: float = 0.45, ): super().__init__(name=name, score_threshold=score_threshold, encoder=encoder) + encoder.score_threshold = score_threshold def __call__(self, docs: List[str]): # Check if there's only a single document diff --git a/semantic_router/splitters/cumulative_sim.py b/semantic_router/splitters/cumulative_sim.py index 960eeb29..ba8f4bd3 100644 --- a/semantic_router/splitters/cumulative_sim.py +++ b/semantic_router/splitters/cumulative_sim.py @@ -18,6 +18,7 @@ class CumulativeSimSplitter(BaseSplitter): score_threshold: float = 0.45, ): super().__init__(name=name, score_threshold=score_threshold, encoder=encoder) + encoder.score_threshold = score_threshold def __call__(self, docs: List[str]): total_docs = len(docs) diff --git a/semantic_router/text.py b/semantic_router/text.py index dfa0ecf7..3d56e374 100644 --- a/semantic_router/text.py +++ b/semantic_router/text.py @@ -50,6 +50,11 @@ class Conversation(BaseModel): def add_new_messages(self, new_messages: List[Message]): + """Adds new messages to the conversation. + + :param messages: The new messages to be added to the conversation. + :type messages: List[Message] + """ self.messages.extend(new_messages) def remove_topics(self): -- GitLab