Skip to content
Snippets Groups Projects
Unverified Commit 091b256e authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

similarity_threshold -> score_threshold

parent 1226b96c
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,7 @@ from semantic_router.encoders import BaseEncoder
class BaseSplitter(BaseModel):
name: str
encoder: BaseEncoder
similarity_threshold: float
score_threshold: float
def __call__(self, docs: List[str]) -> List[List[float]]:
raise NotImplementedError("Subclasses must implement this method")
......@@ -15,10 +15,10 @@ class ConsecutiveSimSplitter(BaseSplitter):
self,
encoder: BaseEncoder,
name: str = "consecutive_similarity_splitter",
similarity_threshold: float = 0.45,
score_threshold: float = 0.45,
):
super().__init__(
name=name, similarity_threshold=similarity_threshold, encoder=encoder
name=name, score_threshold=score_threshold, encoder=encoder
)
def __call__(self, docs: List[str]):
......@@ -38,7 +38,7 @@ class ConsecutiveSimSplitter(BaseSplitter):
for idx in range(1, total_docs):
curr_sim_score = sim_matrix[idx - 1][idx]
if idx < len(sim_matrix) and curr_sim_score < self.similarity_threshold:
if idx < len(sim_matrix) and curr_sim_score < self.score_threshold:
splits.append(
DocumentSplit(
docs=list(docs[curr_split_start_idx:idx]),
......
......@@ -15,10 +15,10 @@ class CumulativeSimSplitter(BaseSplitter):
self,
encoder: BaseEncoder,
name: str = "cumulative_similarity_splitter",
similarity_threshold: float = 0.45,
score_threshold: float = 0.45,
):
super().__init__(
name=name, similarity_threshold=similarity_threshold, encoder=encoder
name=name, score_threshold=score_threshold, encoder=encoder
)
def __call__(self, docs: List[str]):
......@@ -49,7 +49,7 @@ class CumulativeSimSplitter(BaseSplitter):
* np.linalg.norm(next_doc_embed)
)
# Decision to split based on similarity score.
if curr_sim_score < self.similarity_threshold:
if curr_sim_score < self.score_threshold:
splits.append(
DocumentSplit(
docs=list(docs[curr_split_start_idx : idx + 1]),
......
......@@ -61,11 +61,11 @@ class Conversation(BaseModel):
if split_method == "consecutive_similarity":
self.splitter = ConsecutiveSimSplitter(
encoder=encoder, similarity_threshold=threshold
encoder=encoder, score_threshold=threshold
)
elif split_method == "cumulative_similarity":
self.splitter = CumulativeSimSplitter(
encoder=encoder, similarity_threshold=threshold
encoder=encoder, score_threshold=threshold
)
else:
raise ValueError(f"Invalid split method: {split_method}")
......
......@@ -22,7 +22,7 @@ def test_consecutive_sim_splitter():
input_type="",
)
# Instantiate the ConsecutiveSimSplitter with the mock encoder
splitter = ConsecutiveSimSplitter(encoder=cohere_encoder, similarity_threshold=0.9)
splitter = ConsecutiveSimSplitter(encoder=cohere_encoder, score_threshold=0.9)
splitter.encoder = mock_encoder
# Define some documents
......@@ -58,7 +58,7 @@ def test_cumulative_sim_splitter():
cohere_api_key="",
input_type="",
)
splitter = CumulativeSimSplitter(encoder=cohere_encoder, similarity_threshold=0.9)
splitter = CumulativeSimSplitter(encoder=cohere_encoder, score_threshold=0.9)
splitter.encoder = mock_encoder
# Define some documents
......@@ -163,7 +163,7 @@ def test_consecutive_similarity_splitter_single_doc():
# Assuming any return value since it should not reach the point of using the encoder
mock_encoder.return_value = np.array([[0.5, 0]])
splitter = ConsecutiveSimSplitter(encoder=mock_encoder, similarity_threshold=0.5)
splitter = ConsecutiveSimSplitter(encoder=mock_encoder, score_threshold=0.5)
docs = ["doc1"]
with pytest.raises(ValueError) as excinfo:
......@@ -176,7 +176,7 @@ def test_cumulative_similarity_splitter_single_doc():
# Assuming any return value since it should not reach the point of using the encoder
mock_encoder.return_value = np.array([[0.5, 0]])
splitter = CumulativeSimSplitter(encoder=mock_encoder, similarity_threshold=0.5)
splitter = CumulativeSimSplitter(encoder=mock_encoder, score_threshold=0.5)
docs = ["doc1"]
with pytest.raises(ValueError) as excinfo:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment