diff --git a/poetry.lock b/poetry.lock index e7c7c5957164751b22db0c863e49ac67a390bd99..2498a9a6c5327741198c1a5e1ac6ed2986abe8a9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3388,4 +3388,4 @@ local = ["llama-cpp-python", "torch", "transformers"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "49fd469a4cf8a0a31d2e4df9e9a2b40d1f6fefeba87ec26d76d1cb716f4f51ca" +content-hash = "10453196b0249ab854bdcd965ce8631ad9e0db33ae4423d35b250cb8b07b9898" diff --git a/pyproject.toml b/pyproject.toml index 8971df262bb3da60163460444cc38b491c55f2a5..b396601d8ca30dd802c5116f5493b0571a585f17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ torch = {version = "^2.1.0", optional = true} transformers = {version = "^4.36.2", optional = true} llama-cpp-python = {version = "^0.2.28", optional = true} black = "^23.12.1" +colorama = "^0.4.6" [tool.poetry.extras] hybrid = ["pinecone-text"] diff --git a/semantic_router/schema.py b/semantic_router/schema.py index f8dcd8d1a5f827fcc8ebca32b286ad23cb250f4e..0c8376ce2f428224a0c0720ef260686fc4656208 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -64,6 +64,9 @@ class Message(BaseModel): def to_llamacpp(self): return {"role": self.role, "content": self.content} + + def __str__(self): + return f"{self.role}: {self.content}" class DocumentSplit(BaseModel): diff --git a/semantic_router/text.py b/semantic_router/text.py index 003f8c368447e9b857faa1c3a0fb9fb98565fd60..dfa0ecf7b2d99cfec6508f5331be4e79d8fd1dea 100644 --- a/semantic_router/text.py +++ b/semantic_router/text.py @@ -1,3 +1,6 @@ +from colorama import Fore +from colorama import Style + from pydantic.v1 import BaseModel, Field from typing import Union, List, Literal, Tuple from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter @@ -9,6 +12,16 @@ from semantic_router.schema import DocumentSplit # Define a type alias for the splitter to simplify the annotation SplitterType = Union[ConsecutiveSimSplitter, CumulativeSimSplitter, None] +colors = [ + Fore.WHITE, + Fore.RED, + Fore.GREEN, + Fore.YELLOW, + Fore.BLUE, + Fore.MAGENTA, + Fore.CYAN +] + class Conversation(BaseModel): messages: List[Message] = Field( @@ -17,26 +30,31 @@ class Conversation(BaseModel): topics: List[Tuple[int, str]] = [] splitter: SplitterType = None + def __str__(self): + if not self.messages: + return "" + if not self.topics: + return "\n".join([str(message) for message in self.messages]) + else: + # we print each topic a different color + return_str_list = [] + current_topic_id = None + color_idx = 0 + for topic_id, message in self.topics: + if topic_id != current_topic_id: + # change color + color_idx = (color_idx + 1) % len(colors) + current_topic_id = topic_id + return_str_list.append(f"{colors[color_idx]}{message}{Style.RESET_ALL}") + return "\n".join(return_str_list) + + def add_new_messages(self, new_messages: List[Message]): self.messages.extend(new_messages) def remove_topics(self): self.topics = [] - def print_topics(self): - if not self.topics: - print("No topics to display.") - return - print("Topics:") - current_topic_id = None - for topic_id, message in self.topics: - if topic_id != current_topic_id: - if current_topic_id is not None: - print("\n", end="") - print(f"Topic {topic_id + 1}:") - current_topic_id = topic_id - print(f" - {message}") - def configure_splitter( self, encoder: BaseEncoder, diff --git a/tests/unit/test_splitters.py b/tests/unit/test_splitters.py index 218323b51d1f4ed406128f2b2d87b198377ec9fb..165fc2dd4327320f12ed3ea3a68c440b650af1dd 100644 --- a/tests/unit/test_splitters.py +++ b/tests/unit/test_splitters.py @@ -88,7 +88,7 @@ def test_split_by_topic_consecutive_similarity(): messages = [ Message(role="User", content="What is the latest news?"), - Message(role="Bot", content="How is the weather today?"), + Message(role="Assistant", content="How is the weather today?"), ] conversation = Conversation(messages=messages) @@ -107,7 +107,7 @@ def test_split_by_topic_consecutive_similarity(): assert len(new_topics) == 2 assert new_topics[0].docs == ["User: What is the latest news?"] - assert new_topics[1].docs == ["Bot: How is the weather today?"] + assert new_topics[1].docs == ["Assistant: How is the weather today?"] def test_split_by_topic_cumulative_similarity(): @@ -118,7 +118,7 @@ def test_split_by_topic_cumulative_similarity(): messages = [ Message(role="User", content="What is the latest news?"), - Message(role="Bot", content="How is the weather today?"), + Message(role="Assistant", content="How is the weather today?"), ] conversation = Conversation(messages=messages)