Skip to content
Snippets Groups Projects
Commit 2bfb6a8a authored by Eustache Le Bihan's avatar Eustache Le Bihan
Browse files

handle diff chat templates

parent 618b5c23
No related branches found
No related tags found
No related merge requests found
...@@ -470,7 +470,7 @@ class LanguageModelHandlerArguments: ...@@ -470,7 +470,7 @@ class LanguageModelHandlerArguments:
} }
) )
init_chat_role: str = field( init_chat_role: str = field(
default="system", default=None,
metadata={ metadata={
"help": "Initial role for setting up the chat context. Default is 'system'." "help": "Initial role for setting up the chat context. Default is 'system'."
} }
...@@ -503,7 +503,7 @@ class LanguageModelHandler(BaseHandler): ...@@ -503,7 +503,7 @@ class LanguageModelHandler(BaseHandler):
torch_dtype="float16", torch_dtype="float16",
gen_kwargs={}, gen_kwargs={},
user_role="user", user_role="user",
init_chat_role="system", init_chat_role=None,
init_chat_prompt="You are a helpful AI assistant.", init_chat_prompt="You are a helpful AI assistant.",
): ):
self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer = AutoTokenizer.from_pretrained(model_name)
...@@ -522,9 +522,13 @@ class LanguageModelHandler(BaseHandler): ...@@ -522,9 +522,13 @@ class LanguageModelHandler(BaseHandler):
skip_prompt=True, skip_prompt=True,
skip_special_tokens=True, skip_special_tokens=True,
) )
self.chat = [ self.chat = []
{"role": init_chat_role, "content": init_chat_prompt} if init_chat_role:
] if not init_chat_prompt:
raise ValueError(f"An initial promt needs to be specified when setting init_chat_role.")
self.chat.append(
{"role": init_chat_role, "content": init_chat_prompt}
)
self.gen_kwargs = { self.gen_kwargs = {
"streamer": self.streamer, "streamer": self.streamer,
"return_full_text": False, "return_full_text": False,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment