Skip to content
Snippets Groups Projects
Commit c25d6af8 authored by Ismail Ashraq's avatar Ismail Ashraq
Browse files

fix test errors

parent 741c57c8
Branches
Tags
No related merge requests found
......@@ -71,12 +71,12 @@ def semantic_splitter(
curr_split_docs_embed = encoder([curr_split_docs])[0]
next_doc_embed = encoder([next_doc])[0]
similarity = np.dot(curr_split_docs_embed, next_doc_embed) / (
curr_sim_score = np.dot(curr_split_docs_embed, next_doc_embed) / (
np.linalg.norm(curr_split_docs_embed)
* np.linalg.norm(next_doc_embed)
)
if similarity < threshold:
if curr_sim_score < threshold:
splits.append(
DocumentSplit(
docs=docs[curr_split_start_idx : idx + 1],
......
......@@ -17,7 +17,8 @@ def test_semantic_splitter_consecutive_similarity_drop():
result = semantic_splitter(mock_encoder, docs, threshold, split_method)
assert result == {"split 1": ["doc1", "doc2", "doc3"], "split 2": ["doc4", "doc5"]}
assert result[0].docs == ["doc1", "doc2", "doc3"]
assert result[1].docs == ["doc4", "doc5"]
def test_semantic_splitter_cumulative_similarity_drop():
......@@ -33,7 +34,8 @@ def test_semantic_splitter_cumulative_similarity_drop():
result = semantic_splitter(mock_encoder, docs, threshold, split_method)
assert result == {"split 1": ["doc1", "doc2"], "split 2": ["doc3", "doc4", "doc5"]}
assert result[0].docs == ["doc1", "doc2"]
assert result[1].docs == ["doc3", "doc4", "doc5"]
def test_semantic_splitter_invalid_method():
......@@ -62,7 +64,5 @@ def test_split_by_topic():
encoder=mock_encoder, threshold=0.5, split_method="consecutive_similarity_drop"
)
assert result == {
"split 1": ["User: What is the latest news?"],
"split 2": ["Bot: How is the weather today?"],
}
assert result[0].docs == ["User: What is the latest news?"]
assert result[1].docs == ["Bot: How is the weather today?"]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment