Skip to content
Snippets Groups Projects
Unverified Commit e5323725 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

New PyTests for text.py

And more linting.
parent c336d2b1
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@ from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter
from semantic_router.encoders.base import BaseEncoder
from semantic_router.encoders.cohere import CohereEncoder
from semantic_router.splitters.base import BaseSplitter
def test_consecutive_sim_splitter():
......@@ -182,3 +183,17 @@ def test_cumulative_similarity_splitter_single_doc():
with pytest.raises(ValueError) as excinfo:
splitter(docs)
assert "at least two are required" in str(excinfo.value)
@pytest.fixture
def base_splitter_instance():
# Now MockEncoder includes default values for required fields
mock_encoder = Mock(spec=BaseEncoder)
mock_encoder.name = "mock_encoder"
mock_encoder.score_threshold = 0.5
return BaseSplitter(name="test_splitter", encoder=mock_encoder, score_threshold=0.5)
def test_base_splitter_call_not_implemented(base_splitter_instance):
with pytest.raises(NotImplementedError):
base_splitter_instance(["document"])
import pytest
from unittest.mock import Mock
from semantic_router.text import Conversation, Message
from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter
from semantic_router.encoders.cohere import (
CohereEncoder,
) # Adjust this import based on your project structure
@pytest.fixture
def conversation_instance():
return Conversation()
@pytest.fixture
def cohere_encoder():
# Initialize CohereEncoder with necessary arguments
encoder = CohereEncoder(
name="cohere_encoder", cohere_api_key="dummy_key", input_type="text"
)
return encoder
def test_add_new_messages(conversation_instance):
initial_len = len(conversation_instance.messages)
conversation_instance.add_new_messages([Message(role="user", content="Hello")])
assert len(conversation_instance.messages) == initial_len + 1
def test_remove_topics(conversation_instance):
conversation_instance.topics.append((1, "Sample Topic"))
conversation_instance.remove_topics()
assert len(conversation_instance.topics) == 0
def test_configure_splitter_consecutive_similarity(
conversation_instance, cohere_encoder
):
conversation_instance.configure_splitter(
encoder=cohere_encoder, threshold=0.5, split_method="consecutive_similarity"
)
assert isinstance(conversation_instance.splitter, ConsecutiveSimSplitter)
def test_configure_splitter_cumulative_similarity(
conversation_instance, cohere_encoder
):
conversation_instance.configure_splitter(
encoder=cohere_encoder, threshold=0.5, split_method="cumulative_similarity"
)
assert isinstance(conversation_instance.splitter, CumulativeSimSplitter)
def test_configure_splitter_invalid_method(conversation_instance, cohere_encoder):
with pytest.raises(ValueError):
conversation_instance.configure_splitter(
encoder=cohere_encoder, threshold=0.5, split_method="invalid_method"
)
def test_split_by_topic_without_configuring_splitter(conversation_instance):
with pytest.raises(ValueError):
conversation_instance.split_by_topic()
def test_split_by_topic_with_no_unclustered_messages(
conversation_instance, cohere_encoder, capsys
):
conversation_instance.configure_splitter(
encoder=cohere_encoder, threshold=0.5, split_method="consecutive_similarity"
)
conversation_instance.splitter = Mock()
conversation_instance.split_by_topic()
captured = capsys.readouterr()
assert "No unclustered messages to process." in captured.out
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment