Skip to content
Snippets Groups Projects
Commit d2136a73 authored by Ming Ding's avatar Ming Ding
Browse files

fix nan loss skip failure bug

parent b43d3553
No related branches found
No related tags found
No related merge requests found
......@@ -340,7 +340,11 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
# Check nan or inf in forward, preventing it from interfering loss scaler,
# and all reduce metrics by the way
loss_checker = lm_loss.detach()
lm_loss_reduced = lm_loss.detach().clone()
torch.distributed.all_reduce(lm_loss_reduced.data)
lm_loss_reduced.data = lm_loss_reduced.data / args.world_size
loss_checker = lm_loss_reduced
for name in metrics:
metrics[name] = metrics[name].detach().clone()
torch.distributed.all_reduce(metrics[name].data)
......@@ -352,7 +356,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
lm_loss_reduced = backward_step(optimizer, model, lm_loss, args, timers)
backward_step(optimizer, model, lm_loss, args, timers)
timers('backward').stop()
# Update parameters.
skipped_iter, complete = 0, False
......@@ -383,18 +387,12 @@ def backward_step(optimizer, model, loss, args, timers):
else:
raise ValueError('Currently, we only support training with deepspeed.')
reduced_losses = loss.view(1)
# Reduce losses for logging
torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data / args.world_size
if args.deepspeed:
# DeepSpeed backward propagation already addressed all reduce communication.
# Reset the timer to avoid breaking timer logs below.
timers('allreduce').reset()
return reduced_losses
return
def evaluate(data_iterator, model, args, timers, verbose=False, hooks={}):
"""Evaluation."""
......
......@@ -109,6 +109,8 @@ class Timers:
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
if name not in self.timers:
continue
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment