diff --git a/tests/test_finetuning.py b/tests/test_finetuning.py index 6651a22a2b7cb1bcc693cd63b87bc4ce783a7244..e0d6b324aee89478cb0fa67f7af4cea8f138ab76 100644 --- a/tests/test_finetuning.py +++ b/tests/test_finetuning.py @@ -13,6 +13,15 @@ from torch.utils.data.sampler import BatchSampler from llama_recipes.finetuning import main from llama_recipes.data.sampler import LengthBasedBatchSampler + +def get_fake_dataset(): + return [{ + "input_ids":[1], + "attention_mask":[1], + "labels":[1], + }] + + @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained') @@ -22,7 +31,7 @@ from llama_recipes.data.sampler import LengthBasedBatchSampler def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train): kwargs = {"run_validation": False} - get_dataset.return_value = [[1]] + get_dataset.return_value = get_fake_dataset() main(**kwargs) @@ -46,7 +55,8 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge @patch('llama_recipes.finetuning.StepLR') def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train): kwargs = {"run_validation": True} - get_dataset.return_value = [[1]] + + get_dataset.return_value = get_fake_dataset() main(**kwargs) @@ -72,7 +82,7 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train): kwargs = {"use_peft": True} - get_dataset.return_value = [[1]] + get_dataset.return_value = get_fake_dataset() main(**kwargs) @@ -89,7 +99,7 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train): kwargs = {"weight_decay": 0.01} - get_dataset.return_value = [[1]] + get_dataset.return_value = get_fake_dataset() get_peft_model.return_value = Linear(1,1) get_peft_model.return_value.print_trainable_parameters=lambda:None @@ -113,9 +123,12 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train): - kwargs = {"batching_strategy": "packing"} + kwargs = { + "batching_strategy": "packing", + "use_peft": False, + } - get_dataset.return_value = [[1]] + get_dataset.return_value = get_fake_dataset() main(**kwargs)