From 51738c7955a09c6d574c5d092cc896ff4f27200f Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan <eulebihan@gmail.com> Date: Tue, 13 Aug 2024 22:31:49 +0000 Subject: [PATCH] fix chat --- s2s_pipeline.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/s2s_pipeline.py b/s2s_pipeline.py index b7186fc..c784105 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -550,31 +550,36 @@ class LanguageModelHandlerArguments: metadata={"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."} ) chat_size: int = field( - default=3, - metadata={"help": "Number of messages of the messages to keep for the chat. None for no limitations."} + default=1, + metadata={"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."} ) class Chat: """ - Handles the chat using a circular buffer to avoid OOM issues. + Handles the chat using to avoid OOM issues. """ def __init__(self, size): + self.size = size self.init_chat_message = None - self.buffer = deque(maxlen=size) + # maxlen is necessary pair, since a each new step we add an prompt and assitant answer + self.buffer = [] def append(self, item): self.buffer.append(item) + if len(self.buffer) == 2 * (self.size + 1): + self.buffer.pop(0) + self.buffer.pop(0) def init_chat(self, init_chat_message): self.init_chat_message = init_chat_message def to_list(self): if self.init_chat_message: - return [self.init_chat_message] + list(self.buffer) + return [self.init_chat_message] + self.buffer else: - return list(self.buffer) + return self.buffer class LanguageModelHandler(BaseHandler): @@ -589,7 +594,7 @@ class LanguageModelHandler(BaseHandler): torch_dtype="float16", gen_kwargs={}, user_role="user", - chat_size=3, + chat_size=1, init_chat_role=None, init_chat_prompt="You are a helpful AI assistant.", ): @@ -663,7 +668,7 @@ class LanguageModelHandler(BaseHandler): self.chat.append( {"role": self.user_role, "content": prompt} ) - thread = Thread(target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs) +x thread = Thread(target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs) thread.start() generated_text, printable_text = "", "" -- GitLab