diff --git a/model/mixins.py b/model/mixins.py index 78befb8b27637deb7bf7c98bdcfde408eb28b483..125f8b0e2c7ccda41c4125892637046a041f7186 100644 --- a/model/mixins.py +++ b/model/mixins.py @@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module): class PositionEmbeddingMixin(BaseMixin): def __init__(self, additional_sequence_length, hidden_size, - init_method_std=0.02, reinit_slice=(-1024, None) + init_method_std=0.02, reinit_slice=slice(-1024, None) ): super(PositionEmbeddingMixin, self).__init__() self.reinit_slice = reinit_slice diff --git a/training/deepspeed_training.py b/training/deepspeed_training.py index 90190a6f95eb775808cc058b9f912d3a9b6555d2..094d49b25f28f857d25572157071802307b1b641 100644 --- a/training/deepspeed_training.py +++ b/training/deepspeed_training.py @@ -337,10 +337,10 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, # and all reduce metrics by the way loss_checker = lm_loss.detach() for name in metrics: - metrics[name] = metrics[name].detach() + metrics[name] = metrics[name].detach().clone() torch.distributed.all_reduce(metrics[name].data) metrics[name].data /= args.world_size - loss_checker += metrics[name] + loss_checker = loss_checker + metrics[name] if loss_checker.isnan().any() or loss_checker.isinf().any(): print('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!') return lm_loss.detach(), 1, metrics diff --git a/training/model_io.py b/training/model_io.py index 8de74ea100da51f0212a53b64bf2ebd306cf786f..1655ed20999230346b0dd20c7c547c52a38fb246 100644 --- a/training/model_io.py +++ b/training/model_io.py @@ -39,7 +39,9 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, args): """Save a model checkpoint.""" if args.deepspeed: - save_ds_checkpoint(iteration, model, lr_scheduler, args) + if mpu.get_data_parallel_rank() == 0: + print('Saving Model...') + save_ds_checkpoint(iteration, model, lr_scheduler, args) else: raise ValueError("training without deepspeed is not supported.") # Wait so everyone is done (necessary) @@ -74,8 +76,6 @@ def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save os.makedirs(save_dir, exist_ok=True) # Ensure tag is a string tag = str(tag) - # Ensure checkpoint tag is consistent across ranks - model._checkpoint_tag_validation(tag) # Real save via deepspeed model._create_checkpoint_file(save_dir, tag, False) model._save_checkpoint(save_dir, tag, client_state=client_state)