From e53237254d5fa78816eb5db53de2aeb37ef19bc5 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood <siraj@aurelio.ai> Date: Fri, 2 Feb 2024 03:24:17 +0400 Subject: [PATCH] New PyTests for text.py And more linting. --- tests/unit/test_splitters.py | 15 +++++++ tests/unit/test_text.py | 76 ++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 tests/unit/test_text.py diff --git a/tests/unit/test_splitters.py b/tests/unit/test_splitters.py index 0ed3a01a..218323b5 100644 --- a/tests/unit/test_splitters.py +++ b/tests/unit/test_splitters.py @@ -9,6 +9,7 @@ from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter from semantic_router.encoders.base import BaseEncoder from semantic_router.encoders.cohere import CohereEncoder +from semantic_router.splitters.base import BaseSplitter def test_consecutive_sim_splitter(): @@ -182,3 +183,17 @@ def test_cumulative_similarity_splitter_single_doc(): with pytest.raises(ValueError) as excinfo: splitter(docs) assert "at least two are required" in str(excinfo.value) + + +@pytest.fixture +def base_splitter_instance(): + # Now MockEncoder includes default values for required fields + mock_encoder = Mock(spec=BaseEncoder) + mock_encoder.name = "mock_encoder" + mock_encoder.score_threshold = 0.5 + return BaseSplitter(name="test_splitter", encoder=mock_encoder, score_threshold=0.5) + + +def test_base_splitter_call_not_implemented(base_splitter_instance): + with pytest.raises(NotImplementedError): + base_splitter_instance(["document"]) diff --git a/tests/unit/test_text.py b/tests/unit/test_text.py new file mode 100644 index 00000000..0685f89e --- /dev/null +++ b/tests/unit/test_text.py @@ -0,0 +1,76 @@ +import pytest +from unittest.mock import Mock +from semantic_router.text import Conversation, Message +from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter +from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter +from semantic_router.encoders.cohere import ( + CohereEncoder, +) # Adjust this import based on your project structure + + +@pytest.fixture +def conversation_instance(): + return Conversation() + + +@pytest.fixture +def cohere_encoder(): + # Initialize CohereEncoder with necessary arguments + encoder = CohereEncoder( + name="cohere_encoder", cohere_api_key="dummy_key", input_type="text" + ) + return encoder + + +def test_add_new_messages(conversation_instance): + initial_len = len(conversation_instance.messages) + conversation_instance.add_new_messages([Message(role="user", content="Hello")]) + assert len(conversation_instance.messages) == initial_len + 1 + + +def test_remove_topics(conversation_instance): + conversation_instance.topics.append((1, "Sample Topic")) + conversation_instance.remove_topics() + assert len(conversation_instance.topics) == 0 + + +def test_configure_splitter_consecutive_similarity( + conversation_instance, cohere_encoder +): + conversation_instance.configure_splitter( + encoder=cohere_encoder, threshold=0.5, split_method="consecutive_similarity" + ) + assert isinstance(conversation_instance.splitter, ConsecutiveSimSplitter) + + +def test_configure_splitter_cumulative_similarity( + conversation_instance, cohere_encoder +): + conversation_instance.configure_splitter( + encoder=cohere_encoder, threshold=0.5, split_method="cumulative_similarity" + ) + assert isinstance(conversation_instance.splitter, CumulativeSimSplitter) + + +def test_configure_splitter_invalid_method(conversation_instance, cohere_encoder): + with pytest.raises(ValueError): + conversation_instance.configure_splitter( + encoder=cohere_encoder, threshold=0.5, split_method="invalid_method" + ) + + +def test_split_by_topic_without_configuring_splitter(conversation_instance): + with pytest.raises(ValueError): + conversation_instance.split_by_topic() + + +def test_split_by_topic_with_no_unclustered_messages( + conversation_instance, cohere_encoder, capsys +): + conversation_instance.configure_splitter( + encoder=cohere_encoder, threshold=0.5, split_method="consecutive_similarity" + ) + conversation_instance.splitter = Mock() + conversation_instance.split_by_topic() + captured = capsys.readouterr() + assert "No unclustered messages to process." in captured.out -- GitLab