Skip to content
Snippets Groups Projects
Unverified Commit b33e49e1 authored by Yushi Bai's avatar Yushi Bai Committed by GitHub
Browse files

Merge pull request #51 from FaustLyu/main

Update retrieval/
parents a80fd111 0323a8fb
No related branches found
No related tags found
No related merge requests found
......@@ -37,7 +37,7 @@ def process_jsonl_file(input_file, output_folder, chunk_size=100, filename='Unkn
for i, chunk in enumerate(chunks):
output_datum = {
'id': data['_id'] + '_' + str(i),
'text': chunk,
'text': chunk.strip(),
'title': ''
}
output_data.append(output_datum)
......
import os
import json
import argparse
args = argparse.ArgumentParser()
args.add_argument("--data", type=str, default="C200_7")
args.add_argument("--model", type=str, default="chatglm2-6b")
args = args.parse_args()
import sys
sys.path.append("..")
from metrics import (
qa_f1_score,
rouge_zh_score,
qa_f1_zh_score,
rouge_score,
classification_score,
retrieval_score,
retrieval_zh_score,
count_score,
code_sim_score,
)
dataset2metric = {
"hotpotqa": qa_f1_score,
"2wikimqa": qa_f1_score,
"musique": qa_f1_score,
"dureader": rouge_zh_score,
"narrativeqa": qa_f1_score,
"qasper": qa_f1_score,
"multifieldqa_en": qa_f1_score,
"multifieldqa_zh": qa_f1_zh_score,
"gov_report": rouge_score,
"qmsum": rouge_score,
"vcsum": rouge_zh_score,
"trec": classification_score,
"nq": qa_f1_score,
"triviaqa": qa_f1_score,
"lsht": classification_score,
"passage_retrieval_en": retrieval_score,
"passage_count": count_score,
"passage_retrieval_zh": retrieval_zh_score,
"lcc": code_sim_score,
"repobench-p": code_sim_score,
}
def scorer(dataset, predictions, answers, all_classes):
total_score = 0.
for (prediction, ground_truths) in zip(predictions, answers):
score = 0.
for ground_truth in ground_truths:
score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
total_score += score
return round(100 * total_score / len(predictions), 2)
if __name__ == '__main__':
scores = dict()
all_files = os.listdir(f"{args.model}_pred_{args.data}")
for filename in all_files:
predictions, answers = [], []
dataset = filename.split('.')[0]
with open(f"{args.model}_pred_{args.data}/{filename}", "r", encoding='utf-8') as f:
for line in f:
data = json.loads(line)
predictions.append(data["pred"])
answers.append(data["answers"])
all_classes = data["all_classes"]
score = scorer(dataset, predictions, answers, all_classes)
scores[dataset] = score
os.makedirs(f"result_{args.model}", exist_ok=True)
with open(f"result_{args.model}/{args.data}.json", "w", encoding='utf-8') as f:
json.dump(scores, f, ensure_ascii=False, indent=4)
......@@ -64,9 +64,9 @@ def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset
prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompt on these tasks
prompt = build_chat(tokenizer, prompt, model_name)
context_length = input.input_ids.shape[-1]
input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
context_length = input.input_ids.shape[-1]
if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue
output = model.generate(
**input,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment