Skip to content
Snippets Groups Projects
Commit 7c35e218 authored by Zhengxiao Du's avatar Zhengxiao Du
Browse files

Merge branch 'main' into glm

parents d9a54331 52077b84
Branches
Tags
No related merge requests found
......@@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module):
class PositionEmbeddingMixin(BaseMixin):
def __init__(self, additional_sequence_length, hidden_size,
init_method_std=0.02, reinit_slice=(-1024, None)
init_method_std=0.02, reinit_slice=slice(-1024, None)
):
super(PositionEmbeddingMixin, self).__init__()
self.reinit_slice = reinit_slice
......
......@@ -337,10 +337,10 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
# and all reduce metrics by the way
loss_checker = lm_loss.detach()
for name in metrics:
metrics[name] = metrics[name].detach()
metrics[name] = metrics[name].detach().clone()
torch.distributed.all_reduce(metrics[name].data)
metrics[name].data /= args.world_size
loss_checker += metrics[name]
loss_checker = loss_checker + metrics[name]
if loss_checker.isnan().any() or loss_checker.isinf().any():
print('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!')
return lm_loss.detach(), 1, metrics
......
......@@ -39,7 +39,9 @@ def save_checkpoint(iteration, model, optimizer,
lr_scheduler, args):
"""Save a model checkpoint."""
if args.deepspeed:
save_ds_checkpoint(iteration, model, lr_scheduler, args)
if mpu.get_data_parallel_rank() == 0:
print('Saving Model...')
save_ds_checkpoint(iteration, model, lr_scheduler, args)
else:
raise ValueError("training without deepspeed is not supported.")
# Wait so everyone is done (necessary)
......@@ -74,8 +76,6 @@ def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save
os.makedirs(save_dir, exist_ok=True)
# Ensure tag is a string
tag = str(tag)
# Ensure checkpoint tag is consistent across ranks
model._checkpoint_tag_validation(tag)
# Real save via deepspeed
model._create_checkpoint_file(save_dir, tag, False)
model._save_checkpoint(save_dir, tag, client_state=client_state)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment