Skip to content
Snippets Groups Projects
Commit 06e1bb55 authored by FaustLyu's avatar FaustLyu
Browse files

Reorder 2 lines

parent c9b0007f
No related branches found
No related tags found
No related merge requests found
......@@ -64,9 +64,9 @@ def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset
prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompt on these tasks
prompt = build_chat(tokenizer, prompt, model_name)
context_length = input.input_ids.shape[-1]
input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
context_length = input.input_ids.shape[-1]
if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue
output = model.generate(
**input,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment