Code owners
Assign users and groups as approvers for specific file changes. Learn more.
eval.py 2.29 KiB
import os
import json
import argparse
args = argparse.ArgumentParser()
args.add_argument("--data", type=str, default="C200_7")
args.add_argument("--model", type=str, default="chatglm2-6b")
args = args.parse_args()
import sys
sys.path.append("..")
from metrics import (
qa_f1_score,
rouge_zh_score,
qa_f1_zh_score,
rouge_score,
classification_score,
retrieval_score,
retrieval_zh_score,
count_score,
code_sim_score,
)
dataset2metric = {
"hotpotqa": qa_f1_score,
"2wikimqa": qa_f1_score,
"musique": qa_f1_score,
"dureader": rouge_zh_score,
"narrativeqa": qa_f1_score,
"qasper": qa_f1_score,
"multifieldqa_en": qa_f1_score,
"multifieldqa_zh": qa_f1_zh_score,
"gov_report": rouge_score,
"qmsum": rouge_score,
"vcsum": rouge_zh_score,
"trec": classification_score,
"nq": qa_f1_score,
"triviaqa": qa_f1_score,
"lsht": classification_score,
"passage_retrieval_en": retrieval_score,
"passage_count": count_score,
"passage_retrieval_zh": retrieval_zh_score,
"lcc": code_sim_score,
"repobench-p": code_sim_score,
}
def scorer(dataset, predictions, answers, all_classes):
total_score = 0.
for (prediction, ground_truths) in zip(predictions, answers):
score = 0.
for ground_truth in ground_truths:
score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
total_score += score
return round(100 * total_score / len(predictions), 2)
if __name__ == '__main__':
scores = dict()
all_files = os.listdir(f"{args.model}_pred_{args.data}")
for filename in all_files:
predictions, answers = [], []
dataset = filename.split('.')[0]
with open(f"{args.model}_pred_{args.data}/{filename}", "r", encoding='utf-8') as f:
for line in f:
data = json.loads(line)
predictions.append(data["pred"])
answers.append(data["answers"])
all_classes = data["all_classes"]
score = scorer(dataset, predictions, answers, all_classes)
scores[dataset] = score
os.makedirs(f"result_{args.model}", exist_ok=True)
with open(f"result_{args.model}/{args.data}.json", "w", encoding='utf-8') as f:
json.dump(scores, f, ensure_ascii=False, indent=4)