Skip to content
Snippets Groups Projects
Commit 3e304e54 authored by Ming Ding's avatar Ming Ding
Browse files

fix loss_checker bug & slice

parent 8584bf91
Branches
Tags
No related merge requests found
......@@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module):
class PositionEmbeddingMixin(BaseMixin):
def __init__(self, additional_sequence_length, hidden_size,
init_method_std=0.02, reinit_slice=(-1024, None)
init_method_std=0.02, reinit_slice=slice(-1024, None)
):
super(PositionEmbeddingMixin, self).__init__()
self.reinit_slice = reinit_slice
......
......@@ -337,10 +337,10 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
# and all reduce metrics by the way
loss_checker = lm_loss.detach()
for name in metrics:
metrics[name] = metrics[name].detach()
metrics[name] = metrics[name].detach().clone()
torch.distributed.all_reduce(metrics[name].data)
metrics[name].data /= args.world_size
loss_checker += metrics[name]
loss_checker = loss_checker + metrics[name]
if loss_checker.isnan().any() or loss_checker.isinf().any():
print('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!')
return lm_loss.detach(), 1, metrics
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment