diff --git a/tests/unit/test_text.py b/tests/unit/test_text.py index 0685f89ee25a0faecf63378614b7cfc500a686b5..749c9b92bf4ddf672bd734d561414880ec1fe847 100644 --- a/tests/unit/test_text.py +++ b/tests/unit/test_text.py @@ -6,6 +6,7 @@ from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter from semantic_router.encoders.cohere import ( CohereEncoder, ) # Adjust this import based on your project structure +from semantic_router.schema import DocumentSplit @pytest.fixture @@ -74,3 +75,125 @@ def test_split_by_topic_with_no_unclustered_messages( conversation_instance.split_by_topic() captured = capsys.readouterr() assert "No unclustered messages to process." in captured.out + + +def test_print_topics_empty(conversation_instance, capsys): + # Test printing topics when there are no topics + conversation_instance.print_topics() + captured = capsys.readouterr() + assert "No topics to display." in captured.out + + +def test_print_topics_with_data(conversation_instance, capsys): + # Add some topics to the conversation instance + conversation_instance.topics.append((0, "Hello, how are you?")) + conversation_instance.topics.append((0, "I'm fine, thanks!")) + conversation_instance.topics.append((1, "What's the weather like?")) + conversation_instance.topics.append((2, "It's sunny.")) + + # Test printing topics with data + conversation_instance.print_topics() + captured = capsys.readouterr() + + # Expected output based on the topics added + expected_output = ( + "Topics:\n" + "Topic 1:\n" + " - Hello, how are you?\n" + " - I'm fine, thanks!\n" + "\n" + "Topic 2:\n" + " - What's the weather like?\n" + "\n" + "Topic 3:\n" + " - It's sunny." + ) + + # Normalize newlines for Windows compatibility + normalized_output = captured.out.replace("\r\n", "\n") + assert normalized_output.strip() == expected_output + + +def test_get_last_message_and_topic_id_with_no_topics(conversation_instance): + # Test the method when there are no topics in the conversation + last_topic_id, last_message = conversation_instance.get_last_message_and_topic_id() + assert ( + last_topic_id is None and last_message is None + ), "Expected None for both topic ID and message when there are no topics" + + +def test_get_last_message_and_topic_id_with_topics(conversation_instance): + # Add some topics to the conversation instance + conversation_instance.topics.append((0, "First message")) + conversation_instance.topics.append((1, "Second message")) + conversation_instance.topics.append((2, "Third message")) + + # Test the method when there are topics in the conversation + last_topic_id, last_message = conversation_instance.get_last_message_and_topic_id() + assert ( + last_topic_id == 2 and last_message == "Third message" + ), "Expected last topic ID and message to match the last topic added" + + +def test_determine_topic_start_index_no_existing_topics(conversation_instance): + # Scenario where there are no existing topics + new_topics = [ + DocumentSplit(docs=["User: Hello!"], is_triggered=True, triggered_score=0.4) + ] + start_index = conversation_instance.determine_topic_start_index( + new_topics, None, None + ) + assert ( + start_index == 1 + ), "Expected start index to be 1 when there are no existing topics" + + +def test_determine_topic_start_index_with_existing_topics_not_including_last_message( + conversation_instance, +): + # Scenario where existing topics do not include the last message + conversation_instance.topics.append((0, "First message")) + new_topics = [ + DocumentSplit(docs=["User: Hello!"], is_triggered=True, triggered_score=0.4) + ] + start_index = conversation_instance.determine_topic_start_index( + new_topics, 0, "Non-existent last message" + ) + assert ( + start_index == 1 + ), "Expected start index to increment when last message is not in new topics" + + +def test_determine_topic_start_index_with_existing_topics_including_last_message( + conversation_instance, +): + # Scenario where the first new topic includes the last message + conversation_instance.topics.append((0, "First message")) + new_topics = [ + DocumentSplit( + docs=["First message", "Another message"], + is_triggered=True, + triggered_score=0.4, + ) + ] + start_index = conversation_instance.determine_topic_start_index( + new_topics, 0, "First message" + ) + assert ( + start_index == 0 + ), "Expected start index to be the same as last topic ID when last message is included in new topics" + + +def test_determine_topic_start_index_increment_from_last_topic_id( + conversation_instance, +): + # Scenario to test increment from the last topic ID when last message is not in new topics + conversation_instance.topics.append((1, "First message")) + conversation_instance.topics.append((2, "Second message")) + new_topics = [ + DocumentSplit(docs=["User: Hello!"], is_triggered=True, triggered_score=0.4) + ] + start_index = conversation_instance.determine_topic_start_index( + new_topics, 2, "Non-existent last message" + ) + assert start_index == 3, "Expected start index to be last topic ID + 1"