Skip to content
Snippets Groups Projects
Unverified Commit a0bd51eb authored by Yushi Bai's avatar Yushi Bai Committed by GitHub
Browse files

Update pred.py

parent b33e49e1
No related branches found
No related tags found
No related merge requests found
......@@ -63,7 +63,10 @@ def get_pred(rank, world_size, data, max_length, max_gen, prompt_format, dataset
if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
prompt = build_chat(tokenizer, prompt, model_name)
if "chatglm3" in model_name:
input = prompt.to(device)
if dataset in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]:
input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
else:
input = prompt.to(device)
else:
input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
context_length = input.input_ids.shape[-1]
......@@ -175,4 +178,4 @@ if __name__ == '__main__':
p.start()
processes.append(p)
for p in processes:
p.join()
\ No newline at end of file
p.join()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment