diff --git a/semantic_router/splitters/base.py b/semantic_router/splitters/base.py index ccd4e6f6b2149f5cb2183ea68e9ec1726b5ac686..c7ca37c2ed2660448b4549181e354aa0e5b8727a 100644 --- a/semantic_router/splitters/base.py +++ b/semantic_router/splitters/base.py @@ -7,7 +7,7 @@ from semantic_router.encoders import BaseEncoder class BaseSplitter(BaseModel): name: str encoder: BaseEncoder - similarity_threshold: float + score_threshold: float def __call__(self, docs: List[str]) -> List[List[float]]: raise NotImplementedError("Subclasses must implement this method") diff --git a/semantic_router/splitters/consecutive_sim.py b/semantic_router/splitters/consecutive_sim.py index 8cafb09d1f7e9f195f2e51a373027fca6181f6e9..5cda1c0e558e558679c026b3f76f9addc4f8016f 100644 --- a/semantic_router/splitters/consecutive_sim.py +++ b/semantic_router/splitters/consecutive_sim.py @@ -15,10 +15,10 @@ class ConsecutiveSimSplitter(BaseSplitter): self, encoder: BaseEncoder, name: str = "consecutive_similarity_splitter", - similarity_threshold: float = 0.45, + score_threshold: float = 0.45, ): super().__init__( - name=name, similarity_threshold=similarity_threshold, encoder=encoder + name=name, score_threshold=score_threshold, encoder=encoder ) def __call__(self, docs: List[str]): @@ -38,7 +38,7 @@ class ConsecutiveSimSplitter(BaseSplitter): for idx in range(1, total_docs): curr_sim_score = sim_matrix[idx - 1][idx] - if idx < len(sim_matrix) and curr_sim_score < self.similarity_threshold: + if idx < len(sim_matrix) and curr_sim_score < self.score_threshold: splits.append( DocumentSplit( docs=list(docs[curr_split_start_idx:idx]), diff --git a/semantic_router/splitters/cumulative_sim.py b/semantic_router/splitters/cumulative_sim.py index 97ab5b21e474fc8e270e823adf1ee333b5c4ce8a..1dbd80b04370e10871b2f0d5b2c8d48954894a00 100644 --- a/semantic_router/splitters/cumulative_sim.py +++ b/semantic_router/splitters/cumulative_sim.py @@ -15,10 +15,10 @@ class CumulativeSimSplitter(BaseSplitter): self, encoder: BaseEncoder, name: str = "cumulative_similarity_splitter", - similarity_threshold: float = 0.45, + score_threshold: float = 0.45, ): super().__init__( - name=name, similarity_threshold=similarity_threshold, encoder=encoder + name=name, score_threshold=score_threshold, encoder=encoder ) def __call__(self, docs: List[str]): @@ -49,7 +49,7 @@ class CumulativeSimSplitter(BaseSplitter): * np.linalg.norm(next_doc_embed) ) # Decision to split based on similarity score. - if curr_sim_score < self.similarity_threshold: + if curr_sim_score < self.score_threshold: splits.append( DocumentSplit( docs=list(docs[curr_split_start_idx : idx + 1]), diff --git a/semantic_router/text.py b/semantic_router/text.py index 2717f99af79b38f9259fc695ac79cd8888217ab1..003f8c368447e9b857faa1c3a0fb9fb98565fd60 100644 --- a/semantic_router/text.py +++ b/semantic_router/text.py @@ -61,11 +61,11 @@ class Conversation(BaseModel): if split_method == "consecutive_similarity": self.splitter = ConsecutiveSimSplitter( - encoder=encoder, similarity_threshold=threshold + encoder=encoder, score_threshold=threshold ) elif split_method == "cumulative_similarity": self.splitter = CumulativeSimSplitter( - encoder=encoder, similarity_threshold=threshold + encoder=encoder, score_threshold=threshold ) else: raise ValueError(f"Invalid split method: {split_method}") diff --git a/tests/unit/test_splitters.py b/tests/unit/test_splitters.py index bef137d6447427c67dfc072cda2b2e2982b4446a..21146aaf09b5343be1674607c9fadd7a378963bb 100644 --- a/tests/unit/test_splitters.py +++ b/tests/unit/test_splitters.py @@ -22,7 +22,7 @@ def test_consecutive_sim_splitter(): input_type="", ) # Instantiate the ConsecutiveSimSplitter with the mock encoder - splitter = ConsecutiveSimSplitter(encoder=cohere_encoder, similarity_threshold=0.9) + splitter = ConsecutiveSimSplitter(encoder=cohere_encoder, score_threshold=0.9) splitter.encoder = mock_encoder # Define some documents @@ -58,7 +58,7 @@ def test_cumulative_sim_splitter(): cohere_api_key="", input_type="", ) - splitter = CumulativeSimSplitter(encoder=cohere_encoder, similarity_threshold=0.9) + splitter = CumulativeSimSplitter(encoder=cohere_encoder, score_threshold=0.9) splitter.encoder = mock_encoder # Define some documents @@ -163,7 +163,7 @@ def test_consecutive_similarity_splitter_single_doc(): # Assuming any return value since it should not reach the point of using the encoder mock_encoder.return_value = np.array([[0.5, 0]]) - splitter = ConsecutiveSimSplitter(encoder=mock_encoder, similarity_threshold=0.5) + splitter = ConsecutiveSimSplitter(encoder=mock_encoder, score_threshold=0.5) docs = ["doc1"] with pytest.raises(ValueError) as excinfo: @@ -176,7 +176,7 @@ def test_cumulative_similarity_splitter_single_doc(): # Assuming any return value since it should not reach the point of using the encoder mock_encoder.return_value = np.array([[0.5, 0]]) - splitter = CumulativeSimSplitter(encoder=mock_encoder, similarity_threshold=0.5) + splitter = CumulativeSimSplitter(encoder=mock_encoder, score_threshold=0.5) docs = ["doc1"] with pytest.raises(ValueError) as excinfo: