From 545a4c8bcdbb7a860f8a08f5c5be3964e26f04c4 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Mon, 5 Feb 2024 12:52:16 +0100 Subject: [PATCH] modified print for conversations --- poetry.lock | 2 +- pyproject.toml | 1 + semantic_router/schema.py | 3 +++ semantic_router/text.py | 46 +++++++++++++++++++++++++----------- tests/unit/test_splitters.py | 6 ++--- 5 files changed, 40 insertions(+), 18 deletions(-) diff --git a/poetry.lock b/poetry.lock index e7c7c595..2498a9a6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3388,4 +3388,4 @@ local = ["llama-cpp-python", "torch", "transformers"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "49fd469a4cf8a0a31d2e4df9e9a2b40d1f6fefeba87ec26d76d1cb716f4f51ca" +content-hash = "10453196b0249ab854bdcd965ce8631ad9e0db33ae4423d35b250cb8b07b9898" diff --git a/pyproject.toml b/pyproject.toml index 8971df26..b396601d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ torch = {version = "^2.1.0", optional = true} transformers = {version = "^4.36.2", optional = true} llama-cpp-python = {version = "^0.2.28", optional = true} black = "^23.12.1" +colorama = "^0.4.6" [tool.poetry.extras] hybrid = ["pinecone-text"] diff --git a/semantic_router/schema.py b/semantic_router/schema.py index f8dcd8d1..0c8376ce 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -64,6 +64,9 @@ class Message(BaseModel): def to_llamacpp(self): return {"role": self.role, "content": self.content} + + def __str__(self): + return f"{self.role}: {self.content}" class DocumentSplit(BaseModel): diff --git a/semantic_router/text.py b/semantic_router/text.py index 003f8c36..dfa0ecf7 100644 --- a/semantic_router/text.py +++ b/semantic_router/text.py @@ -1,3 +1,6 @@ +from colorama import Fore +from colorama import Style + from pydantic.v1 import BaseModel, Field from typing import Union, List, Literal, Tuple from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter @@ -9,6 +12,16 @@ from semantic_router.schema import DocumentSplit # Define a type alias for the splitter to simplify the annotation SplitterType = Union[ConsecutiveSimSplitter, CumulativeSimSplitter, None] +colors = [ + Fore.WHITE, + Fore.RED, + Fore.GREEN, + Fore.YELLOW, + Fore.BLUE, + Fore.MAGENTA, + Fore.CYAN +] + class Conversation(BaseModel): messages: List[Message] = Field( @@ -17,26 +30,31 @@ class Conversation(BaseModel): topics: List[Tuple[int, str]] = [] splitter: SplitterType = None + def __str__(self): + if not self.messages: + return "" + if not self.topics: + return "\n".join([str(message) for message in self.messages]) + else: + # we print each topic a different color + return_str_list = [] + current_topic_id = None + color_idx = 0 + for topic_id, message in self.topics: + if topic_id != current_topic_id: + # change color + color_idx = (color_idx + 1) % len(colors) + current_topic_id = topic_id + return_str_list.append(f"{colors[color_idx]}{message}{Style.RESET_ALL}") + return "\n".join(return_str_list) + + def add_new_messages(self, new_messages: List[Message]): self.messages.extend(new_messages) def remove_topics(self): self.topics = [] - def print_topics(self): - if not self.topics: - print("No topics to display.") - return - print("Topics:") - current_topic_id = None - for topic_id, message in self.topics: - if topic_id != current_topic_id: - if current_topic_id is not None: - print("\n", end="") - print(f"Topic {topic_id + 1}:") - current_topic_id = topic_id - print(f" - {message}") - def configure_splitter( self, encoder: BaseEncoder, diff --git a/tests/unit/test_splitters.py b/tests/unit/test_splitters.py index 218323b5..165fc2dd 100644 --- a/tests/unit/test_splitters.py +++ b/tests/unit/test_splitters.py @@ -88,7 +88,7 @@ def test_split_by_topic_consecutive_similarity(): messages = [ Message(role="User", content="What is the latest news?"), - Message(role="Bot", content="How is the weather today?"), + Message(role="Assistant", content="How is the weather today?"), ] conversation = Conversation(messages=messages) @@ -107,7 +107,7 @@ def test_split_by_topic_consecutive_similarity(): assert len(new_topics) == 2 assert new_topics[0].docs == ["User: What is the latest news?"] - assert new_topics[1].docs == ["Bot: How is the weather today?"] + assert new_topics[1].docs == ["Assistant: How is the weather today?"] def test_split_by_topic_cumulative_similarity(): @@ -118,7 +118,7 @@ def test_split_by_topic_cumulative_similarity(): messages = [ Message(role="User", content="What is the latest news?"), - Message(role="Bot", content="How is the weather today?"), + Message(role="Assistant", content="How is the weather today?"), ] conversation = Conversation(messages=messages) -- GitLab