From 06a686ea60862716f3b06c2fd98bdd08691b85c7 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Fri, 2 Feb 2024 02:47:46 +0400
Subject: [PATCH] Linted and Added Pytests and Fixed Existing Pytests

Changes to the code meant we needed to change some pytests.

New pytests added for Conversation splitter method.

Also fixed a bug in cumulative_sim.py where we weren't limiting curr_split_docs to only latter docs after a split.
---
 semantic_router/encoders/cohere.py           |  14 +-
 semantic_router/schema.py                    |   2 +-
 semantic_router/splitters/base.py            |   1 +
 semantic_router/splitters/consecutive_sim.py |  17 +-
 semantic_router/splitters/cumulative_sim.py  |  26 +--
 semantic_router/text.py                      |  50 +++---
 tests/unit/test_splitters.py                 | 161 +++++++++++++++----
 7 files changed, 195 insertions(+), 76 deletions(-)

diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py
index e145e6cd..ceeece17 100644
--- a/semantic_router/encoders/cohere.py
+++ b/semantic_router/encoders/cohere.py
@@ -15,12 +15,16 @@ class CohereEncoder(BaseEncoder):
         self,
         name: Optional[str] = None,
         cohere_api_key: Optional[str] = None,
-        score_threshold: Optional[float] = 0.3,
-        input_type: Optional[str] = "search_query"
+        score_threshold: float = 0.3,
+        input_type: Optional[str] = "search_query",
     ):
         if name is None:
             name = os.getenv("COHERE_MODEL_NAME", "embed-english-v3.0")
-        super().__init__(name=name, score_threshold=score_threshold, input_type=input_type)
+        super().__init__(
+            name=name, 
+            score_threshold=score_threshold, 
+            input_type=input_type # type: ignore
+        )
         self.input_type = input_type
         cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY")
         if cohere_api_key is None:
@@ -36,7 +40,9 @@ class CohereEncoder(BaseEncoder):
         if self.client is None:
             raise ValueError("Cohere client is not initialized.")
         try:
-            embeds = self.client.embed(docs, input_type=self.input_type, model=self.name)
+            embeds = self.client.embed(
+                docs, input_type=self.input_type, model=self.name
+            )
             return embeds.embeddings
         except Exception as e:
             raise ValueError(f"Cohere API call failed. Error: {e}") from e
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 6fcca01f..f8dcd8d1 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -69,4 +69,4 @@ class Message(BaseModel):
 class DocumentSplit(BaseModel):
     docs: List[str]
     is_triggered: bool = False
-    triggered_score: Optional[float] = None
\ No newline at end of file
+    triggered_score: Optional[float] = None
diff --git a/semantic_router/splitters/base.py b/semantic_router/splitters/base.py
index 38867a25..ccd4e6f6 100644
--- a/semantic_router/splitters/base.py
+++ b/semantic_router/splitters/base.py
@@ -3,6 +3,7 @@ from typing import List
 from pydantic.v1 import BaseModel
 from semantic_router.encoders import BaseEncoder
 
+
 class BaseSplitter(BaseModel):
     name: str
     encoder: BaseEncoder
diff --git a/semantic_router/splitters/consecutive_sim.py b/semantic_router/splitters/consecutive_sim.py
index a9038750..388ffb2f 100644
--- a/semantic_router/splitters/consecutive_sim.py
+++ b/semantic_router/splitters/consecutive_sim.py
@@ -4,8 +4,9 @@ from semantic_router.encoders import BaseEncoder
 import numpy as np
 from semantic_router.schema import DocumentSplit
 
+
 class ConsecutiveSimSplitter(BaseSplitter):
-    
+
     """
     Called "consecutive sim splitter" because we check the similarities of consecutive document embeddings (compare ith to i+1th document embedding).
     """
@@ -14,15 +15,17 @@ class ConsecutiveSimSplitter(BaseSplitter):
         self,
         encoder: BaseEncoder,
         name: str = "consecutive_similarity_splitter",
-        similarity_threshold: float = 0.45
+        similarity_threshold: float = 0.45,
     ):
         super().__init__(
-            name=name, 
-            similarity_threshold=similarity_threshold,
-            encoder=encoder
-            )
+            name=name, similarity_threshold=similarity_threshold, encoder=encoder
+        )
 
     def __call__(self, docs: List[str]):
+        # Check if there's only a single document
+        if len(docs) == 1:
+            raise ValueError("There is only one document provided; at least two are required to determine topics based on similarity.")
+
         doc_embeds = self.encoder(docs)
         norm_embeds = doc_embeds / np.linalg.norm(doc_embeds, axis=1, keepdims=True)
         sim_matrix = np.matmul(norm_embeds, norm_embeds.T)
@@ -44,4 +47,4 @@ class ConsecutiveSimSplitter(BaseSplitter):
                 curr_split_start_idx = idx
                 curr_split_num += 1
         splits.append(DocumentSplit(docs=list(docs[curr_split_start_idx:])))
-        return splits
\ No newline at end of file
+        return splits
diff --git a/semantic_router/splitters/cumulative_sim.py b/semantic_router/splitters/cumulative_sim.py
index 016bca8d..e3c4c645 100644
--- a/semantic_router/splitters/cumulative_sim.py
+++ b/semantic_router/splitters/cumulative_sim.py
@@ -4,8 +4,9 @@ import numpy as np
 from semantic_router.schema import DocumentSplit
 from semantic_router.encoders import BaseEncoder
 
+
 class CumulativeSimSplitter(BaseSplitter):
-    
+
     """
     Called "cumulative sim" because we check the similarities of the embeddings of cumulative concatenated documents with the next document.
     """
@@ -17,13 +18,14 @@ class CumulativeSimSplitter(BaseSplitter):
         similarity_threshold: float = 0.45,
     ):
         super().__init__(
-            name=name, 
-            similarity_threshold=similarity_threshold,
-            encoder=encoder
-            )
+            name=name, similarity_threshold=similarity_threshold, encoder=encoder
+        )
 
     def __call__(self, docs: List[str]):
         total_docs = len(docs)
+        # Check if there's only a single document
+        if total_docs == 1:
+            raise ValueError("There is only one document provided; at least two are required to determine topics based on similarity.")
         splits = []
         curr_split_start_idx = 0
 
@@ -34,29 +36,31 @@ class CumulativeSimSplitter(BaseSplitter):
                     curr_split_docs = docs[idx]
                 else:
                     # For subsequent iterations, compare cumulative documents up to the current one with the next.
-                    curr_split_docs = "\n".join(docs[0: idx + 1])
+                    curr_split_docs = "\n".join(docs[curr_split_start_idx : idx + 1])
                 next_doc = docs[idx + 1]
 
                 # Embedding and similarity calculation remains the same.
                 curr_split_docs_embed = self.encoder([curr_split_docs])[0]
                 next_doc_embed = self.encoder([next_doc])[0]
                 curr_sim_score = np.dot(curr_split_docs_embed, next_doc_embed) / (
-                    np.linalg.norm(curr_split_docs_embed) * np.linalg.norm(next_doc_embed)
+                    np.linalg.norm(curr_split_docs_embed)
+                    * np.linalg.norm(next_doc_embed)
                 )
-
                 # Decision to split based on similarity score.
                 if curr_sim_score < self.similarity_threshold:
                     splits.append(
                         DocumentSplit(
-                            docs=list(docs[curr_split_start_idx: idx + 1]),
+                            docs=list(docs[curr_split_start_idx : idx + 1]),
                             is_triggered=True,
                             triggered_score=curr_sim_score,
                         )
                     )
-                    curr_split_start_idx = idx + 1  # Update the start index for the next segment.
+                    curr_split_start_idx = (
+                        idx + 1
+                    )  # Update the start index for the next segment.
 
         # Add the last segment after the loop.
         if curr_split_start_idx < total_docs:
             splits.append(DocumentSplit(docs=list(docs[curr_split_start_idx:])))
 
-        return splits
\ No newline at end of file
+        return splits
diff --git a/semantic_router/text.py b/semantic_router/text.py
index 5b7fd6c5..2717f99a 100644
--- a/semantic_router/text.py
+++ b/semantic_router/text.py
@@ -9,8 +9,11 @@ from semantic_router.schema import DocumentSplit
 # Define a type alias for the splitter to simplify the annotation
 SplitterType = Union[ConsecutiveSimSplitter, CumulativeSimSplitter, None]
 
+
 class Conversation(BaseModel):
-    messages: List[Message] = Field(default_factory=list) # Ensure this is initialized as an empty list
+    messages: List[Message] = Field(
+        default_factory=list
+    )  # Ensure this is initialized as an empty list
     topics: List[Tuple[int, str]] = []
     splitter: SplitterType = None
 
@@ -28,8 +31,8 @@ class Conversation(BaseModel):
         current_topic_id = None
         for topic_id, message in self.topics:
             if topic_id != current_topic_id:
-                if current_topic_id is not None: 
-                    print("\n", end="")  
+                if current_topic_id is not None:
+                    print("\n", end="")
                 print(f"Topic {topic_id + 1}:")
                 current_topic_id = topic_id
             print(f" - {message}")
@@ -42,7 +45,6 @@ class Conversation(BaseModel):
             "consecutive_similarity", "cumulative_similarity"
         ] = "consecutive_similarity",
     ):
-        
         """
         Configures the splitter for the conversation based on the specified method.
 
@@ -58,15 +60,17 @@ class Conversation(BaseModel):
         """
 
         if split_method == "consecutive_similarity":
-            self.splitter = ConsecutiveSimSplitter(encoder=encoder, similarity_threshold=threshold)
+            self.splitter = ConsecutiveSimSplitter(
+                encoder=encoder, similarity_threshold=threshold
+            )
         elif split_method == "cumulative_similarity":
-            self.splitter = CumulativeSimSplitter(encoder=encoder, similarity_threshold=threshold)
+            self.splitter = CumulativeSimSplitter(
+                encoder=encoder, similarity_threshold=threshold
+            )
         else:
             raise ValueError(f"Invalid split method: {split_method}")
-    
 
     def get_last_message_and_topic_id(self):
-
         """
         Retrieves the last message and its corresponding topic ID from the list of topics.
 
@@ -82,11 +86,11 @@ class Conversation(BaseModel):
             return self.topics[-1]
         else:
             return None, None
-    
+
     def determine_topic_start_index(self, new_topics, last_topic_id, last_message):
         """
         Determines the starting index for new topics based on existing topics and the last message.
-        
+
         :param new_topics: The list of new topics generated by the splitter.
         :type new_topics: List[DocumentSplit]
         :param last_topic_id: The topic ID of the last message from the previous splitting.
@@ -98,10 +102,14 @@ class Conversation(BaseModel):
         """
         if not self.topics or not new_topics:
             return 1
-        if last_topic_id is not None and last_message and last_message in new_topics[0].docs:
+        if (
+            last_topic_id is not None
+            and last_message
+            and last_message in new_topics[0].docs
+        ):
             return last_topic_id
         return self.topics[-1][0] + 1
-        
+
     def append_new_topics(self, new_topics, start) -> None:
         """
         Appends new topics to the list of topics with unique IDs.
@@ -118,24 +126,25 @@ class Conversation(BaseModel):
                 self.topics.append((i, message))
 
     def split_by_topic(self) -> Tuple[List[Tuple[int, str]], List[DocumentSplit]]:
-
         """
         Splits the messages into topics based on their semantic similarity.
 
         This method processes unclustered messages, splits them into topics using the configured splitter, and appends the new topics to the existing list of topics with unique IDs. It ensures that messages belonging to the same topic are grouped together, even if they were not processed in the same batch.
 
         :raises ValueError: If the splitter is not configured before calling this method.
-        
+
         :return: A tuple containing the updated list of topics and the list of new topics generated in this call.
         :rtype: tuple[list[tuple[int, str]], list[DocumentSplit]]
         """
 
         if self.splitter is None:
-            raise ValueError("Splitter is not configured. Please call configure_splitter first.")
-        new_topics = []
+            raise ValueError(
+                "Splitter is not configured. Please call configure_splitter first."
+            )
+        new_topics: List[DocumentSplit] = []
 
         # Get unclusteed messages.
-        unclustered_messages = self.messages[len(self.topics):]
+        unclustered_messages = self.messages[len(self.topics) :]
         if not unclustered_messages:
             print("No unclustered messages to process.")
             return self.topics, new_topics
@@ -155,14 +164,15 @@ class Conversation(BaseModel):
             return self.topics, []
 
         # If last_message and the first new message are assigned the same topic ID, then we know the new message should take last_message's place original topic id.
-        start = self.determine_topic_start_index(new_topics, last_topic_id, last_message)
+        start = self.determine_topic_start_index(
+            new_topics, last_topic_id, last_message
+        )
 
         # If the last message from the previous splitting is found in the first new topic, remove it
         if self.topics and new_topics[0].docs[0] == self.topics[-1][1]:
             new_topics[0].docs.pop(0)
 
         self.append_new_topics(new_topics, start)
-        
+
         # TODO: Instead of self.topics as list of tuples should it also be a list of DocumentSplit objects?
         return self.topics, new_topics
-   
\ No newline at end of file
diff --git a/tests/unit/test_splitters.py b/tests/unit/test_splitters.py
index f0e8e3f3..4fb46712 100644
--- a/tests/unit/test_splitters.py
+++ b/tests/unit/test_splitters.py
@@ -1,58 +1,104 @@
-from unittest.mock import Mock
+from unittest.mock import Mock, create_autospec
 
 import pytest
+import numpy as np
 
-from semantic_router.schema import Conversation, Message
-from semantic_router.utils.splitters import semantic_splitter
+from semantic_router.text import Conversation
+from semantic_router.schema import Message
+from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
+from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter
+from semantic_router.encoders.base import BaseEncoder
+from semantic_router.encoders.cohere import CohereEncoder
 
-
-def test_semantic_splitter_consecutive_similarity_drop():
-    # Mock the BaseEncoder
+def test_consecutive_sim_splitter():
+    # Create a Mock object for the encoder
     mock_encoder = Mock()
-    mock_encoder.return_value = [[0.5, 0], [0.5, 0], [0.5, 0], [0, 0.5], [0, 0.5]]
+    mock_encoder.return_value = np.array([[1, 0], [1, 0.1], [0, 1]])
 
-    docs = ["doc1", "doc2", "doc3", "doc4", "doc5"]
-    threshold = 0.5
-    split_method = "consecutive_similarity_drop"
+    cohere_encoder = CohereEncoder(
+    name="", 
+    cohere_api_key='',
+    input_type="",
+    )
+    # Instantiate the ConsecutiveSimSplitter with the mock encoder
+    splitter = ConsecutiveSimSplitter(encoder=cohere_encoder, similarity_threshold=0.9)
+    splitter.encoder = mock_encoder
 
-    result = semantic_splitter(mock_encoder, docs, threshold, split_method)
+    # Define some documents
+    docs = ["doc1", "doc2", "doc3"]
 
-    assert result[0].docs == ["doc1", "doc2", "doc3"]
-    assert result[1].docs == ["doc4", "doc5"]
+    # Use the splitter to split the documents
+    splits = splitter(docs)
 
+    # Verify the splits
+    assert len(splits) == 2, "Expected two splits based on the similarity threshold"
+    assert splits[0].docs == ["doc1", "doc2"], "First split does not match expected documents"
+    assert splits[1].docs == ["doc3"], "Second split does not match expected documents"
 
-def test_semantic_splitter_cumulative_similarity_drop():
+def test_cumulative_sim_splitter():
     # Mock the BaseEncoder
     mock_encoder = Mock()
+    # Adjust the side_effect to simulate the encoder's behavior for cumulative document comparisons
+    # This simplistic simulation assumes binary embeddings for demonstration purposes
+    # Define a side_effect function for the mock encoder
     mock_encoder.side_effect = (
-        lambda x: [[0.5, 0]] if "doc1" in x or "doc1\ndoc2" in x else [[0, 0.5]]
+        lambda x: [[0.5, 0]] if "doc1" in x or "doc1\ndoc2" in x or "doc2" in x else [[0, 0.5]]
     )
 
+    # Instantiate the CumulativeSimSplitter with the mock encoder
+    cohere_encoder = CohereEncoder(
+    name="", 
+    cohere_api_key='',
+    input_type="",
+    )
+    splitter = CumulativeSimSplitter(encoder=cohere_encoder, similarity_threshold=0.9)
+    splitter.encoder = mock_encoder
+
+    # Define some documents
     docs = ["doc1", "doc2", "doc3", "doc4", "doc5"]
-    threshold = 0.5
-    split_method = "cumulative_similarity_drop"
 
-    result = semantic_splitter(mock_encoder, docs, threshold, split_method)
+    # Use the splitter to split the documents
+    splits = splitter(docs)
 
-    assert result[0].docs == ["doc1", "doc2"]
-    assert result[1].docs == ["doc3", "doc4", "doc5"]
+    # Verify the splits
+    # The expected outcome needs to match the logic defined in your mock_encoder's side_effect
+    assert len(splits) == 2, f"{len(splits)}"
+    assert splits[0].docs == ["doc1", "doc2"], "First split does not match expected documents"
+    assert splits[1].docs == ["doc3", "doc4", "doc5"], "Second split does not match expected documents"
 
 
-def test_semantic_splitter_invalid_method():
-    # Mock the BaseEncoder
+def test_split_by_topic_consecutive_similarity():
+
     mock_encoder = Mock()
+    mock_encoder.return_value = [[0.5, 0], [0,0.5]]
 
-    docs = ["doc1", "doc2", "doc3", "doc4", "doc5"]
-    threshold = 0.5
-    split_method = "invalid_method"
+    messages = [
+        Message(role="User", content="What is the latest news?"),
+        Message(role="Bot", content="How is the weather today?"),
+    ]
+    conversation = Conversation(messages=messages)
+    
+    cohere_encoder = CohereEncoder(
+    name="", 
+    cohere_api_key='',
+    input_type="",
+    )
 
-    with pytest.raises(ValueError):
-        semantic_splitter(mock_encoder, docs, threshold, split_method)
+    conversation.configure_splitter(encoder=cohere_encoder, threshold=0.5, split_method="consecutive_similarity")
+    conversation.splitter.encoder = mock_encoder
 
+    topics, new_topics = conversation.split_by_topic()
 
-def test_split_by_topic():
+    assert len(new_topics) == 2
+    assert new_topics[0].docs == ["User: What is the latest news?"]
+    assert new_topics[1].docs == ["Bot: How is the weather today?"]
+
+def test_split_by_topic_cumulative_similarity():
+    
     mock_encoder = Mock()
-    mock_encoder.return_value = [[0.5, 0], [0, 0.5]]
+    mock_encoder.side_effect = (
+        lambda x: [[0.5, 0]] if "User: What is the latest news?" in x else [[0, 0.5]]
+    )
 
     messages = [
         Message(role="User", content="What is the latest news?"),
@@ -60,9 +106,58 @@ def test_split_by_topic():
     ]
     conversation = Conversation(messages=messages)
 
-    result = conversation.split_by_topic(
-        encoder=mock_encoder, threshold=0.5, split_method="consecutive_similarity_drop"
+    cohere_encoder = CohereEncoder(
+    name="", 
+    cohere_api_key='',
+    input_type="",
     )
 
-    assert result[0].docs == ["User: What is the latest news?"]
-    assert result[1].docs == ["Bot: How is the weather today?"]
+    conversation.configure_splitter(encoder=cohere_encoder, threshold=0.5, split_method="cumulative_similarity")
+    conversation.splitter.encoder = mock_encoder
+
+    topics, new_topics = conversation.split_by_topic()
+
+    # Assertions may need to be adjusted based on the expected behavior of the cumulative similarity splitter
+    assert len(new_topics) == 2
+
+
+def test_split_by_topic_no_messages():
+    mock_encoder = create_autospec(BaseEncoder)
+    conversation = Conversation()
+    conversation.configure_splitter(encoder=mock_encoder, threshold=0.5, split_method="consecutive_similarity")
+
+    topics, new_topics = conversation.split_by_topic()
+
+    assert len(new_topics) == 0
+    assert len(topics) == 0
+
+def test_split_by_topic_without_configuring_splitter():
+    conversation = Conversation(messages=[Message(role="User", content="Hello")])
+
+    with pytest.raises(ValueError):
+        conversation.split_by_topic()
+
+
+def test_consecutive_similarity_splitter_single_doc():
+    mock_encoder = create_autospec(BaseEncoder)
+    # Assuming any return value since it should not reach the point of using the encoder
+    mock_encoder.return_value = np.array([[0.5, 0]])
+
+    splitter = ConsecutiveSimSplitter(encoder=mock_encoder, similarity_threshold=0.5)
+
+    docs = ["doc1"]
+    with pytest.raises(ValueError) as excinfo:
+        result = splitter(docs)
+    assert "at least two are required" in str(excinfo.value)
+
+def test_cumulative_similarity_splitter_single_doc():
+    mock_encoder = create_autospec(BaseEncoder)
+    # Assuming any return value since it should not reach the point of using the encoder
+    mock_encoder.return_value = np.array([[0.5, 0]])
+
+    splitter = CumulativeSimSplitter(encoder=mock_encoder, similarity_threshold=0.5)
+
+    docs = ["doc1"]
+    with pytest.raises(ValueError) as excinfo:
+        result = splitter(docs)
+    assert "at least two are required" in str(excinfo.value)
\ No newline at end of file
-- 
GitLab