Skip to content
Snippets Groups Projects
Commit c25d6af8 authored by Ismail Ashraq's avatar Ismail Ashraq
Browse files

fix test errors

parent 741c57c8
No related branches found
No related tags found
No related merge requests found
...@@ -71,12 +71,12 @@ def semantic_splitter( ...@@ -71,12 +71,12 @@ def semantic_splitter(
curr_split_docs_embed = encoder([curr_split_docs])[0] curr_split_docs_embed = encoder([curr_split_docs])[0]
next_doc_embed = encoder([next_doc])[0] next_doc_embed = encoder([next_doc])[0]
similarity = np.dot(curr_split_docs_embed, next_doc_embed) / ( curr_sim_score = np.dot(curr_split_docs_embed, next_doc_embed) / (
np.linalg.norm(curr_split_docs_embed) np.linalg.norm(curr_split_docs_embed)
* np.linalg.norm(next_doc_embed) * np.linalg.norm(next_doc_embed)
) )
if similarity < threshold: if curr_sim_score < threshold:
splits.append( splits.append(
DocumentSplit( DocumentSplit(
docs=docs[curr_split_start_idx : idx + 1], docs=docs[curr_split_start_idx : idx + 1],
......
...@@ -17,7 +17,8 @@ def test_semantic_splitter_consecutive_similarity_drop(): ...@@ -17,7 +17,8 @@ def test_semantic_splitter_consecutive_similarity_drop():
result = semantic_splitter(mock_encoder, docs, threshold, split_method) result = semantic_splitter(mock_encoder, docs, threshold, split_method)
assert result == {"split 1": ["doc1", "doc2", "doc3"], "split 2": ["doc4", "doc5"]} assert result[0].docs == ["doc1", "doc2", "doc3"]
assert result[1].docs == ["doc4", "doc5"]
def test_semantic_splitter_cumulative_similarity_drop(): def test_semantic_splitter_cumulative_similarity_drop():
...@@ -33,7 +34,8 @@ def test_semantic_splitter_cumulative_similarity_drop(): ...@@ -33,7 +34,8 @@ def test_semantic_splitter_cumulative_similarity_drop():
result = semantic_splitter(mock_encoder, docs, threshold, split_method) result = semantic_splitter(mock_encoder, docs, threshold, split_method)
assert result == {"split 1": ["doc1", "doc2"], "split 2": ["doc3", "doc4", "doc5"]} assert result[0].docs == ["doc1", "doc2"]
assert result[1].docs == ["doc3", "doc4", "doc5"]
def test_semantic_splitter_invalid_method(): def test_semantic_splitter_invalid_method():
...@@ -62,7 +64,5 @@ def test_split_by_topic(): ...@@ -62,7 +64,5 @@ def test_split_by_topic():
encoder=mock_encoder, threshold=0.5, split_method="consecutive_similarity_drop" encoder=mock_encoder, threshold=0.5, split_method="consecutive_similarity_drop"
) )
assert result == { assert result[0].docs == ["User: What is the latest news?"]
"split 1": ["User: What is the latest news?"], assert result[1].docs == ["Bot: How is the weather today?"]
"split 2": ["Bot: How is the weather today?"],
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment