diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py
index 413a7c205135d4020ed2df990e30f262caf23d39..eaaea93606fb5f78b4a0818ed235652230a3d994 100644
--- a/SwissArmyTransformer/training/deepspeed_training.py
+++ b/SwissArmyTransformer/training/deepspeed_training.py
@@ -328,8 +328,11 @@ def train(model, optimizer, lr_scheduler,
 
 
 def train_step(data_iterator, model, optimizer, lr_scheduler,
-               args, timers, hooks={}):
+               args, timers, hooks=None, single_step=False):
     """Single training step."""
+    if hooks is None:
+        hooks = {}
+    lm_loss_total, metrics_total, count = 0.0, {}, 0
     forward_step = hooks['forward_step']
     
     while True:
@@ -354,6 +357,13 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
             print('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!')
             return lm_loss.detach(), 1, metrics
 
+        # Accumulate the statistics
+        lm_loss_total += lm_loss_reduced
+        for name in metrics:
+            if name not in metrics_total:
+                metrics_total[name] = 0.0
+            metrics_total[name] += metrics[name]
+        count += 1
         # Calculate gradients, reduce across processes, and clip.
         timers('backward').start()
         backward_step(optimizer, model, lm_loss, args, timers)
@@ -374,7 +384,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
         else:
             raise ValueError('Currently, we only support training with deepspeed.')
         timers('optimizer').stop()
-        if complete:
+        if complete or single_step:
             break
     return lm_loss_reduced, skipped_iter, metrics