Skip to content
Snippets Groups Projects
Commit 653ab4db authored by Zhengxiao Du's avatar Zhengxiao Du
Browse files

Fix train_step while gradient accumulation

parent d2136a73
No related branches found
No related tags found
No related merge requests found
......@@ -328,8 +328,11 @@ def train(model, optimizer, lr_scheduler,
def train_step(data_iterator, model, optimizer, lr_scheduler,
args, timers, hooks={}):
args, timers, hooks=None, single_step=False):
"""Single training step."""
if hooks is None:
hooks = {}
lm_loss_total, metrics_total, count = 0.0, {}, 0
forward_step = hooks['forward_step']
while True:
......@@ -354,6 +357,13 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
print('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!')
return lm_loss.detach(), 1, metrics
# Accumulate the statistics
lm_loss_total += lm_loss_reduced
for name in metrics:
if name not in metrics_total:
metrics_total[name] = 0.0
metrics_total[name] += metrics[name]
count += 1
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
backward_step(optimizer, model, lm_loss, args, timers)
......@@ -374,7 +384,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
else:
raise ValueError('Currently, we only support training with deepspeed.')
timers('optimizer').stop()
if complete:
if complete or single_step:
break
return lm_loss_reduced, skipped_iter, metrics
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment