From 545a4c8bcdbb7a860f8a08f5c5be3964e26f04c4 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Mon, 5 Feb 2024 12:52:16 +0100
Subject: [PATCH] modified print for conversations

---
 poetry.lock                  |  2 +-
 pyproject.toml               |  1 +
 semantic_router/schema.py    |  3 +++
 semantic_router/text.py      | 46 +++++++++++++++++++++++++-----------
 tests/unit/test_splitters.py |  6 ++---
 5 files changed, 40 insertions(+), 18 deletions(-)

diff --git a/poetry.lock b/poetry.lock
index e7c7c595..2498a9a6 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -3388,4 +3388,4 @@ local = ["llama-cpp-python", "torch", "transformers"]
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.9"
-content-hash = "49fd469a4cf8a0a31d2e4df9e9a2b40d1f6fefeba87ec26d76d1cb716f4f51ca"
+content-hash = "10453196b0249ab854bdcd965ce8631ad9e0db33ae4423d35b250cb8b07b9898"
diff --git a/pyproject.toml b/pyproject.toml
index 8971df26..b396601d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -28,6 +28,7 @@ torch = {version = "^2.1.0", optional = true}
 transformers = {version = "^4.36.2", optional = true}
 llama-cpp-python = {version = "^0.2.28", optional = true}
 black = "^23.12.1"
+colorama = "^0.4.6"
 
 [tool.poetry.extras]
 hybrid = ["pinecone-text"]
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index f8dcd8d1..0c8376ce 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -64,6 +64,9 @@ class Message(BaseModel):
 
     def to_llamacpp(self):
         return {"role": self.role, "content": self.content}
+    
+    def __str__(self):
+        return f"{self.role}: {self.content}"
 
 
 class DocumentSplit(BaseModel):
diff --git a/semantic_router/text.py b/semantic_router/text.py
index 003f8c36..dfa0ecf7 100644
--- a/semantic_router/text.py
+++ b/semantic_router/text.py
@@ -1,3 +1,6 @@
+from colorama import Fore
+from colorama import Style
+
 from pydantic.v1 import BaseModel, Field
 from typing import Union, List, Literal, Tuple
 from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
@@ -9,6 +12,16 @@ from semantic_router.schema import DocumentSplit
 # Define a type alias for the splitter to simplify the annotation
 SplitterType = Union[ConsecutiveSimSplitter, CumulativeSimSplitter, None]
 
+colors = [
+    Fore.WHITE,
+    Fore.RED,
+    Fore.GREEN,
+    Fore.YELLOW,
+    Fore.BLUE,
+    Fore.MAGENTA,
+    Fore.CYAN
+]
+
 
 class Conversation(BaseModel):
     messages: List[Message] = Field(
@@ -17,26 +30,31 @@ class Conversation(BaseModel):
     topics: List[Tuple[int, str]] = []
     splitter: SplitterType = None
 
+    def __str__(self):
+        if not self.messages:
+            return ""
+        if not self.topics:
+            return "\n".join([str(message) for message in self.messages])
+        else:
+            # we print each topic a different color
+            return_str_list = []
+            current_topic_id = None
+            color_idx = 0
+            for topic_id, message in self.topics:
+                if topic_id != current_topic_id:
+                    # change color
+                    color_idx = (color_idx + 1) % len(colors)
+                    current_topic_id = topic_id
+                return_str_list.append(f"{colors[color_idx]}{message}{Style.RESET_ALL}")
+            return "\n".join(return_str_list)
+
+
     def add_new_messages(self, new_messages: List[Message]):
         self.messages.extend(new_messages)
 
     def remove_topics(self):
         self.topics = []
 
-    def print_topics(self):
-        if not self.topics:
-            print("No topics to display.")
-            return
-        print("Topics:")
-        current_topic_id = None
-        for topic_id, message in self.topics:
-            if topic_id != current_topic_id:
-                if current_topic_id is not None:
-                    print("\n", end="")
-                print(f"Topic {topic_id + 1}:")
-                current_topic_id = topic_id
-            print(f" - {message}")
-
     def configure_splitter(
         self,
         encoder: BaseEncoder,
diff --git a/tests/unit/test_splitters.py b/tests/unit/test_splitters.py
index 218323b5..165fc2dd 100644
--- a/tests/unit/test_splitters.py
+++ b/tests/unit/test_splitters.py
@@ -88,7 +88,7 @@ def test_split_by_topic_consecutive_similarity():
 
     messages = [
         Message(role="User", content="What is the latest news?"),
-        Message(role="Bot", content="How is the weather today?"),
+        Message(role="Assistant", content="How is the weather today?"),
     ]
     conversation = Conversation(messages=messages)
 
@@ -107,7 +107,7 @@ def test_split_by_topic_consecutive_similarity():
 
     assert len(new_topics) == 2
     assert new_topics[0].docs == ["User: What is the latest news?"]
-    assert new_topics[1].docs == ["Bot: How is the weather today?"]
+    assert new_topics[1].docs == ["Assistant: How is the weather today?"]
 
 
 def test_split_by_topic_cumulative_similarity():
@@ -118,7 +118,7 @@ def test_split_by_topic_cumulative_similarity():
 
     messages = [
         Message(role="User", content="What is the latest news?"),
-        Message(role="Bot", content="How is the weather today?"),
+        Message(role="Assistant", content="How is the weather today?"),
     ]
     conversation = Conversation(messages=messages)
 
-- 
GitLab