diff --git a/s2s_pipeline.py b/s2s_pipeline.py index ec72d6e485bb4e754cd2041c016db3a293527740..1fd0c48ec814a7af51d81cda8c554a593360ad84 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -470,7 +470,7 @@ class LanguageModelHandlerArguments: } ) init_chat_role: str = field( - default="system", + default=None, metadata={ "help": "Initial role for setting up the chat context. Default is 'system'." } @@ -503,7 +503,7 @@ class LanguageModelHandler(BaseHandler): torch_dtype="float16", gen_kwargs={}, user_role="user", - init_chat_role="system", + init_chat_role=None, init_chat_prompt="You are a helpful AI assistant.", ): self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -522,9 +522,13 @@ class LanguageModelHandler(BaseHandler): skip_prompt=True, skip_special_tokens=True, ) - self.chat = [ - {"role": init_chat_role, "content": init_chat_prompt} - ] + self.chat = [] + if init_chat_role: + if not init_chat_prompt: + raise ValueError(f"An initial promt needs to be specified when setting init_chat_role.") + self.chat.append( + {"role": init_chat_role, "content": init_chat_prompt} + ) self.gen_kwargs = { "streamer": self.streamer, "return_full_text": False,