Skip to content
Snippets Groups Projects
Unverified Commit 091b256e authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

similarity_threshold -> score_threshold

parent 1226b96c
No related branches found
No related tags found
No related merge requests found
...@@ -7,7 +7,7 @@ from semantic_router.encoders import BaseEncoder ...@@ -7,7 +7,7 @@ from semantic_router.encoders import BaseEncoder
class BaseSplitter(BaseModel): class BaseSplitter(BaseModel):
name: str name: str
encoder: BaseEncoder encoder: BaseEncoder
similarity_threshold: float score_threshold: float
def __call__(self, docs: List[str]) -> List[List[float]]: def __call__(self, docs: List[str]) -> List[List[float]]:
raise NotImplementedError("Subclasses must implement this method") raise NotImplementedError("Subclasses must implement this method")
...@@ -15,10 +15,10 @@ class ConsecutiveSimSplitter(BaseSplitter): ...@@ -15,10 +15,10 @@ class ConsecutiveSimSplitter(BaseSplitter):
self, self,
encoder: BaseEncoder, encoder: BaseEncoder,
name: str = "consecutive_similarity_splitter", name: str = "consecutive_similarity_splitter",
similarity_threshold: float = 0.45, score_threshold: float = 0.45,
): ):
super().__init__( super().__init__(
name=name, similarity_threshold=similarity_threshold, encoder=encoder name=name, score_threshold=score_threshold, encoder=encoder
) )
def __call__(self, docs: List[str]): def __call__(self, docs: List[str]):
...@@ -38,7 +38,7 @@ class ConsecutiveSimSplitter(BaseSplitter): ...@@ -38,7 +38,7 @@ class ConsecutiveSimSplitter(BaseSplitter):
for idx in range(1, total_docs): for idx in range(1, total_docs):
curr_sim_score = sim_matrix[idx - 1][idx] curr_sim_score = sim_matrix[idx - 1][idx]
if idx < len(sim_matrix) and curr_sim_score < self.similarity_threshold: if idx < len(sim_matrix) and curr_sim_score < self.score_threshold:
splits.append( splits.append(
DocumentSplit( DocumentSplit(
docs=list(docs[curr_split_start_idx:idx]), docs=list(docs[curr_split_start_idx:idx]),
......
...@@ -15,10 +15,10 @@ class CumulativeSimSplitter(BaseSplitter): ...@@ -15,10 +15,10 @@ class CumulativeSimSplitter(BaseSplitter):
self, self,
encoder: BaseEncoder, encoder: BaseEncoder,
name: str = "cumulative_similarity_splitter", name: str = "cumulative_similarity_splitter",
similarity_threshold: float = 0.45, score_threshold: float = 0.45,
): ):
super().__init__( super().__init__(
name=name, similarity_threshold=similarity_threshold, encoder=encoder name=name, score_threshold=score_threshold, encoder=encoder
) )
def __call__(self, docs: List[str]): def __call__(self, docs: List[str]):
...@@ -49,7 +49,7 @@ class CumulativeSimSplitter(BaseSplitter): ...@@ -49,7 +49,7 @@ class CumulativeSimSplitter(BaseSplitter):
* np.linalg.norm(next_doc_embed) * np.linalg.norm(next_doc_embed)
) )
# Decision to split based on similarity score. # Decision to split based on similarity score.
if curr_sim_score < self.similarity_threshold: if curr_sim_score < self.score_threshold:
splits.append( splits.append(
DocumentSplit( DocumentSplit(
docs=list(docs[curr_split_start_idx : idx + 1]), docs=list(docs[curr_split_start_idx : idx + 1]),
......
...@@ -61,11 +61,11 @@ class Conversation(BaseModel): ...@@ -61,11 +61,11 @@ class Conversation(BaseModel):
if split_method == "consecutive_similarity": if split_method == "consecutive_similarity":
self.splitter = ConsecutiveSimSplitter( self.splitter = ConsecutiveSimSplitter(
encoder=encoder, similarity_threshold=threshold encoder=encoder, score_threshold=threshold
) )
elif split_method == "cumulative_similarity": elif split_method == "cumulative_similarity":
self.splitter = CumulativeSimSplitter( self.splitter = CumulativeSimSplitter(
encoder=encoder, similarity_threshold=threshold encoder=encoder, score_threshold=threshold
) )
else: else:
raise ValueError(f"Invalid split method: {split_method}") raise ValueError(f"Invalid split method: {split_method}")
......
...@@ -22,7 +22,7 @@ def test_consecutive_sim_splitter(): ...@@ -22,7 +22,7 @@ def test_consecutive_sim_splitter():
input_type="", input_type="",
) )
# Instantiate the ConsecutiveSimSplitter with the mock encoder # Instantiate the ConsecutiveSimSplitter with the mock encoder
splitter = ConsecutiveSimSplitter(encoder=cohere_encoder, similarity_threshold=0.9) splitter = ConsecutiveSimSplitter(encoder=cohere_encoder, score_threshold=0.9)
splitter.encoder = mock_encoder splitter.encoder = mock_encoder
# Define some documents # Define some documents
...@@ -58,7 +58,7 @@ def test_cumulative_sim_splitter(): ...@@ -58,7 +58,7 @@ def test_cumulative_sim_splitter():
cohere_api_key="", cohere_api_key="",
input_type="", input_type="",
) )
splitter = CumulativeSimSplitter(encoder=cohere_encoder, similarity_threshold=0.9) splitter = CumulativeSimSplitter(encoder=cohere_encoder, score_threshold=0.9)
splitter.encoder = mock_encoder splitter.encoder = mock_encoder
# Define some documents # Define some documents
...@@ -163,7 +163,7 @@ def test_consecutive_similarity_splitter_single_doc(): ...@@ -163,7 +163,7 @@ def test_consecutive_similarity_splitter_single_doc():
# Assuming any return value since it should not reach the point of using the encoder # Assuming any return value since it should not reach the point of using the encoder
mock_encoder.return_value = np.array([[0.5, 0]]) mock_encoder.return_value = np.array([[0.5, 0]])
splitter = ConsecutiveSimSplitter(encoder=mock_encoder, similarity_threshold=0.5) splitter = ConsecutiveSimSplitter(encoder=mock_encoder, score_threshold=0.5)
docs = ["doc1"] docs = ["doc1"]
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
...@@ -176,7 +176,7 @@ def test_cumulative_similarity_splitter_single_doc(): ...@@ -176,7 +176,7 @@ def test_cumulative_similarity_splitter_single_doc():
# Assuming any return value since it should not reach the point of using the encoder # Assuming any return value since it should not reach the point of using the encoder
mock_encoder.return_value = np.array([[0.5, 0]]) mock_encoder.return_value = np.array([[0.5, 0]])
splitter = CumulativeSimSplitter(encoder=mock_encoder, similarity_threshold=0.5) splitter = CumulativeSimSplitter(encoder=mock_encoder, score_threshold=0.5)
docs = ["doc1"] docs = ["doc1"]
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment