Skip to content
Snippets Groups Projects
eval.py 3.69 KiB
Newer Older
bys0318's avatar
bys0318 committed
import os
import json
bys0318's avatar
bys0318 committed
import argparse
import numpy as np
bys0318's avatar
bys0318 committed

from metrics import (
    qa_f1_score,
    rouge_zh_score,
    qa_f1_zh_score,
    rouge_score,
    classification_score,
    retrieval_score,
    retrieval_zh_score,
    count_score,
    code_sim_score,
)

dataset2metric = {
    "narrativeqa": qa_f1_score,
    "qasper": qa_f1_score,
    "multifieldqa_en": qa_f1_score,
    "multifieldqa_zh": qa_f1_zh_score,
bys0318's avatar
bys0318 committed
    "hotpotqa": qa_f1_score,
    "2wikimqa": qa_f1_score,
    "musique": qa_f1_score,
    "dureader": rouge_zh_score,
bys0318's avatar
bys0318 committed
    "gov_report": rouge_score,
    "qmsum": rouge_score,
bys0318's avatar
bys0318 committed
    "multi_news": rouge_score,
bys0318's avatar
bys0318 committed
    "vcsum": rouge_zh_score,
    "trec": classification_score,
    "triviaqa": qa_f1_score,
bys0318's avatar
bys0318 committed
    "samsum": rouge_score,
bys0318's avatar
bys0318 committed
    "lsht": classification_score,
    "passage_retrieval_en": retrieval_score,
    "passage_count": count_score,
bys0318's avatar
bys0318 committed
    "passage_retrieval_zh": retrieval_zh_score,
bys0318's avatar
bys0318 committed
    "lcc": code_sim_score,
    "repobench-p": code_sim_score,
}

bys0318's avatar
bys0318 committed
def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default=None)
    parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
    return parser.parse_args(args)

def scorer_e(dataset, predictions, answers, lengths, all_classes):
    scores = {"0-4k": [], "4-8k": [], "8k+": []}
    for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
        score = 0.
        if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
            prediction = prediction.lstrip('\n').split('\n')[0]
        for ground_truth in ground_truths:
            score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
        if length < 4000:
            scores["0-4k"].append(score)
        elif length < 8000:
            scores["4-8k"].append(score)
        else:
            scores["8k+"].append(score)
    for key in scores.keys():
        scores[key] = round(100 * np.mean(scores[key]), 2)
    return scores

bys0318's avatar
bys0318 committed
def scorer(dataset, predictions, answers, all_classes):
    total_score = 0.
    for (prediction, ground_truths) in zip(predictions, answers):
        score = 0.
bys0318's avatar
bys0318 committed
        if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
            prediction = prediction.lstrip('\n').split('\n')[0]
bys0318's avatar
bys0318 committed
        for ground_truth in ground_truths:
            score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
        total_score += score
    return round(100 * total_score / len(predictions), 2)

if __name__ == '__main__':
bys0318's avatar
bys0318 committed
    args = parse_args()
bys0318's avatar
bys0318 committed
    scores = dict()
bys0318's avatar
bys0318 committed
    if args.e:
        path = f"pred_e/{args.model}/"
    else:
        path = f"pred/{args.model}/"
    all_files = os.listdir(path)
    print("Evaluating on:", all_files)
bys0318's avatar
bys0318 committed
    for filename in all_files:
bys0318's avatar
bys0318 committed
        if not filename.endswith("jsonl"):
            continue
        predictions, answers, lengths = [], [], []
bys0318's avatar
bys0318 committed
        dataset = filename.split('.')[0]
bys0318's avatar
bys0318 committed
        with open(f"{path}{filename}", "r", encoding="utf-8") as f:
bys0318's avatar
bys0318 committed
            for line in f:
                data = json.loads(line)
                predictions.append(data["pred"])
                answers.append(data["answers"])
                all_classes = data["all_classes"]
bys0318's avatar
bys0318 committed
                if "length" in data:
                    lengths.append(data["length"])
        if args.e:
            score = scorer_e(dataset, predictions, answers, lengths, all_classes)
        else:
            score = scorer(dataset, predictions, answers, all_classes)
bys0318's avatar
bys0318 committed
        scores[dataset] = score
bys0318's avatar
bys0318 committed
    if args.e:
        out_path = f"pred_e/{args.model}/result.json"
    else:
        out_path = f"pred/{args.model}/result.json"
    with open(out_path, "w") as f:
bys0318's avatar
bys0318 committed
        json.dump(scores, f, ensure_ascii=False, indent=4)