diff --git a/pred.py b/pred.py index 8163891d348aac9a388a23e11b8f44606d836331..e3ff006286e0323e93d02f20422a48312d09c0b2 100644 --- a/pred.py +++ b/pred.py @@ -18,7 +18,8 @@ def parse_args(args=None): # This is the customized building prompt for chat models def build_chat(tokenizer, prompt, model_name): if "chatglm3" in model_name: - prompt = tokenizer.build_chat_input(prompt) + prompt = tokenizer.build_chat_input(prompt).input_ids[0] + prompt = tokenizer.decode(prompt[2:], skip_special_tokens=True) elif "chatglm" in model_name: prompt = tokenizer.build_prompt(prompt) elif "longchat" in model_name or "vicuna" in model_name: @@ -26,7 +27,7 @@ def build_chat(tokenizer, prompt, model_name): conv = get_conversation_template("vicuna") conv.append_message(conv.roles[0], prompt) conv.append_message(conv.roles[1], None) - prompt = conv.get_prompt() + prompt = conv.get_prompt() elif "llama2" in model_name: prompt = f"[INST]{prompt}[/INST]" elif "xgen" in model_name: @@ -52,6 +53,8 @@ def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset prompt = prompt_format.format(**json_obj) # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] + if "chatglm3" in model: + tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt", add_special_tokens=False).input_ids[0] if len(tokenized_prompt) > max_length: half = int(max_length/2) prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)