Skip to content
Snippets Groups Projects
Commit aa5dee24 authored by Matthias Reso's avatar Matthias Reso
Browse files

Fix unit test to reflect batch packing

parent 8620ab8a
Branches
Tags
No related merge requests found
......@@ -13,6 +13,15 @@ from torch.utils.data.sampler import BatchSampler
from llama_recipes.finetuning import main
from llama_recipes.data.sampler import LengthBasedBatchSampler
def get_fake_dataset():
return [{
"input_ids":[1],
"attention_mask":[1],
"labels":[1],
}]
@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
......@@ -22,7 +31,7 @@ from llama_recipes.data.sampler import LengthBasedBatchSampler
def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
kwargs = {"run_validation": False}
get_dataset.return_value = [[1]]
get_dataset.return_value = get_fake_dataset()
main(**kwargs)
......@@ -46,7 +55,8 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
@patch('llama_recipes.finetuning.StepLR')
def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
kwargs = {"run_validation": True}
get_dataset.return_value = [[1]]
get_dataset.return_value = get_fake_dataset()
main(**kwargs)
......@@ -72,7 +82,7 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
kwargs = {"use_peft": True}
get_dataset.return_value = [[1]]
get_dataset.return_value = get_fake_dataset()
main(**kwargs)
......@@ -89,7 +99,7 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
kwargs = {"weight_decay": 0.01}
get_dataset.return_value = [[1]]
get_dataset.return_value = get_fake_dataset()
get_peft_model.return_value = Linear(1,1)
get_peft_model.return_value.print_trainable_parameters=lambda:None
......@@ -113,9 +123,12 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')
def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
kwargs = {"batching_strategy": "packing"}
kwargs = {
"batching_strategy": "packing",
"use_peft": False,
}
get_dataset.return_value = [[1]]
get_dataset.return_value = get_fake_dataset()
main(**kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment