diff --git a/semantic_router/utils/splitters.py b/semantic_router/utils/splitters.py index 10b957086cadb562aaf72151e893b033d75d442d..7521aa3878cbbbd757cfa567843384db772da2cc 100644 --- a/semantic_router/utils/splitters.py +++ b/semantic_router/utils/splitters.py @@ -71,12 +71,12 @@ def semantic_splitter( curr_split_docs_embed = encoder([curr_split_docs])[0] next_doc_embed = encoder([next_doc])[0] - similarity = np.dot(curr_split_docs_embed, next_doc_embed) / ( + curr_sim_score = np.dot(curr_split_docs_embed, next_doc_embed) / ( np.linalg.norm(curr_split_docs_embed) * np.linalg.norm(next_doc_embed) ) - if similarity < threshold: + if curr_sim_score < threshold: splits.append( DocumentSplit( docs=docs[curr_split_start_idx : idx + 1], diff --git a/tests/unit/test_splitters.py b/tests/unit/test_splitters.py index ac9c037c7985bbce54144de4c6e4e7096162dc19..f0e8e3f32734f988e2d54a878c26f7f5f1994228 100644 --- a/tests/unit/test_splitters.py +++ b/tests/unit/test_splitters.py @@ -17,7 +17,8 @@ def test_semantic_splitter_consecutive_similarity_drop(): result = semantic_splitter(mock_encoder, docs, threshold, split_method) - assert result == {"split 1": ["doc1", "doc2", "doc3"], "split 2": ["doc4", "doc5"]} + assert result[0].docs == ["doc1", "doc2", "doc3"] + assert result[1].docs == ["doc4", "doc5"] def test_semantic_splitter_cumulative_similarity_drop(): @@ -33,7 +34,8 @@ def test_semantic_splitter_cumulative_similarity_drop(): result = semantic_splitter(mock_encoder, docs, threshold, split_method) - assert result == {"split 1": ["doc1", "doc2"], "split 2": ["doc3", "doc4", "doc5"]} + assert result[0].docs == ["doc1", "doc2"] + assert result[1].docs == ["doc3", "doc4", "doc5"] def test_semantic_splitter_invalid_method(): @@ -62,7 +64,5 @@ def test_split_by_topic(): encoder=mock_encoder, threshold=0.5, split_method="consecutive_similarity_drop" ) - assert result == { - "split 1": ["User: What is the latest news?"], - "split 2": ["Bot: How is the weather today?"], - } + assert result[0].docs == ["User: What is the latest news?"] + assert result[1].docs == ["Bot: How is the weather today?"]