Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
from datasets import load_dataset
import torch
import json
from transformers import AutoTokenizer, AutoModel, LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM
from tqdm import tqdm
import argparse
# DEBUG
# os.chdir(os.path.dirname(os.path.abspath(__file__)))
def parse_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="chatglm2-6b")
parser.add_argument("--top_k", type=int, default=3)
parser.add_argument("--data", type=str, default="B500")
return parser.parse_args(args)
# This is the customized building prompt for chat models, here is an example for ChatGLM2
def build_chat(tokenizer, prompt, model_name):
if "chatglm" in model_name:
prompt = tokenizer.build_prompt(prompt)
elif "longchat" in model_name or "vicuna" in model_name:
from fastchat.model import get_conversation_template
conv = get_conversation_template("vicuna")
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
elif "llama2" in model_name:
prompt = f"[INST]{prompt}[/INST]"
elif "xgen" in model_name:
header = (
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
)
prompt = header + f" ### Human: {prompt}\n###"
elif "internlm" in model_name:
prompt = f"<|User|>:{prompt}<eoh>\n<|Bot|>:"
return prompt
def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name, args):
preds = [{}] * len(data)
if os.path.exists(f"{args.model}_pred_{args.data}_{args.top_k}/{dataset}.jsonl"):
with open(f"{args.model}_pred_{args.data}_{args.top_k}/{dataset}.jsonl", "r", encoding="utf-8") as f:
for index, item in enumerate(f):
preds[index] = json.loads(item)
for index, json_obj in enumerate(tqdm(data, desc=f"{dataset}")):
if preds[index] != {}:
continue
if args.top_k != 0:
json_obj['context'] = "".join(json_obj['retrieved'][:args.top_k])
prompt = prompt_format.format(**json_obj)
prompt = build_chat(tokenizer, prompt, model_name)
if "chatgpt" in model_name:
output = openai.ChatCompletion.create(model="gpt-3.5-turbo-16k",
messages=[{"role": "user", "content": prompt}], max_tokens=max_gen,
temperature=1.0)
pred = output['choices'][0]['message']['content']
context_length = output['usage']['prompt_tokens']
else:
# truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
if len(tokenized_prompt) > max_length:
half = int(max_length/2)
prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompt on these tasks
prompt = build_chat(tokenizer, prompt, model_name)
input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue
output = model.generate(
**input,
max_new_tokens=max_gen,
num_beams=1,
do_sample=False,
temperature=1.0,
min_length=context_length+1,
eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]],
)[0]
else:
output = model.generate(
**input,
max_new_tokens=max_gen,
num_beams=1,
do_sample=False,
temperature=1.0,
)[0]
pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)
pred = post_process(pred, model_name)
preds[index] = {"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"],
"context_length": context_length}
with open(f"{args.model}_pred_{args.data}_{args.top_k}/{dataset}.jsonl", "w", encoding="utf-8") as f:
for pred in preds:
json.dump(pred, f, ensure_ascii=False)
f.write('\n')
return preds
def post_process(response, model_name):
if "xgen" in model_name:
response = response.strip().replace("Assistant:", "")
elif "internlm" in model_name:
response = response.split("<eoa>")[0]
return response
def load_model_and_tokenizer(model2path, model_name, device):
if "chatgpt" in model_name:
return model_name, model_name
else:
if "chatglm" in model_name or "internlm" in model_name or "xgen" in model_name:
tokenizer = AutoTokenizer.from_pretrained(model2path[model_name], trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model2path[model_name], trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
elif "llama2" in model_name:
tokenizer = LlamaTokenizer.from_pretrained(model2path[model_name])
model = LlamaForCausalLM.from_pretrained(model2path[model_name], torch_dtype=torch.bfloat16).to(device)
elif "longchat" in model_name or "vicuna" in model_name:
from fastchat.model import load_model
model, _ = load_model(
model2path[model_name],
device='cpu',
num_gpus=0,
load_8bit=False,
cpu_offloading=False,
debug=False,
)
model = model.to(device)
model = model.bfloat16()
tokenizer = AutoTokenizer.from_pretrained(model2path[model_name], trust_remote_code=True, use_fast=False)
model = model.eval()
return model, tokenizer
if __name__ == '__main__':
args = parse_args()
model_name = args.model
if "chatgpt" in model_name:
import openai
# openai.api_base=""
openai.api_key = "YOUR_KEY"
# Retrieval is fit for these datasets
datasets = ["multifieldqa_en", "qasper", "2wikimqa", "dureader", \
"hotpotqa", "narrativeqa", "musique", "multifieldqa_zh"]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load configs
model2path = json.load(open("../config/model2path.json", "r"))
model2maxlen = json.load(open("../config/model2maxlen.json", "r"))
# we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output
dataset2prompt = json.load(open("../config/dataset2prompt.json", "r"))
dataset2maxlen = json.load(open("../config/dataset2maxlen.json", "r"))
# define your model
model, tokenizer = load_model_and_tokenizer(model2path, model_name, device)
max_length = model2maxlen[model_name]
# predict on each dataset
os.makedirs(f"{args.model}_pred_{args.data}_{args.top_k}", exist_ok=True)
for dataset in datasets:
data = load_dataset(f'../LongBench/{args.data}/LongBench.py', dataset, split='test',
download_mode='force_redownload') # force to load from dir
prompt_format = dataset2prompt[dataset]
max_gen = dataset2maxlen[dataset]
preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name, args)