Skip to content
Snippets Groups Projects
Commit 91e2573a authored by Shijie Wu's avatar Shijie Wu Committed by Matthias Reso
Browse files

pass weight_decay into optimizer

parent c38bf5bd
No related branches found
No related tags found
No related merge requests found
......@@ -226,12 +226,13 @@ def main(**kwargs):
momentum_dtype=torch.bfloat16,
variance_dtype=torch.bfloat16,
use_kahan_summation=False,
weight_decay=train_config.weight_decay,
)
else:
optimizer = optim.AdamW(
model.parameters(),
lr=train_config.lr,
weight_decay=0.0,
weight_decay=train_config.weight_decay,
)
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment