From aff0493f87e30ef9929373de34d532918e2590b4 Mon Sep 17 00:00:00 2001
From: Zhengxiao Du <zx-du20@mails.tsinghua.edu.cn>
Date: Sat, 6 Nov 2021 17:09:07 +0800
Subject: [PATCH] Fix train_step while gradient accumulation

---
 SwissArmyTransformer/training/deepspeed_training.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py
index eaaea93..e748176 100644
--- a/SwissArmyTransformer/training/deepspeed_training.py
+++ b/SwissArmyTransformer/training/deepspeed_training.py
@@ -386,7 +386,9 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
         timers('optimizer').stop()
         if complete or single_step:
             break
-    return lm_loss_reduced, skipped_iter, metrics
+    lm_loss_total /= count
+    metrics_total = {key: value / count for key, value in metrics_total.items()}
+    return lm_loss_total, skipped_iter, metrics_total
 
 def backward_step(optimizer, model, loss, args, timers):
     """Backward step."""
-- 
GitLab