From 3e304e54f326e609575d71cd8641011e9edaf08e Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Tue, 19 Oct 2021 09:05:12 +0000 Subject: [PATCH] fix loss_checker bug & slice --- model/mixins.py | 2 +- training/deepspeed_training.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/model/mixins.py b/model/mixins.py index 78befb8..125f8b0 100644 --- a/model/mixins.py +++ b/model/mixins.py @@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module): class PositionEmbeddingMixin(BaseMixin): def __init__(self, additional_sequence_length, hidden_size, - init_method_std=0.02, reinit_slice=(-1024, None) + init_method_std=0.02, reinit_slice=slice(-1024, None) ): super(PositionEmbeddingMixin, self).__init__() self.reinit_slice = reinit_slice diff --git a/training/deepspeed_training.py b/training/deepspeed_training.py index 90190a6..094d49b 100644 --- a/training/deepspeed_training.py +++ b/training/deepspeed_training.py @@ -337,10 +337,10 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, # and all reduce metrics by the way loss_checker = lm_loss.detach() for name in metrics: - metrics[name] = metrics[name].detach() + metrics[name] = metrics[name].detach().clone() torch.distributed.all_reduce(metrics[name].data) metrics[name].data /= args.world_size - loss_checker += metrics[name] + loss_checker = loss_checker + metrics[name] if loss_checker.isnan().any() or loss_checker.isinf().any(): print('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!') return lm_loss.detach(), 1, metrics -- GitLab