From aff0493f87e30ef9929373de34d532918e2590b4 Mon Sep 17 00:00:00 2001 From: Zhengxiao Du <zx-du20@mails.tsinghua.edu.cn> Date: Sat, 6 Nov 2021 17:09:07 +0800 Subject: [PATCH] Fix train_step while gradient accumulation --- SwissArmyTransformer/training/deepspeed_training.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py index eaaea93..e748176 100644 --- a/SwissArmyTransformer/training/deepspeed_training.py +++ b/SwissArmyTransformer/training/deepspeed_training.py @@ -386,7 +386,9 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, timers('optimizer').stop() if complete or single_step: break - return lm_loss_reduced, skipped_iter, metrics + lm_loss_total /= count + metrics_total = {key: value / count for key, value in metrics_total.items()} + return lm_loss_total, skipped_iter, metrics_total def backward_step(optimizer, model, loss, args, timers): """Backward step.""" -- GitLab