Skip to content
Snippets Groups Projects
Commit ff5882f0 authored by bys0318's avatar bys0318
Browse files

fix chatglm3 build_chat

parent 2310b3bc
No related branches found
No related tags found
No related merge requests found
...@@ -18,7 +18,8 @@ def parse_args(args=None): ...@@ -18,7 +18,8 @@ def parse_args(args=None):
# This is the customized building prompt for chat models # This is the customized building prompt for chat models
def build_chat(tokenizer, prompt, model_name): def build_chat(tokenizer, prompt, model_name):
if "chatglm3" in model_name: if "chatglm3" in model_name:
prompt = tokenizer.build_chat_input(prompt) prompt = tokenizer.build_chat_input(prompt).input_ids[0]
prompt = tokenizer.decode(prompt[2:], skip_special_tokens=True)
elif "chatglm" in model_name: elif "chatglm" in model_name:
prompt = tokenizer.build_prompt(prompt) prompt = tokenizer.build_prompt(prompt)
elif "longchat" in model_name or "vicuna" in model_name: elif "longchat" in model_name or "vicuna" in model_name:
...@@ -26,7 +27,7 @@ def build_chat(tokenizer, prompt, model_name): ...@@ -26,7 +27,7 @@ def build_chat(tokenizer, prompt, model_name):
conv = get_conversation_template("vicuna") conv = get_conversation_template("vicuna")
conv.append_message(conv.roles[0], prompt) conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None) conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt() prompt = conv.get_prompt()
elif "llama2" in model_name: elif "llama2" in model_name:
prompt = f"[INST]{prompt}[/INST]" prompt = f"[INST]{prompt}[/INST]"
elif "xgen" in model_name: elif "xgen" in model_name:
...@@ -52,6 +53,8 @@ def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset ...@@ -52,6 +53,8 @@ def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset
prompt = prompt_format.format(**json_obj) prompt = prompt_format.format(**json_obj)
# truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
if "chatglm3" in model:
tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt", add_special_tokens=False).input_ids[0]
if len(tokenized_prompt) > max_length: if len(tokenized_prompt) > max_length:
half = int(max_length/2) half = int(max_length/2)
prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment