diff --git a/pred.py b/pred.py index 467d38b90b16bdf872596145b02b6777bf02c2e7..8163891d348aac9a388a23e11b8f44606d836331 100644 --- a/pred.py +++ b/pred.py @@ -17,7 +17,9 @@ def parse_args(args=None): # This is the customized building prompt for chat models def build_chat(tokenizer, prompt, model_name): - if "chatglm" in model_name: + if "chatglm3" in model_name: + prompt = tokenizer.build_chat_input(prompt) + elif "chatglm" in model_name: prompt = tokenizer.build_prompt(prompt) elif "longchat" in model_name or "vicuna" in model_name: from fastchat.model import get_conversation_template