diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py
index eaaea93606fb5f78b4a0818ed235652230a3d994..e7481769970e6e95525159b5eaeb1a3a4d389b53 100644
--- a/SwissArmyTransformer/training/deepspeed_training.py
+++ b/SwissArmyTransformer/training/deepspeed_training.py
@@ -386,7 +386,9 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
         timers('optimizer').stop()
         if complete or single_step:
             break
-    return lm_loss_reduced, skipped_iter, metrics
+    lm_loss_total /= count
+    metrics_total = {key: value / count for key, value in metrics_total.items()}
+    return lm_loss_total, skipped_iter, metrics_total
 
 def backward_step(optimizer, model, loss, args, timers):
     """Backward step."""