Skip to content
Snippets Groups Projects
Unverified Commit 2533f68b authored by JackKuo666's avatar JackKuo666 Committed by GitHub
Browse files

Update support chatglm3

parent 7738f1dc
No related branches found
No related tags found
No related merge requests found
......@@ -11,7 +11,7 @@ from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
def parse_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default=None, choices=["llama2-7b-chat-4k", "longchat-v1.5-7b-32k", "xgen-7b-8k", "internlm-7b-8k", "chatglm2-6b", "chatglm2-6b-32k", "vicuna-v1.5-7b-16k"])
parser.add_argument('--model', type=str, default=None, choices=["llama2-7b-chat-4k", "longchat-v1.5-7b-32k", "xgen-7b-8k", "internlm-7b-8k", "chatglm2-6b", "chatglm2-6b-32k", "chatglm3-6b-32k", "vicuna-v1.5-7b-16k"])
parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
return parser.parse_args(args)
......@@ -60,7 +60,7 @@ def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset
if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
prompt = build_chat(tokenizer, prompt, model_name)
if "chatglm3" in model_name:
input = prompt
input = prompt.to(device)
else:
input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
context_length = input.input_ids.shape[-1]
......@@ -163,4 +163,4 @@ if __name__ == '__main__':
with open(out_path, "w", encoding="utf-8") as f:
for pred in preds:
json.dump(pred, f, ensure_ascii=False)
f.write('\n')
\ No newline at end of file
f.write('\n')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment