diff --git a/tests/unit/test_splitters.py b/tests/unit/test_splitters.py index 5d1e04901ca5e8f31423bd9c7a0c26b5e7979780..bcd8f62bea52b4d481f614757e0e2f9e63ae2453 100644 --- a/tests/unit/test_splitters.py +++ b/tests/unit/test_splitters.py @@ -1,6 +1,7 @@ import pytest from unittest.mock import Mock from semantic_router.utils.splitters import semantic_splitter +from semantic_router.schema import Conversation, Message def test_semantic_splitter_consecutive_similarity_drop(): @@ -43,3 +44,23 @@ def test_semantic_splitter_invalid_method(): with pytest.raises(ValueError): semantic_splitter(mock_encoder, docs, threshold, split_method) + + +def test_split_by_topic(): + mock_encoder = Mock() + mock_encoder.return_value = [[0.5, 0], [0, 0.5]] + + messages = [ + Message(role="User", content="What is the latest news?"), + Message(role="Bot", content="How is the weather today?"), + ] + conversation = Conversation(messages=messages) + + result = conversation.split_by_topic( + encoder=mock_encoder, threshold=0.5, split_method="consecutive_similarity_drop" + ) + + assert result == { + "split 1": ["User: What is the latest news?"], + "split 2": ["Bot: How is the weather today?"], + }