From aa5dee241a4473323a47f0c46379827932f85e51 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Mon, 2 Oct 2023 13:06:54 -0700 Subject: [PATCH] Fix unit test to reflect batch packing --- tests/test_finetuning.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/tests/test_finetuning.py b/tests/test_finetuning.py index 6651a22a..e0d6b324 100644 --- a/tests/test_finetuning.py +++ b/tests/test_finetuning.py @@ -13,6 +13,15 @@ from torch.utils.data.sampler import BatchSampler from llama_recipes.finetuning import main from llama_recipes.data.sampler import LengthBasedBatchSampler + +def get_fake_dataset(): + return [{ + "input_ids":[1], + "attention_mask":[1], + "labels":[1], + }] + + @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained') @@ -22,7 +31,7 @@ from llama_recipes.data.sampler import LengthBasedBatchSampler def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train): kwargs = {"run_validation": False} - get_dataset.return_value = [[1]] + get_dataset.return_value = get_fake_dataset() main(**kwargs) @@ -46,7 +55,8 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge @patch('llama_recipes.finetuning.StepLR') def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train): kwargs = {"run_validation": True} - get_dataset.return_value = [[1]] + + get_dataset.return_value = get_fake_dataset() main(**kwargs) @@ -72,7 +82,7 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train): kwargs = {"use_peft": True} - get_dataset.return_value = [[1]] + get_dataset.return_value = get_fake_dataset() main(**kwargs) @@ -89,7 +99,7 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train): kwargs = {"weight_decay": 0.01} - get_dataset.return_value = [[1]] + get_dataset.return_value = get_fake_dataset() get_peft_model.return_value = Linear(1,1) get_peft_model.return_value.print_trainable_parameters=lambda:None @@ -113,9 +123,12 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train): - kwargs = {"batching_strategy": "packing"} + kwargs = { + "batching_strategy": "packing", + "use_peft": False, + } - get_dataset.return_value = [[1]] + get_dataset.return_value = get_fake_dataset() main(**kwargs) -- GitLab