From 0b2fa40dba83fd625bc0f1119b0172c0575b5a57 Mon Sep 17 00:00:00 2001
From: Matthias Reso <13337103+mreso@users.noreply.github.com>
Date: Thu, 14 Sep 2023 14:56:59 +0000
Subject: [PATCH] Add unit test for weight decay

---
 tests/test_finetuning.py | 33 +++++++++++++++++++++++++++++++--
 1 file changed, 31 insertions(+), 2 deletions(-)

diff --git a/tests/test_finetuning.py b/tests/test_finetuning.py
index abaa4bd8..12985468 100644
--- a/tests/test_finetuning.py
+++ b/tests/test_finetuning.py
@@ -1,9 +1,11 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+from pytest import approx
 from unittest.mock import patch
-import importlib
 
+from torch.nn import Linear
+from torch.optim import AdamW
 from torch.utils.data.dataloader import DataLoader
 
 from llama_recipes.finetuning import main
@@ -72,4 +74,31 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
     main(**kwargs)
     
     assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
-    assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
\ No newline at end of file
+    assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
+    
+    
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.get_preprocessed_dataset')
+@patch('llama_recipes.finetuning.get_peft_model')
+@patch('llama_recipes.finetuning.StepLR')
+def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
+    kwargs = {"weight_decay": 0.01}
+    
+    get_dataset.return_value = [1]
+    
+    get_peft_model.return_value = Linear(1,1)
+    get_peft_model.return_value.print_trainable_parameters=lambda:None
+    main(**kwargs)
+    
+    assert train.call_count == 1
+    
+    args, kwargs = train.call_args
+    optimizer = args[4]
+    
+    print(optimizer.state_dict())
+    
+    assert isinstance(optimizer, AdamW)
+    assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
+    
\ No newline at end of file
-- 
GitLab