import pytest from unittest.mock import Mock from semantic_router.text import Conversation, Message from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter from semantic_router.encoders.cohere import ( CohereEncoder, ) # Adjust this import based on your project structure @pytest.fixture def conversation_instance(): return Conversation() @pytest.fixture def cohere_encoder(): # Initialize CohereEncoder with necessary arguments encoder = CohereEncoder( name="cohere_encoder", cohere_api_key="dummy_key", input_type="text" ) return encoder def test_add_new_messages(conversation_instance): initial_len = len(conversation_instance.messages) conversation_instance.add_new_messages([Message(role="user", content="Hello")]) assert len(conversation_instance.messages) == initial_len + 1 def test_remove_topics(conversation_instance): conversation_instance.topics.append((1, "Sample Topic")) conversation_instance.remove_topics() assert len(conversation_instance.topics) == 0 def test_configure_splitter_consecutive_similarity( conversation_instance, cohere_encoder ): conversation_instance.configure_splitter( encoder=cohere_encoder, threshold=0.5, split_method="consecutive_similarity" ) assert isinstance(conversation_instance.splitter, ConsecutiveSimSplitter) def test_configure_splitter_cumulative_similarity( conversation_instance, cohere_encoder ): conversation_instance.configure_splitter( encoder=cohere_encoder, threshold=0.5, split_method="cumulative_similarity" ) assert isinstance(conversation_instance.splitter, CumulativeSimSplitter) def test_configure_splitter_invalid_method(conversation_instance, cohere_encoder): with pytest.raises(ValueError): conversation_instance.configure_splitter( encoder=cohere_encoder, threshold=0.5, split_method="invalid_method" ) def test_split_by_topic_without_configuring_splitter(conversation_instance): with pytest.raises(ValueError): conversation_instance.split_by_topic() def test_split_by_topic_with_no_unclustered_messages( conversation_instance, cohere_encoder, capsys ): conversation_instance.configure_splitter( encoder=cohere_encoder, threshold=0.5, split_method="consecutive_similarity" ) conversation_instance.splitter = Mock() conversation_instance.split_by_topic() captured = capsys.readouterr() assert "No unclustered messages to process." in captured.out