Skip to content
Snippets Groups Projects
Commit 7c35e218 authored by Zhengxiao Du's avatar Zhengxiao Du
Browse files

Merge branch 'main' into glm

parents d9a54331 52077b84
No related branches found
No related tags found
No related merge requests found
...@@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module): ...@@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module):
class PositionEmbeddingMixin(BaseMixin): class PositionEmbeddingMixin(BaseMixin):
def __init__(self, additional_sequence_length, hidden_size, def __init__(self, additional_sequence_length, hidden_size,
init_method_std=0.02, reinit_slice=(-1024, None) init_method_std=0.02, reinit_slice=slice(-1024, None)
): ):
super(PositionEmbeddingMixin, self).__init__() super(PositionEmbeddingMixin, self).__init__()
self.reinit_slice = reinit_slice self.reinit_slice = reinit_slice
......
...@@ -337,10 +337,10 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, ...@@ -337,10 +337,10 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
# and all reduce metrics by the way # and all reduce metrics by the way
loss_checker = lm_loss.detach() loss_checker = lm_loss.detach()
for name in metrics: for name in metrics:
metrics[name] = metrics[name].detach() metrics[name] = metrics[name].detach().clone()
torch.distributed.all_reduce(metrics[name].data) torch.distributed.all_reduce(metrics[name].data)
metrics[name].data /= args.world_size metrics[name].data /= args.world_size
loss_checker += metrics[name] loss_checker = loss_checker + metrics[name]
if loss_checker.isnan().any() or loss_checker.isinf().any(): if loss_checker.isnan().any() or loss_checker.isinf().any():
print('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!') print('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!')
return lm_loss.detach(), 1, metrics return lm_loss.detach(), 1, metrics
......
...@@ -39,7 +39,9 @@ def save_checkpoint(iteration, model, optimizer, ...@@ -39,7 +39,9 @@ def save_checkpoint(iteration, model, optimizer,
lr_scheduler, args): lr_scheduler, args):
"""Save a model checkpoint.""" """Save a model checkpoint."""
if args.deepspeed: if args.deepspeed:
save_ds_checkpoint(iteration, model, lr_scheduler, args) if mpu.get_data_parallel_rank() == 0:
print('Saving Model...')
save_ds_checkpoint(iteration, model, lr_scheduler, args)
else: else:
raise ValueError("training without deepspeed is not supported.") raise ValueError("training without deepspeed is not supported.")
# Wait so everyone is done (necessary) # Wait so everyone is done (necessary)
...@@ -74,8 +76,6 @@ def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save ...@@ -74,8 +76,6 @@ def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
# Ensure tag is a string # Ensure tag is a string
tag = str(tag) tag = str(tag)
# Ensure checkpoint tag is consistent across ranks
model._checkpoint_tag_validation(tag)
# Real save via deepspeed # Real save via deepspeed
model._create_checkpoint_file(save_dir, tag, False) model._create_checkpoint_file(save_dir, tag, False)
model._save_checkpoint(save_dir, tag, client_state=client_state) model._save_checkpoint(save_dir, tag, client_state=client_state)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment