From ff5882f00665fedcecacb8da13ead4d73079d9e4 Mon Sep 17 00:00:00 2001 From: bys0318 <bys0318@126.com> Date: Thu, 2 Nov 2023 20:24:36 +0800 Subject: [PATCH] fix chatglm3 build_chat --- pred.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pred.py b/pred.py index 8163891..e3ff006 100644 --- a/pred.py +++ b/pred.py @@ -18,7 +18,8 @@ def parse_args(args=None): # This is the customized building prompt for chat models def build_chat(tokenizer, prompt, model_name): if "chatglm3" in model_name: - prompt = tokenizer.build_chat_input(prompt) + prompt = tokenizer.build_chat_input(prompt).input_ids[0] + prompt = tokenizer.decode(prompt[2:], skip_special_tokens=True) elif "chatglm" in model_name: prompt = tokenizer.build_prompt(prompt) elif "longchat" in model_name or "vicuna" in model_name: @@ -26,7 +27,7 @@ def build_chat(tokenizer, prompt, model_name): conv = get_conversation_template("vicuna") conv.append_message(conv.roles[0], prompt) conv.append_message(conv.roles[1], None) - prompt = conv.get_prompt() + prompt = conv.get_prompt() elif "llama2" in model_name: prompt = f"[INST]{prompt}[/INST]" elif "xgen" in model_name: @@ -52,6 +53,8 @@ def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset prompt = prompt_format.format(**json_obj) # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] + if "chatglm3" in model: + tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt", add_special_tokens=False).input_ids[0] if len(tokenized_prompt) > max_length: half = int(max_length/2) prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) -- GitLab