Skip to content
Snippets Groups Projects
Commit e5837929 authored by bys0318's avatar bys0318
Browse files

fix typo

parent ff5882f0
No related branches found
No related tags found
No related merge requests found
......@@ -53,7 +53,7 @@ def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset
prompt = prompt_format.format(**json_obj)
# truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
if "chatglm3" in model:
if "chatglm3" in model_name:
tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt", add_special_tokens=False).input_ids[0]
if len(tokenized_prompt) > max_length:
half = int(max_length/2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment