Skip to content
Snippets Groups Projects
Commit 0b2fa40d authored by Matthias Reso's avatar Matthias Reso
Browse files

Add unit test for weight decay

parent 91e2573a
Branches
Tags
No related merge requests found
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from pytest import approx
from unittest.mock import patch from unittest.mock import patch
import importlib
from torch.nn import Linear
from torch.optim import AdamW
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from llama_recipes.finetuning import main from llama_recipes.finetuning import main
...@@ -72,4 +74,31 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge ...@@ -72,4 +74,31 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
main(**kwargs) main(**kwargs)
assert get_peft_model.return_value.to.call_args.args[0] == "cuda" assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
assert get_peft_model.return_value.print_trainable_parameters.call_count == 1 assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
\ No newline at end of file
@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
@patch('llama_recipes.finetuning.get_peft_model')
@patch('llama_recipes.finetuning.StepLR')
def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
kwargs = {"weight_decay": 0.01}
get_dataset.return_value = [1]
get_peft_model.return_value = Linear(1,1)
get_peft_model.return_value.print_trainable_parameters=lambda:None
main(**kwargs)
assert train.call_count == 1
args, kwargs = train.call_args
optimizer = args[4]
print(optimizer.state_dict())
assert isinstance(optimizer, AdamW)
assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment