From e53237254d5fa78816eb5db53de2aeb37ef19bc5 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Fri, 2 Feb 2024 03:24:17 +0400
Subject: [PATCH] New PyTests for text.py

And more linting.
---
 tests/unit/test_splitters.py | 15 +++++++
 tests/unit/test_text.py      | 76 ++++++++++++++++++++++++++++++++++++
 2 files changed, 91 insertions(+)
 create mode 100644 tests/unit/test_text.py

diff --git a/tests/unit/test_splitters.py b/tests/unit/test_splitters.py
index 0ed3a01a..218323b5 100644
--- a/tests/unit/test_splitters.py
+++ b/tests/unit/test_splitters.py
@@ -9,6 +9,7 @@ from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
 from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter
 from semantic_router.encoders.base import BaseEncoder
 from semantic_router.encoders.cohere import CohereEncoder
+from semantic_router.splitters.base import BaseSplitter
 
 
 def test_consecutive_sim_splitter():
@@ -182,3 +183,17 @@ def test_cumulative_similarity_splitter_single_doc():
     with pytest.raises(ValueError) as excinfo:
         splitter(docs)
     assert "at least two are required" in str(excinfo.value)
+
+
+@pytest.fixture
+def base_splitter_instance():
+    # Now MockEncoder includes default values for required fields
+    mock_encoder = Mock(spec=BaseEncoder)
+    mock_encoder.name = "mock_encoder"
+    mock_encoder.score_threshold = 0.5
+    return BaseSplitter(name="test_splitter", encoder=mock_encoder, score_threshold=0.5)
+
+
+def test_base_splitter_call_not_implemented(base_splitter_instance):
+    with pytest.raises(NotImplementedError):
+        base_splitter_instance(["document"])
diff --git a/tests/unit/test_text.py b/tests/unit/test_text.py
new file mode 100644
index 00000000..0685f89e
--- /dev/null
+++ b/tests/unit/test_text.py
@@ -0,0 +1,76 @@
+import pytest
+from unittest.mock import Mock
+from semantic_router.text import Conversation, Message
+from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
+from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter
+from semantic_router.encoders.cohere import (
+    CohereEncoder,
+)  # Adjust this import based on your project structure
+
+
+@pytest.fixture
+def conversation_instance():
+    return Conversation()
+
+
+@pytest.fixture
+def cohere_encoder():
+    # Initialize CohereEncoder with necessary arguments
+    encoder = CohereEncoder(
+        name="cohere_encoder", cohere_api_key="dummy_key", input_type="text"
+    )
+    return encoder
+
+
+def test_add_new_messages(conversation_instance):
+    initial_len = len(conversation_instance.messages)
+    conversation_instance.add_new_messages([Message(role="user", content="Hello")])
+    assert len(conversation_instance.messages) == initial_len + 1
+
+
+def test_remove_topics(conversation_instance):
+    conversation_instance.topics.append((1, "Sample Topic"))
+    conversation_instance.remove_topics()
+    assert len(conversation_instance.topics) == 0
+
+
+def test_configure_splitter_consecutive_similarity(
+    conversation_instance, cohere_encoder
+):
+    conversation_instance.configure_splitter(
+        encoder=cohere_encoder, threshold=0.5, split_method="consecutive_similarity"
+    )
+    assert isinstance(conversation_instance.splitter, ConsecutiveSimSplitter)
+
+
+def test_configure_splitter_cumulative_similarity(
+    conversation_instance, cohere_encoder
+):
+    conversation_instance.configure_splitter(
+        encoder=cohere_encoder, threshold=0.5, split_method="cumulative_similarity"
+    )
+    assert isinstance(conversation_instance.splitter, CumulativeSimSplitter)
+
+
+def test_configure_splitter_invalid_method(conversation_instance, cohere_encoder):
+    with pytest.raises(ValueError):
+        conversation_instance.configure_splitter(
+            encoder=cohere_encoder, threshold=0.5, split_method="invalid_method"
+        )
+
+
+def test_split_by_topic_without_configuring_splitter(conversation_instance):
+    with pytest.raises(ValueError):
+        conversation_instance.split_by_topic()
+
+
+def test_split_by_topic_with_no_unclustered_messages(
+    conversation_instance, cohere_encoder, capsys
+):
+    conversation_instance.configure_splitter(
+        encoder=cohere_encoder, threshold=0.5, split_method="consecutive_similarity"
+    )
+    conversation_instance.splitter = Mock()
+    conversation_instance.split_by_topic()
+    captured = capsys.readouterr()
+    assert "No unclustered messages to process." in captured.out
-- 
GitLab