From 3196799ec483096fbf08f30a05279f7f79ffe982 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan <eulebihan@gmail.com> Date: Tue, 13 Aug 2024 19:25:52 +0000 Subject: [PATCH] chat_size --- s2s_pipeline.py | 39 ++++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/s2s_pipeline.py b/s2s_pipeline.py index eebb2c6..1b744ff 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -9,6 +9,7 @@ import os from pathlib import Path from dataclasses import dataclass, field from copy import copy +from collections import deque import numpy as np import torch @@ -490,7 +491,7 @@ class LanguageModelHandlerArguments: } ) lm_gen_max_new_tokens: int = field( - default=128, + default=64, metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 128."} ) lm_gen_temperature: float = field( @@ -501,6 +502,28 @@ class LanguageModelHandlerArguments: default=False, metadata={"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."} ) + chat_size: int = field( + default=3, + metadata={"help": "Number of messages of the messages to keep for the chat. None for no limitations."} + ) + + +class Chat: + def __init__(self, size): + self.init_chat_message = None + self.buffer = deque(maxlen=size) + + def append(self, item): + self.buffer.append(item) + + def init_chat(self, init_chat_message): + self.init_chat_message = init_chat_message + + def to_list(self): + if self.init_chat_message: + return [self.init_chat_message] + list(self.buffer) + else: + return list(self.buffer) class LanguageModelHandler(BaseHandler): @@ -509,6 +532,7 @@ class LanguageModelHandler(BaseHandler): model_name="microsoft/Phi-3-mini-4k-instruct", device="cuda", torch_dtype="float16", + chat_size=3, gen_kwargs={}, user_role="user", init_chat_role=None, @@ -532,19 +556,24 @@ class LanguageModelHandler(BaseHandler): skip_prompt=True, skip_special_tokens=True, ) - self.chat = [] + self.chat = Chat(chat_size) if init_chat_role: if not init_chat_prompt: raise ValueError(f"An initial promt needs to be specified when setting init_chat_role.") - self.chat.append( + self.chat.init_chat( {"role": init_chat_role, "content": init_chat_prompt} ) + self.gen_kwargs = { "streamer": self.streamer, "return_full_text": False, **gen_kwargs } self.user_role = user_role + + + + self.warmup() def warmup(self): @@ -578,7 +607,7 @@ class LanguageModelHandler(BaseHandler): self.chat.append( {"role": self.user_role, "content": prompt} ) - thread = Thread(target=self.pipe, args=(self.chat,), kwargs=self.gen_kwargs) + thread = Thread(target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs) thread.start() generated_text, printable_text = "", "" logger.debug("infering language model...") @@ -623,7 +652,7 @@ class ParlerTTSHandlerArguments: } ) tts_gen_min_new_tokens: int = field( - default=10, + default=None, metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"} ) tts_gen_max_new_tokens: int = field( -- GitLab