Skip to content
Snippets Groups Projects
Unverified Commit 7ebd68c2 authored by Kai Wu's avatar Kai Wu Committed by GitHub
Browse files

Add llama 3.2 mmlu, math, gpqa evals to meta_eval harness (#801)

parents 2f72bcec ab1b1450
No related branches found
No related tags found
No related merge requests found
Showing
with 222 additions and 45 deletions
# Calculating Meta 3.1 Evaluation Metrics Using LM-Evaluation-Harness
# Calculating Meta 3.x Evaluation Metrics Using LM-Evaluation-Harness
As Llama models gain popularity, evaluating these models has become increasingly important. We have released all the evaluation details for Llama 3.1 models as datasets in the [3.1 evals Hugging Face collection](https://huggingface.co/collections/meta-llama/llama-31-evals-66a2c5a14c2093e58298ac7f). This recipe demonstrates how to calculate the Llama 3.1 reported benchmark numbers using the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/main) library and our prompts from the 3.1 evals datasets on selected tasks.
As Llama models gain popularity, evaluating these models has become increasingly important. We have released all the evaluation details for Llama 3.x models on Hugging Face as datasets in the [3.1 evals collection](https://huggingface.co/collections/meta-llama/llama-31-evals-66a2c5a14c2093e58298ac7f) and the [3.2 evals collection](https://huggingface.co/collections/meta-llama/llama-32-evals-66f44b3d2df1c7b136d821f0). This recipe demonstrates how to calculate the Llama 3.x reported benchmark numbers using the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/main) library and our prompts from the 3.x evals datasets on selected tasks.
## Disclaimer
1. **This recipe is not the official implementation** of Llama evaluation. Since our internal eval repo isn't public, we want to provide this recipe as an aid for anyone who wants to use the datasets we released. It is based on public third-party libraries, as this implementation is not mirroring Llama evaluation, therefore this may lead to minor differences in the produced numbers.
2. **Model Compatibility**: This tutorial is specifically for Llama 3 based models, as our prompts include Llama 3 special tokens, e.g. `<|start_header_id|>user<|end_header_id|>`. It will not work with models that are not based on Llama 3.
......@@ -38,14 +37,20 @@ To access our [3.1 evals Hugging Face collection](https://huggingface.co/collect
- Log in to the Hugging Face website and click the 3.1 evals dataset pages and agree to the terms.
- Follow the [Hugging Face authentication instructions](https://huggingface.co/docs/huggingface_hub/en/quick-start#authentication) to gain read access for your machine.
The same process can be followed to access the [3.2 evals Hugging Face collection](https://huggingface.co/collections/meta-llama/llama-32-evals-66f44b3d2df1c7b136d821f0)
It is recommended to read the dataset card to understand the meaning of each column and use the viewer feature in the Hugging Face dataset to view our dataset. It is important to have some basic understanding of our dataset format and content before proceeding.
### Task Selection
Given the extensive number of tasks available (12 for pretrained models and 30 for instruct models), here we will focus on tasks that overlap with the popular Hugging Face [Open LLM Leaderboard v2](https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard) as shown in the following:
Given the extensive number of tasks available (12 for pretrained models and 30 for instruct models), a subset of tasks are chosen:
- **Tasks for 3.1 pretrained models**: BBH and MMLU-Pro
- **Tasks for 3.1 instruct models**: Math-Hard, IFeval, GPQA, and MMLU-Pro
- **Tasks for 3.2 pretrained models**: MMLU
- **Tasks for 3.2 instruct models**: MMLU, GPQA
- **Tasks for pretrained models**: BBH and MMLU-Pro
- **Tasks for instruct models**: Math-Hard, IFeval, GPQA, and MMLU-Pro
These tasks are common evalutions, many of which overlap with the Hugging Face [Open LLM Leaderboard v2](https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard)
Here, we aim to get the benchmark numbers on the aforementioned tasks using Hugging Face [leaderboard implementation](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/lm_eval/tasks/leaderboard). Please follow the instructions below to make necessary modifications to use our eval prompts and get more eval metrics.
......@@ -58,14 +63,16 @@ Here, we aim to get the benchmark numbers on the aforementioned tasks using Hugg
model_name: "meta-llama/Llama-3.1-8B-Instruct" # The name of the model to evaluate. This must be a valid Llama 3 based model name in the HuggingFace model hub."
evals_dataset: "meta-llama/Llama-3.1-8B-Instruct-evals" # The name of the 3.1 evals dataset to evaluate, please make sure this eval dataset corresponds to the model loaded. This must be a valid Llama 3.1 evals dataset name in the Llama 3.1 Evals collection.
# Must be one of the following ["meta-llama/Llama-3.1-8B-Instruct-evals","meta-llama/Llama-3.1-70B-Instruct-evals","meta-llama/Llama-3.1-405B-Instruct-evals","meta-llama/Llama-3.1-8B-evals","meta-llama/Llama-3.1-70B-evals","meta-llama/Llama-3.1-405B-evals"]
# Must be one of the following ["meta-llama/Llama-3.1-8B-Instruct-evals","meta-llama/Llama-3.1-70B-Instruct-evals","meta-llama/Llama-3.1-405B-Instruct-evals","meta-llama/Llama-3.1-8B-evals","meta-llama/Llama-3.1-70B-evals","meta-llama/Llama-3.1-405B-evals","meta-llama/Llama-3.2-1B-evals","meta-llama/Llama-3.2-3B-evals", "meta-llama/Llama-3.2-1B-Instruct-evals", "meta-llama/Llama-3.2-3B-Instruct-evals"]
tasks: "meta_instruct" # Available tasks for instruct model: "meta_math_hard", "meta_gpqa", "meta_mmlu_pro_instruct", "meta_ifeval"; or just use "meta_instruct" to run all of them.
# Available tasks for pretrain model: "meta_bbh", "meta_mmlu_pro_pretrain"; or just use "meta_pretrain" to run all of them.
tasks: "meta_instruct" # Available tasks for 3.1 instruct model: "meta_math_hard", "meta_gpqa_cot", "meta_mmlu_pro_instruct", "meta_ifeval"; or just use "meta_instruct" to run all of them.
# Available tasks for 3.1 pretrain model: "meta_bbh", "meta_mmlu_pro_pretrain"; or just use "meta_pretrain" to run all of them.
# Available tasks for 3.2 instruct model: "meta_mmlu", "meta_math", "meta_gpqa"; or just use "meta_instruct" to run all of them.
# Available tasks for 3.2 pretrain model: "meta_mmlu"; or just use "meta_pretrain" to run all of them
tensor_parallel_size: 1 # The VLLM argument that speicify the tensor parallel size for the model, eg how many GPUs to use for a model copy.
tensor_parallel_size: 1 # The VLLM argument that specify the tensor parallel size for the model, eg how many GPUs to use for a model copy.
data_parallel_size: 4 # The VLLM argument that speicify the data parallel size for the model, eg how copies of model will be used.
data_parallel_size: 4 # The VLLM argument that specify the data parallel size for the model, eg how copies of model will be used.
...
......
model_name: "meta-llama/Llama-3.1-8B-Instruct" # The name of the model to evaluate. This must be a valid Meta Llama 3 based model name in the HuggingFace model hub."
evals_dataset: "meta-llama/Llama-3.1-8B-Instruct-evals" # The name of the 3.1 evals dataset to evaluate, please make sure this eval dataset corresponds to the model loaded. This must be a valid Meta Llama 3.1 evals dataset name in the Llama 3.1 Evals collection.
# Must be one of the following ["meta-llama/Llama-3.1-8B-Instruct-evals","meta-llama/Llama-3.1-70B-Instruct-evals","meta-llama/Llama-3.1-405B-Instruct-evals","meta-llama/Llama-3.1-8B-evals","meta-llama/Llama-3.1-70B-evals","meta-llama/Llama-3.1-405B-evals"]
evals_dataset: "meta-llama/Llama-3.1-8B-Instruct-evals" # The name of the 3.1 evals dataset to evaluate, please make sure this eval dataset corresponds to the model loaded. This must be a valid dataset name in the Llama 3.x Evals collection.
# Must be one of the following ["meta-llama/Llama-3.1-8B-Instruct-evals","meta-llama/Llama-3.1-70B-Instruct-evals","meta-llama/Llama-3.1-405B-Instruct-evals","meta-llama/Llama-3.1-8B-evals","meta-llama/Llama-3.1-70B-evals","meta-llama/Llama-3.1-405B-evals","meta-llama/Llama-3.2-1B-evals","meta-llama/Llama-3.2-3B-evals", "meta-llama/Llama-3.2-1B-Instruct-evals", "meta-llama/Llama-3.2-3B-Instruct-evals"]
tasks: "meta_instruct" # Available tasks for instruct model: "meta_math_hard", "meta_gpqa", "meta_mmlu_pro_instruct", "meta_ifeval"; or just use "meta_instruct" to run all of them.
# Available tasks for pretrain model: "meta_bbh", "meta_mmlu_pro_pretrain"; or just use "meta_pretrain" to run all of them.
tasks: "meta_instruct" # Available tasks for 3.1 instruct model: "meta_math_hard", "meta_gpqa_cot", "meta_mmlu_pro_instruct", "meta_ifeval"; or just use "meta_instruct" to run all of them.
# Available tasks for 3.1 pretrain model: "meta_bbh", "meta_mmlu_pro_pretrain"; or just use "meta_pretrain" to run all of them.
# Available tasks for 3.2 instruct model: "meta_mmlu", "meta_math", "meta_gpqa"; or just use "meta_instruct" to run all of them.
# Available tasks for 3.2 pretrain model: "meta_mmlu"; or just use "meta_pretrain" to run all of them
tensor_parallel_size: 1 # The VLLM argument that speicify the tensor parallel size for the model, eg how many GPUs to use for a model copy.
tensor_parallel_size: 1 # The VLLM argument that specify the tensor parallel size for the model, eg how many GPUs to use for a model copy.
data_parallel_size: 4 # The VLLM argument that speicify the data parallel size for the model, eg how copies of model will be used.
data_parallel_size: 4 # The VLLM argument that specify the data parallel size for the model, eg how copies of model will be used.
gpu_memory_utilization: 0.9 #The VLLM argument that speicify gpu memory utilization, the rest will be reserved for KV cache.
gpu_memory_utilization: 0.9 #The VLLM argument that specify gpu memory utilization, the rest will be reserved for KV cache.
max_model_len: 8192 #The VLLM argument that speicify model max length, decrease this value only if GPU memory issue encountered. Please make sure the max_gen_toks in the yaml does not exceed this length.
max_model_len: 8192 #The VLLM argument that specify model max length, decrease this value only if GPU memory issue encountered. Please make sure the max_gen_toks in the yaml does not exceed this length.
batch_size: "auto" # Batch size, can be 'auto', 'auto:N', or an integer. It is strongly recommend to use 'auto' for vllm to speed up the inference
......
dataset_path: meta-llama/Llama-3.1-8B-Instruct-evals
dataset_name: Llama-3.1-8B-Instruct-evals__gpqa__details
task: meta_gpqa
output_type: generate_until
process_docs: !function utils.process_docs
test_split: latest
doc_to_text: !function utils.doc_to_text
doc_to_target: gold
filter_list:
- name: "strict-match"
filter:
- function: "regex"
group_select: -1
regex_pattern: ' ([A-Z])'
- function: "take_first"
generation_kwargs:
until: []
do_sample: false
temperature: 0
max_gen_toks: 2048
num_fewshot: 0
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
metadata:
version: 1.0
import random
import re
import datasets
def doc_to_text(doc: dict) -> str:
return doc["input_final_prompts"][0]
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
def _process_doc(doc: dict) -> dict:
out_doc = {
"problem": doc["input_question"],
"gold": doc["input_correct_responses"][0],
}
return out_doc
dataset = dataset.select_columns(["input_question", "input_correct_responses", "input_final_prompts", "is_correct","input_question_hash","input_choice_list","output_prediction_text"])
dataset = dataset.rename_column("is_correct","previously_is_correct")
dataset = dataset.map(_process_doc)
return dataset.map(_process_doc)
dataset_path: meta-llama/Llama-3.1-8B-Instruct-evals
dataset_name: Llama-3.1-8B-Instruct-evals__gpqa__details
task: meta_gpqa
task: meta_gpqa_cot
output_type: generate_until
process_docs: !function utils.process_docs
test_split: latest
......
dataset_path: parquet
dataset_kwargs:
data_files: ./work_dir/joined_math.parquet
task: meta_math
process_docs: !function utils.process_docs
output_type: generate_until
test_split: train
doc_to_text: !function utils.doc_to_text
process_results: !function utils.process_results
doc_to_target: answer
generation_kwargs:
until: []
do_sample: false
temperature: 0
max_gen_toks: 512
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
\ No newline at end of file
dataset_path: parquet
dataset_kwargs:
data_files: ./work_dir/joined_math.parquet
data_files: ./work_dir/joined_math_hard.parquet
task: meta_math_hard
process_docs: !function utils.process_docs
output_type: generate_until
......
task: meta_mmlu
dataset_path: meta-llama/Llama-3.1-8B-evals
dataset_name: Llama-3.1-8B-evals__mmlu__details
test_split: latest
output_type: multiple_choice
process_docs: !function utils.process_docs
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target
doc_to_choice: ["A", "B", "C", "D"]
# 5-shot prompts are already included in the dataset
# So no need to generate
num_fewshot: 0
metadata:
version: 1.0
\ No newline at end of file
import string
import datasets
def doc_to_text(doc: dict) -> str:
# Strip out the last two characters, which is a space and the answer
# E.g., "Answer: B" -> "Answer:"
return doc["input_final_prompts"][0][:-2]
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
def _process_doc(doc: dict) -> dict:
# input_correct_responses is in format of: "Answer: B"
answer = doc["input_correct_responses"][0]
# Indexes are always A: 0, B: 1, C: 2, D: 3
answer_index = string.ascii_uppercase.index(answer[-1])
out_doc = {
"problem": doc["input_question"],
# The answer is the index of the correct response (0-indexed)
"gold": answer_index,
}
return out_doc
dataset = dataset.select_columns(
["input_question", "input_correct_responses", "input_final_prompts", "is_correct", "input_question_hash",
"input_choice_list"])
dataset = dataset.rename_column("is_correct", "previously_is_correct")
dataset = dataset.map(_process_doc)
return dataset.map(_process_doc)
def doc_to_target(doc: dict) -> str:
return doc["gold"]
......@@ -11,6 +11,24 @@ import nltk
import yaml
from datasets import Dataset, load_dataset
LLAMA_3_1_INSTRUCT_EVALS=[
"meta-llama/Llama-3.1-8B-Instruct-evals",
"meta-llama/Llama-3.1-70B-Instruct-evals",
"meta-llama/Llama-3.1-405B-Instruct-evals",
]
LLAMA_3_1_PRETRAIN_EVALS=[
"meta-llama/Llama-3.1-8B-evals",
"meta-llama/Llama-3.1-70B-evals",
"meta-llama/Llama-3.1-405B-evals",
]
LLAMA_3_2_INSTRUCT_EVALS=[
"meta-llama/Llama-3.2-1B-Instruct-evals",
"meta-llama/Llama-3.2-3B-Instruct-evals",
]
LLAMA_3_2_PRETRAIN_EVALS=[
"meta-llama/Llama-3.2-1B-evals",
"meta-llama/Llama-3.2-3B-evals",
]
# get the ifeval from the evals dataset and join it with the original ifeval datasets
def get_ifeval_data(model_name, output_dir):
......@@ -56,8 +74,8 @@ def get_ifeval_data(model_name, output_dir):
# get the math_hard data from the evals dataset and join it with the original math_hard dataset
def get_math_data(model_name, output_dir):
print(f"preparing the math data using {model_name}'s evals dataset")
def get_math_hard_data(model_name, output_dir):
print(f"preparing the math hard data using {model_name}'s evals dataset")
if model_name not in [
"Llama-3.1-8B-Instruct",
"Llama-3.1-70B-Instruct",
......@@ -74,6 +92,30 @@ def get_math_data(model_name, output_dir):
split="latest",
)
math_data = load_dataset(original_dataset_name, split="test")
joined = join_meta_and_original_math_data(meta_data, math_data)
joined.to_parquet(output_dir + "/joined_math_hard.parquet")
def get_math_data(model_name, output_dir):
print(f"preparing the math data using {model_name}'s evals dataset")
if model_name not in [
"Llama-3.2-1B-Instruct",
"Llama-3.2-3B-Instruct",
]:
raise ValueError(
"Only Llama-3.2-1B-Instruct and Llama-3.2-3B-Instruct models are supported for MATH"
)
original_dataset_name = "lighteval/MATH"
meta_dataset_name = f"meta-llama/{model_name}-evals"
meta_data = load_dataset(
meta_dataset_name,
name=f"{model_name}-evals__math__details",
split="latest",
)
math_data = load_dataset(original_dataset_name, split="test")
joined = join_meta_and_original_math_data(meta_data, math_data)
joined.to_parquet(output_dir + "/joined_math.parquet")
def join_meta_and_original_math_data(meta_data, math_data):
meta_df = meta_data.to_pandas()
math_df = math_data.to_pandas()
math_df = math_df.rename(columns={"problem": "input_question"})
......@@ -94,9 +136,7 @@ def get_math_data(model_name, output_dir):
joined = joined.rename_column(
"output_prediction_text", "previous_output_prediction_text"
)
joined.to_parquet(output_dir + "/joined_math.parquet")
return joined
# get the question from the ifeval dataset
def get_question(example):
......@@ -134,18 +174,33 @@ def change_yaml(args, base_name):
"WORK_DIR", str(yaml_dir)
)
)
# 3.2 evals dataset has a differents set of tasks from 3.1
# Update tasks in meta_pretrain.yaml
with open(args.template_dir + "/meta_pretrain.yaml", "r") as yaml_file:
meta_pretrain = yaml.safe_load(yaml_file)
if args.evals_dataset in LLAMA_3_1_PRETRAIN_EVALS:
meta_pretrain["task"] = ["meta_bbh", "meta_mmlu_pro_pretrain"]
elif args.evals_dataset in LLAMA_3_2_PRETRAIN_EVALS:
meta_pretrain["task"] = ["meta_mmlu"]
with open(args.work_dir + "/meta_pretrain.yaml", "w") as yaml_file:
yaml.dump(meta_pretrain, yaml_file)
# Update tasks in meta_instruct.yaml
with open(args.template_dir + "/meta_instruct.yaml", "r") as yaml_file:
meta_instruct = yaml.safe_load(yaml_file)
if args.evals_dataset in LLAMA_3_1_INSTRUCT_EVALS:
meta_instruct["task"] = ["meta_ifeval", "meta_math_hard", "meta_gpqa_cot", "meta_mmlu_pro_instruct"]
elif args.evals_dataset in LLAMA_3_2_INSTRUCT_EVALS:
meta_instruct["task"] = ["meta_mmlu", "meta_math", "meta_gpqa"]
with open(args.work_dir + "/meta_instruct.yaml", "w") as yaml_file:
yaml.dump(meta_instruct, yaml_file)
# copy the files and change the yaml file to use the correct model name
def copy_and_prepare(args):
# nltk punkt_tab package is needed
nltk.download('punkt_tab')
if not os.path.exists(args.work_dir):
# Copy the all files, including yaml files and python files, from template folder to the work folder
copy_dir(args.template_dir, args.work_dir)
else:
print("work_dir already exists, no need to copy files")
copy_dir(args.template_dir, args.work_dir)
# Use the template yaml to get the correct model name in work_dir yaml
base_name = (
args.evals_dataset.split("/")[-1].replace("-evals", "").replace("-Instruct", "")
......@@ -169,21 +224,22 @@ def prepare_datasets(args):
# model_name are derived from the evals_dataset name
task_list = args.tasks.split(",")
model_name = args.evals_dataset.split("/")[-1].replace("-evals", "")
if "meta_instruct" in task_list:
if "meta_instruct" in task_list and args.evals_dataset in LLAMA_3_1_INSTRUCT_EVALS:
get_ifeval_data(model_name, args.work_dir)
get_math_hard_data(model_name, args.work_dir)
elif "meta_instruct" in task_list and args.evals_dataset in LLAMA_3_2_INSTRUCT_EVALS:
get_math_data(model_name, args.work_dir)
else:
if "meta_ifeval" in task_list:
get_ifeval_data(model_name, args.work_dir)
if "meta_math_hard" in task_list:
get_math_data(model_name, args.work_dir)
get_math_hard_data(model_name, args.work_dir)
# copy the files from src to dst
def copy_dir(src, dst):
try:
shutil.copytree(src, dst)
shutil.copytree(src, dst, dirs_exist_ok=True)
except OSError as exc: # python >2.5
if exc.errno in (errno.ENOTDIR, errno.EINVAL):
shutil.copy(src, dst)
......@@ -207,16 +263,14 @@ if __name__ == "__main__":
args.__setattr__(k, v)
if not os.path.exists(args.template_dir):
raise ValueError("The template_dir does not exist, please check the path")
if args.evals_dataset not in [
"meta-llama/Llama-3.1-8B-Instruct-evals",
"meta-llama/Llama-3.1-70B-Instruct-evals",
"meta-llama/Llama-3.1-405B-Instruct-evals",
"meta-llama/Llama-3.1-8B-evals",
"meta-llama/Llama-3.1-70B-evals",
"meta-llama/Llama-3.1-405B-evals",
]:
if args.evals_dataset not in (
LLAMA_3_1_INSTRUCT_EVALS +
LLAMA_3_1_PRETRAIN_EVALS +
LLAMA_3_2_INSTRUCT_EVALS +
LLAMA_3_2_PRETRAIN_EVALS
):
raise ValueError(
"The evals dataset is not valid, please double check the name, must use the name in the Llama 3.1 Evals collection"
"The evals dataset is not valid, please double check the name, must use the name in the Llama 3.1 or 3.2 Evals collection."
)
args.model_args = f"pretrained={args.model_name},tensor_parallel_size={args.tensor_parallel_size},dtype=auto,gpu_memory_utilization={args.gpu_memory_utilization},data_parallel_size={args.data_parallel_size},max_model_len={args.max_model_len},add_bos_token=True,seed=42"
# Copy the all files from template folder to the work folder
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment