From 99b2d075b7d9bcdd7f7c5fdf48a573c2753559e1 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Mon, 5 Feb 2024 13:08:35 +0100
Subject: [PATCH] threshold applies to encoder

---
 semantic_router/splitters/consecutive_sim.py | 1 +
 semantic_router/splitters/cumulative_sim.py  | 1 +
 semantic_router/text.py                      | 5 +++++
 3 files changed, 7 insertions(+)

diff --git a/semantic_router/splitters/consecutive_sim.py b/semantic_router/splitters/consecutive_sim.py
index 2994d7ff..6bd08845 100644
--- a/semantic_router/splitters/consecutive_sim.py
+++ b/semantic_router/splitters/consecutive_sim.py
@@ -18,6 +18,7 @@ class ConsecutiveSimSplitter(BaseSplitter):
         score_threshold: float = 0.45,
     ):
         super().__init__(name=name, score_threshold=score_threshold, encoder=encoder)
+        encoder.score_threshold = score_threshold
 
     def __call__(self, docs: List[str]):
         # Check if there's only a single document
diff --git a/semantic_router/splitters/cumulative_sim.py b/semantic_router/splitters/cumulative_sim.py
index 960eeb29..ba8f4bd3 100644
--- a/semantic_router/splitters/cumulative_sim.py
+++ b/semantic_router/splitters/cumulative_sim.py
@@ -18,6 +18,7 @@ class CumulativeSimSplitter(BaseSplitter):
         score_threshold: float = 0.45,
     ):
         super().__init__(name=name, score_threshold=score_threshold, encoder=encoder)
+        encoder.score_threshold = score_threshold
 
     def __call__(self, docs: List[str]):
         total_docs = len(docs)
diff --git a/semantic_router/text.py b/semantic_router/text.py
index dfa0ecf7..3d56e374 100644
--- a/semantic_router/text.py
+++ b/semantic_router/text.py
@@ -50,6 +50,11 @@ class Conversation(BaseModel):
 
 
     def add_new_messages(self, new_messages: List[Message]):
+        """Adds new messages to the conversation.
+        
+        :param messages: The new messages to be added to the conversation.
+        :type messages: List[Message]
+        """
         self.messages.extend(new_messages)
 
     def remove_topics(self):
-- 
GitLab