Skip to content
Snippets Groups Projects
Unverified Commit 545a4c8b authored by James Briggs's avatar James Briggs
Browse files

modified print for conversations

parent 3e650f89
Branches
Tags
No related merge requests found
......@@ -3388,4 +3388,4 @@ local = ["llama-cpp-python", "torch", "transformers"]
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
content-hash = "49fd469a4cf8a0a31d2e4df9e9a2b40d1f6fefeba87ec26d76d1cb716f4f51ca"
content-hash = "10453196b0249ab854bdcd965ce8631ad9e0db33ae4423d35b250cb8b07b9898"
......@@ -28,6 +28,7 @@ torch = {version = "^2.1.0", optional = true}
transformers = {version = "^4.36.2", optional = true}
llama-cpp-python = {version = "^0.2.28", optional = true}
black = "^23.12.1"
colorama = "^0.4.6"
[tool.poetry.extras]
hybrid = ["pinecone-text"]
......
......@@ -64,6 +64,9 @@ class Message(BaseModel):
def to_llamacpp(self):
return {"role": self.role, "content": self.content}
def __str__(self):
return f"{self.role}: {self.content}"
class DocumentSplit(BaseModel):
......
from colorama import Fore
from colorama import Style
from pydantic.v1 import BaseModel, Field
from typing import Union, List, Literal, Tuple
from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
......@@ -9,6 +12,16 @@ from semantic_router.schema import DocumentSplit
# Define a type alias for the splitter to simplify the annotation
SplitterType = Union[ConsecutiveSimSplitter, CumulativeSimSplitter, None]
colors = [
Fore.WHITE,
Fore.RED,
Fore.GREEN,
Fore.YELLOW,
Fore.BLUE,
Fore.MAGENTA,
Fore.CYAN
]
class Conversation(BaseModel):
messages: List[Message] = Field(
......@@ -17,26 +30,31 @@ class Conversation(BaseModel):
topics: List[Tuple[int, str]] = []
splitter: SplitterType = None
def __str__(self):
if not self.messages:
return ""
if not self.topics:
return "\n".join([str(message) for message in self.messages])
else:
# we print each topic a different color
return_str_list = []
current_topic_id = None
color_idx = 0
for topic_id, message in self.topics:
if topic_id != current_topic_id:
# change color
color_idx = (color_idx + 1) % len(colors)
current_topic_id = topic_id
return_str_list.append(f"{colors[color_idx]}{message}{Style.RESET_ALL}")
return "\n".join(return_str_list)
def add_new_messages(self, new_messages: List[Message]):
self.messages.extend(new_messages)
def remove_topics(self):
self.topics = []
def print_topics(self):
if not self.topics:
print("No topics to display.")
return
print("Topics:")
current_topic_id = None
for topic_id, message in self.topics:
if topic_id != current_topic_id:
if current_topic_id is not None:
print("\n", end="")
print(f"Topic {topic_id + 1}:")
current_topic_id = topic_id
print(f" - {message}")
def configure_splitter(
self,
encoder: BaseEncoder,
......
......@@ -88,7 +88,7 @@ def test_split_by_topic_consecutive_similarity():
messages = [
Message(role="User", content="What is the latest news?"),
Message(role="Bot", content="How is the weather today?"),
Message(role="Assistant", content="How is the weather today?"),
]
conversation = Conversation(messages=messages)
......@@ -107,7 +107,7 @@ def test_split_by_topic_consecutive_similarity():
assert len(new_topics) == 2
assert new_topics[0].docs == ["User: What is the latest news?"]
assert new_topics[1].docs == ["Bot: How is the weather today?"]
assert new_topics[1].docs == ["Assistant: How is the weather today?"]
def test_split_by_topic_cumulative_similarity():
......@@ -118,7 +118,7 @@ def test_split_by_topic_cumulative_similarity():
messages = [
Message(role="User", content="What is the latest news?"),
Message(role="Bot", content="How is the weather today?"),
Message(role="Assistant", content="How is the weather today?"),
]
conversation = Conversation(messages=messages)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment