diff --git a/tests/test_finetuning.py b/tests/test_finetuning.py index abaa4bd813e87540d7608e919e8c42a6983bd9f6..129854680e83d262de3e802d9a992f104a73dac1 100644 --- a/tests/test_finetuning.py +++ b/tests/test_finetuning.py @@ -1,9 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +from pytest import approx from unittest.mock import patch -import importlib +from torch.nn import Linear +from torch.optim import AdamW from torch.utils.data.dataloader import DataLoader from llama_recipes.finetuning import main @@ -72,4 +74,31 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge main(**kwargs) assert get_peft_model.return_value.to.call_args.args[0] == "cuda" - assert get_peft_model.return_value.print_trainable_parameters.call_count == 1 \ No newline at end of file + assert get_peft_model.return_value.print_trainable_parameters.call_count == 1 + + +@patch('llama_recipes.finetuning.train') +@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') +@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained') +@patch('llama_recipes.finetuning.get_preprocessed_dataset') +@patch('llama_recipes.finetuning.get_peft_model') +@patch('llama_recipes.finetuning.StepLR') +def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train): + kwargs = {"weight_decay": 0.01} + + get_dataset.return_value = [1] + + get_peft_model.return_value = Linear(1,1) + get_peft_model.return_value.print_trainable_parameters=lambda:None + main(**kwargs) + + assert train.call_count == 1 + + args, kwargs = train.call_args + optimizer = args[4] + + print(optimizer.state_dict()) + + assert isinstance(optimizer, AdamW) + assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01) + \ No newline at end of file