diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py index 413a7c205135d4020ed2df990e30f262caf23d39..eaaea93606fb5f78b4a0818ed235652230a3d994 100644 --- a/SwissArmyTransformer/training/deepspeed_training.py +++ b/SwissArmyTransformer/training/deepspeed_training.py @@ -328,8 +328,11 @@ def train(model, optimizer, lr_scheduler, def train_step(data_iterator, model, optimizer, lr_scheduler, - args, timers, hooks={}): + args, timers, hooks=None, single_step=False): """Single training step.""" + if hooks is None: + hooks = {} + lm_loss_total, metrics_total, count = 0.0, {}, 0 forward_step = hooks['forward_step'] while True: @@ -354,6 +357,13 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, print('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!') return lm_loss.detach(), 1, metrics + # Accumulate the statistics + lm_loss_total += lm_loss_reduced + for name in metrics: + if name not in metrics_total: + metrics_total[name] = 0.0 + metrics_total[name] += metrics[name] + count += 1 # Calculate gradients, reduce across processes, and clip. timers('backward').start() backward_step(optimizer, model, lm_loss, args, timers) @@ -374,7 +384,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, else: raise ValueError('Currently, we only support training with deepspeed.') timers('optimizer').stop() - if complete: + if complete or single_step: break return lm_loss_reduced, skipped_iter, metrics