From 91e2573aa848fa1570d24a22dd3a5a341d946cd6 Mon Sep 17 00:00:00 2001 From: Shijie Wu <sjwu@hey.com> Date: Thu, 14 Sep 2023 10:24:48 -0400 Subject: [PATCH] pass weight_decay into optimizer --- src/llama_recipes/finetuning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 5f7f8b24..79f971a0 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -226,12 +226,13 @@ def main(**kwargs): momentum_dtype=torch.bfloat16, variance_dtype=torch.bfloat16, use_kahan_summation=False, + weight_decay=train_config.weight_decay, ) else: optimizer = optim.AdamW( model.parameters(), lr=train_config.lr, - weight_decay=0.0, + weight_decay=train_config.weight_decay, ) scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) -- GitLab