Skip to content
Snippets Groups Projects
pred.py 7.91 KiB
Newer Older
  • Learn to ignore specific revisions
  • bys0318's avatar
    bys0318 committed
    import os
    from datasets import load_dataset
    import torch
    import json
    
    bys0318's avatar
    bys0318 committed
    from transformers import AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM
    
    bys0318's avatar
    bys0318 committed
    from tqdm import tqdm
    
    bys0318's avatar
    bys0318 committed
    import numpy as np
    import random
    import argparse
    from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
    
    bys0318's avatar
    bys0318 committed
    
    
    bys0318's avatar
    bys0318 committed
    def parse_args(args=None):
        parser = argparse.ArgumentParser()
        parser.add_argument('--model', type=str, default=None, choices=["llama2-7b-chat-4k", "longchat-v1.5-7b-32k", "xgen-7b-8k", "internlm-7b-8k", "chatglm2-6b", "chatglm2-6b-32k", "vicuna-v1.5-7b-16k"])
        parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
        return parser.parse_args(args)
    
    bys0318's avatar
    bys0318 committed
    
    
    bys0318's avatar
    bys0318 committed
    # This is the customized building prompt for chat models
    def build_chat(tokenizer, prompt, model_name):
    
    bys0318's avatar
    bys0318 committed
        if "chatglm3" in model_name:
    
    bys0318's avatar
    bys0318 committed
            prompt = tokenizer.build_chat_input(prompt).input_ids[0]
            prompt = tokenizer.decode(prompt[2:], skip_special_tokens=True)
    
    bys0318's avatar
    bys0318 committed
        elif "chatglm" in model_name:
    
    bys0318's avatar
    bys0318 committed
            prompt = tokenizer.build_prompt(prompt)
        elif "longchat" in model_name or "vicuna" in model_name:
            from fastchat.model import get_conversation_template
            conv = get_conversation_template("vicuna")
            conv.append_message(conv.roles[0], prompt)
            conv.append_message(conv.roles[1], None)
    
    bys0318's avatar
    bys0318 committed
            prompt = conv.get_prompt()
    
    bys0318's avatar
    bys0318 committed
        elif "llama2" in model_name:
            prompt = f"[INST]{prompt}[/INST]"
        elif "xgen" in model_name:
            header = (
                "A chat between a curious human and an artificial intelligence assistant. "
                "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
            )
            prompt = header + f" ### Human: {prompt}\n###"
        elif "internlm" in model_name:
            prompt = f"<|User|>:{prompt}<eoh>\n<|Bot|>:"
        return prompt
    
    def post_process(response, model_name):
        if "xgen" in model_name:
            response = response.strip().replace("Assistant:", "")
        elif "internlm" in model_name:
            response = response.split("<eoa>")[0]
        return response
    
    def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name):
    
    bys0318's avatar
    bys0318 committed
        preds = []
    
    bys0318's avatar
    bys0318 committed
        for json_obj in tqdm(data):
    
    bys0318's avatar
    bys0318 committed
            prompt = prompt_format.format(**json_obj)
            # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
            tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
    
    bys0318's avatar
    bys0318 committed
            if "chatglm3" in model_name:
    
    bys0318's avatar
    bys0318 committed
                tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt", add_special_tokens=False).input_ids[0]
    
    bys0318's avatar
    bys0318 committed
            if len(tokenized_prompt) > max_length:
                half = int(max_length/2)
                prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
    
    bys0318's avatar
    bys0318 committed
            if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
                prompt = build_chat(tokenizer, prompt, model_name)
    
    bys0318's avatar
    bys0318 committed
            input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
            context_length = input.input_ids.shape[-1]
    
    bys0318's avatar
    bys0318 committed
            if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue
                output = model.generate(
                    **input,
                    max_new_tokens=max_gen,
                    num_beams=1,
                    do_sample=False,
                    temperature=1.0,
                    min_length=context_length+1,
                    eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]],
                )[0]
            else:
                output = model.generate(
                    **input,
                    max_new_tokens=max_gen,
                    num_beams=1,
                    do_sample=False,
                    temperature=1.0,
                )[0]
    
    bys0318's avatar
    bys0318 committed
            pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)
    
    bys0318's avatar
    bys0318 committed
            pred = post_process(pred, model_name)
            preds.append({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"]})
    
    bys0318's avatar
    bys0318 committed
        return preds
    
    
    bys0318's avatar
    bys0318 committed
    def seed_everything(seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed_all(seed)
    
    def load_model_and_tokenizer(path, model_name, device):
        if "chatglm" in model_name or "internlm" in model_name or "xgen" in model_name:
            tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
            model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
        elif "llama2" in model_name:
            replace_llama_attn_with_flash_attn()
            tokenizer = LlamaTokenizer.from_pretrained(path)
            model = LlamaForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to(device)
        elif "longchat" in model_name or "vicuna" in model_name:
            from fastchat.model import load_model
            replace_llama_attn_with_flash_attn()
            model, _ = load_model(
                path,
                device='cpu',
                num_gpus=0,
                load_8bit=False,
                cpu_offloading=False,
                debug=False,
            )
            model = model.to(device)
            model = model.bfloat16()
            tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
        model = model.eval()
        return model, tokenizer
    
    bys0318's avatar
    bys0318 committed
    
    if __name__ == '__main__':
    
    bys0318's avatar
    bys0318 committed
        seed_everything(42)
        args = parse_args()
        model2path = json.load(open("config/model2path.json", "r"))
        model2maxlen = json.load(open("config/model2maxlen.json", "r"))
    
    bys0318's avatar
    bys0318 committed
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    bys0318's avatar
    bys0318 committed
        model_name = args.model
        # define your model
        model, tokenizer = load_model_and_tokenizer(model2path[model_name], model_name, device)
        max_length = model2maxlen[model_name]
        if args.e:
            datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", \
                "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"]
        else:
            datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \
                        "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \
                        "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"]
    
    bys0318's avatar
    bys0318 committed
        # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output
        dataset2prompt = json.load(open("config/dataset2prompt.json", "r"))
        dataset2maxlen = json.load(open("config/dataset2maxlen.json", "r"))
        # predict on each dataset
        if not os.path.exists("pred"):
            os.makedirs("pred")
    
    bys0318's avatar
    bys0318 committed
        if not os.path.exists("pred_e"):
            os.makedirs("pred_e")
    
    bys0318's avatar
    bys0318 committed
        for dataset in datasets:
    
    bys0318's avatar
    bys0318 committed
            if args.e:
                data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test')
                if not os.path.exists(f"pred_e/{model_name}"):
                    os.makedirs(f"pred_e/{model_name}")
                out_path = f"pred_e/{model_name}/{dataset}.jsonl"
            else:
                data = load_dataset('THUDM/LongBench', dataset, split='test')
                if not os.path.exists(f"pred/{model_name}"):
                    os.makedirs(f"pred/{model_name}")
                out_path = f"pred/{model_name}/{dataset}.jsonl"
    
    bys0318's avatar
    bys0318 committed
            prompt_format = dataset2prompt[dataset]
            max_gen = dataset2maxlen[dataset]
    
    bys0318's avatar
    bys0318 committed
            preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name)
            with open(out_path, "w", encoding="utf-8") as f:
    
    bys0318's avatar
    bys0318 committed
                for pred in preds:
    
    bys0318's avatar
    bys0318 committed
                    json.dump(pred, f, ensure_ascii=False)
    
    bys0318's avatar
    bys0318 committed
                    f.write('\n')