From 91e2573aa848fa1570d24a22dd3a5a341d946cd6 Mon Sep 17 00:00:00 2001
From: Shijie Wu <sjwu@hey.com>
Date: Thu, 14 Sep 2023 10:24:48 -0400
Subject: [PATCH] pass weight_decay into optimizer

---
 src/llama_recipes/finetuning.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index 5f7f8b24..79f971a0 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -226,12 +226,13 @@ def main(**kwargs):
             momentum_dtype=torch.bfloat16,
             variance_dtype=torch.bfloat16,
             use_kahan_summation=False,
+            weight_decay=train_config.weight_decay,
         )
     else:
         optimizer = optim.AdamW(
             model.parameters(),
             lr=train_config.lr,
-            weight_decay=0.0,
+            weight_decay=train_config.weight_decay,
         )
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
 
-- 
GitLab