Skip to content
Snippets Groups Projects
pred.py 8.67 KiB
Newer Older
  • Learn to ignore specific revisions
  • bys0318's avatar
    bys0318 committed
    import os
    from datasets import load_dataset
    import torch
    import json
    
    bys0318's avatar
    bys0318 committed
    from transformers import AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM
    
    bys0318's avatar
    bys0318 committed
    from tqdm import tqdm
    
    bys0318's avatar
    bys0318 committed
    import numpy as np
    import random
    import argparse
    from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
    
    import torch.distributed as dist
    import torch.multiprocessing as mp
    
    bys0318's avatar
    bys0318 committed
    
    
    bys0318's avatar
    bys0318 committed
    def parse_args(args=None):
        parser = argparse.ArgumentParser()
    
    JackKuo666's avatar
    JackKuo666 committed
        parser.add_argument('--model', type=str, default=None, choices=["llama2-7b-chat-4k", "longchat-v1.5-7b-32k", "xgen-7b-8k", "internlm-7b-8k", "chatglm2-6b", "chatglm2-6b-32k", "chatglm3-6b-32k", "vicuna-v1.5-7b-16k"])
    
    bys0318's avatar
    bys0318 committed
        parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
        return parser.parse_args(args)
    
    bys0318's avatar
    bys0318 committed
    
    
    bys0318's avatar
    bys0318 committed
    # This is the customized building prompt for chat models
    def build_chat(tokenizer, prompt, model_name):
    
    bys0318's avatar
    bys0318 committed
        if "chatglm3" in model_name:
    
    bys0318's avatar
    bys0318 committed
            prompt = tokenizer.build_chat_input(prompt)
    
    bys0318's avatar
    bys0318 committed
        elif "chatglm" in model_name:
    
    bys0318's avatar
    bys0318 committed
            prompt = tokenizer.build_prompt(prompt)
        elif "longchat" in model_name or "vicuna" in model_name:
            from fastchat.model import get_conversation_template
            conv = get_conversation_template("vicuna")
            conv.append_message(conv.roles[0], prompt)
            conv.append_message(conv.roles[1], None)
    
    bys0318's avatar
    bys0318 committed
            prompt = conv.get_prompt()
    
    bys0318's avatar
    bys0318 committed
        elif "llama2" in model_name:
            prompt = f"[INST]{prompt}[/INST]"
        elif "xgen" in model_name:
            header = (
                "A chat between a curious human and an artificial intelligence assistant. "
                "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
            )
            prompt = header + f" ### Human: {prompt}\n###"
        elif "internlm" in model_name:
            prompt = f"<|User|>:{prompt}<eoh>\n<|Bot|>:"
        return prompt
    
    def post_process(response, model_name):
        if "xgen" in model_name:
            response = response.strip().replace("Assistant:", "")
        elif "internlm" in model_name:
            response = response.split("<eoa>")[0]
        return response
    
    
    def get_pred(rank, world_size, data, max_length, max_gen, prompt_format, dataset, device, model_name, model2path, out_path):
        device = torch.device(f'cuda:{rank}')
        model, tokenizer = load_model_and_tokenizer(model2path[model_name], model_name, device)
    
    bys0318's avatar
    bys0318 committed
        for json_obj in tqdm(data):
    
    bys0318's avatar
    bys0318 committed
            prompt = prompt_format.format(**json_obj)
            # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
            tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
    
    bys0318's avatar
    bys0318 committed
            if "chatglm3" in model_name:
    
    bys0318's avatar
    bys0318 committed
                tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt", add_special_tokens=False).input_ids[0]
    
    bys0318's avatar
    bys0318 committed
            if len(tokenized_prompt) > max_length:
                half = int(max_length/2)
                prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
    
    bys0318's avatar
    bys0318 committed
            if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
                prompt = build_chat(tokenizer, prompt, model_name)
    
    bys0318's avatar
    bys0318 committed
            if "chatglm3" in model_name:
    
    Yushi Bai's avatar
    Yushi Bai committed
                if dataset in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]:
                    input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
                else:
                    input = prompt.to(device)
    
    bys0318's avatar
    bys0318 committed
            else:
                input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
    
    bys0318's avatar
    bys0318 committed
            context_length = input.input_ids.shape[-1]
    
    bys0318's avatar
    bys0318 committed
            if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue
                output = model.generate(
                    **input,
                    max_new_tokens=max_gen,
                    num_beams=1,
                    do_sample=False,
                    temperature=1.0,
                    min_length=context_length+1,
                    eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]],
                )[0]
            else:
                output = model.generate(
                    **input,
                    max_new_tokens=max_gen,
                    num_beams=1,
                    do_sample=False,
                    temperature=1.0,
                )[0]
    
    bys0318's avatar
    bys0318 committed
            pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)
    
    bys0318's avatar
    bys0318 committed
            pred = post_process(pred, model_name)
    
            with open(out_path, "a", encoding="utf-8") as f:
                json.dump({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"]}, f, ensure_ascii=False)
                f.write('\n')
        dist.destroy_process_group()
    
    bys0318's avatar
    bys0318 committed
    
    
    bys0318's avatar
    bys0318 committed
    def seed_everything(seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed_all(seed)
    
    def load_model_and_tokenizer(path, model_name, device):
        if "chatglm" in model_name or "internlm" in model_name or "xgen" in model_name:
            tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
            model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
        elif "llama2" in model_name:
            replace_llama_attn_with_flash_attn()
            tokenizer = LlamaTokenizer.from_pretrained(path)
            model = LlamaForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to(device)
        elif "longchat" in model_name or "vicuna" in model_name:
            from fastchat.model import load_model
            replace_llama_attn_with_flash_attn()
            model, _ = load_model(
                path,
                device='cpu',
                num_gpus=0,
                load_8bit=False,
                cpu_offloading=False,
                debug=False,
            )
            model = model.to(device)
            model = model.bfloat16()
            tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
        model = model.eval()
        return model, tokenizer
    
    bys0318's avatar
    bys0318 committed
    
    if __name__ == '__main__':
    
    bys0318's avatar
    bys0318 committed
        seed_everything(42)
        args = parse_args()
    
        world_size = torch.cuda.device_count()
        mp.set_start_method('spawn', force=True)
    
    
    bys0318's avatar
    bys0318 committed
        model2path = json.load(open("config/model2path.json", "r"))
        model2maxlen = json.load(open("config/model2maxlen.json", "r"))
    
    bys0318's avatar
    bys0318 committed
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    bys0318's avatar
    bys0318 committed
        model_name = args.model
        # define your model
        max_length = model2maxlen[model_name]
        if args.e:
            datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", \
                "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"]
        else:
            datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \
                        "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \
                        "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"]
    
    bys0318's avatar
    bys0318 committed
        # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output
        dataset2prompt = json.load(open("config/dataset2prompt.json", "r"))
        dataset2maxlen = json.load(open("config/dataset2maxlen.json", "r"))
        # predict on each dataset
        if not os.path.exists("pred"):
            os.makedirs("pred")
    
    bys0318's avatar
    bys0318 committed
        if not os.path.exists("pred_e"):
            os.makedirs("pred_e")
    
    bys0318's avatar
    bys0318 committed
        for dataset in datasets:
    
    bys0318's avatar
    bys0318 committed
            if args.e:
                data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test')
                if not os.path.exists(f"pred_e/{model_name}"):
                    os.makedirs(f"pred_e/{model_name}")
                out_path = f"pred_e/{model_name}/{dataset}.jsonl"
            else:
                data = load_dataset('THUDM/LongBench', dataset, split='test')
                if not os.path.exists(f"pred/{model_name}"):
                    os.makedirs(f"pred/{model_name}")
                out_path = f"pred/{model_name}/{dataset}.jsonl"
    
    bys0318's avatar
    bys0318 committed
            prompt_format = dataset2prompt[dataset]
            max_gen = dataset2maxlen[dataset]
    
            data_all = [data_sample for data_sample in data]
            data_subsets = [data_all[i::world_size] for i in range(world_size)]
            processes = []
            for rank in range(world_size):
                p = mp.Process(target=get_pred, args=(rank, world_size, data_subsets[rank], max_length, \
                            max_gen, prompt_format, dataset, device, model_name, model2path, out_path))
                p.start()
                processes.append(p)
            for p in processes:
    
    Yushi Bai's avatar
    Yushi Bai committed
                p.join()