Skip to content
Snippets Groups Projects
Commit 3196799e authored by Eustache Le Bihan's avatar Eustache Le Bihan
Browse files

chat_size

parent 0574f840
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,7 @@ import os ...@@ -9,6 +9,7 @@ import os
from pathlib import Path from pathlib import Path
from dataclasses import dataclass, field from dataclasses import dataclass, field
from copy import copy from copy import copy
from collections import deque
import numpy as np import numpy as np
import torch import torch
...@@ -490,7 +491,7 @@ class LanguageModelHandlerArguments: ...@@ -490,7 +491,7 @@ class LanguageModelHandlerArguments:
} }
) )
lm_gen_max_new_tokens: int = field( lm_gen_max_new_tokens: int = field(
default=128, default=64,
metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 128."} metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 128."}
) )
lm_gen_temperature: float = field( lm_gen_temperature: float = field(
...@@ -501,6 +502,28 @@ class LanguageModelHandlerArguments: ...@@ -501,6 +502,28 @@ class LanguageModelHandlerArguments:
default=False, default=False,
metadata={"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."} metadata={"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."}
) )
chat_size: int = field(
default=3,
metadata={"help": "Number of messages of the messages to keep for the chat. None for no limitations."}
)
class Chat:
def __init__(self, size):
self.init_chat_message = None
self.buffer = deque(maxlen=size)
def append(self, item):
self.buffer.append(item)
def init_chat(self, init_chat_message):
self.init_chat_message = init_chat_message
def to_list(self):
if self.init_chat_message:
return [self.init_chat_message] + list(self.buffer)
else:
return list(self.buffer)
class LanguageModelHandler(BaseHandler): class LanguageModelHandler(BaseHandler):
...@@ -509,6 +532,7 @@ class LanguageModelHandler(BaseHandler): ...@@ -509,6 +532,7 @@ class LanguageModelHandler(BaseHandler):
model_name="microsoft/Phi-3-mini-4k-instruct", model_name="microsoft/Phi-3-mini-4k-instruct",
device="cuda", device="cuda",
torch_dtype="float16", torch_dtype="float16",
chat_size=3,
gen_kwargs={}, gen_kwargs={},
user_role="user", user_role="user",
init_chat_role=None, init_chat_role=None,
...@@ -532,19 +556,24 @@ class LanguageModelHandler(BaseHandler): ...@@ -532,19 +556,24 @@ class LanguageModelHandler(BaseHandler):
skip_prompt=True, skip_prompt=True,
skip_special_tokens=True, skip_special_tokens=True,
) )
self.chat = [] self.chat = Chat(chat_size)
if init_chat_role: if init_chat_role:
if not init_chat_prompt: if not init_chat_prompt:
raise ValueError(f"An initial promt needs to be specified when setting init_chat_role.") raise ValueError(f"An initial promt needs to be specified when setting init_chat_role.")
self.chat.append( self.chat.init_chat(
{"role": init_chat_role, "content": init_chat_prompt} {"role": init_chat_role, "content": init_chat_prompt}
) )
self.gen_kwargs = { self.gen_kwargs = {
"streamer": self.streamer, "streamer": self.streamer,
"return_full_text": False, "return_full_text": False,
**gen_kwargs **gen_kwargs
} }
self.user_role = user_role self.user_role = user_role
self.warmup() self.warmup()
def warmup(self): def warmup(self):
...@@ -578,7 +607,7 @@ class LanguageModelHandler(BaseHandler): ...@@ -578,7 +607,7 @@ class LanguageModelHandler(BaseHandler):
self.chat.append( self.chat.append(
{"role": self.user_role, "content": prompt} {"role": self.user_role, "content": prompt}
) )
thread = Thread(target=self.pipe, args=(self.chat,), kwargs=self.gen_kwargs) thread = Thread(target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs)
thread.start() thread.start()
generated_text, printable_text = "", "" generated_text, printable_text = "", ""
logger.debug("infering language model...") logger.debug("infering language model...")
...@@ -623,7 +652,7 @@ class ParlerTTSHandlerArguments: ...@@ -623,7 +652,7 @@ class ParlerTTSHandlerArguments:
} }
) )
tts_gen_min_new_tokens: int = field( tts_gen_min_new_tokens: int = field(
default=10, default=None,
metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"} metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"}
) )
tts_gen_max_new_tokens: int = field( tts_gen_max_new_tokens: int = field(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment