Skip to content
Snippets Groups Projects
Commit 52077b84 authored by Ming Ding's avatar Ming Ding
Browse files

fixed mutiple save bug

parent 3e304e54
No related branches found
No related tags found
No related merge requests found
...@@ -35,7 +35,9 @@ def save_checkpoint(iteration, model, optimizer, ...@@ -35,7 +35,9 @@ def save_checkpoint(iteration, model, optimizer,
lr_scheduler, args): lr_scheduler, args):
"""Save a model checkpoint.""" """Save a model checkpoint."""
if args.deepspeed: if args.deepspeed:
save_ds_checkpoint(iteration, model, lr_scheduler, args) if mpu.get_data_parallel_rank() == 0:
print('Saving Model...')
save_ds_checkpoint(iteration, model, lr_scheduler, args)
else: else:
raise ValueError("training without deepspeed is not supported.") raise ValueError("training without deepspeed is not supported.")
# Wait so everyone is done (necessary) # Wait so everyone is done (necessary)
...@@ -70,8 +72,6 @@ def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save ...@@ -70,8 +72,6 @@ def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
# Ensure tag is a string # Ensure tag is a string
tag = str(tag) tag = str(tag)
# Ensure checkpoint tag is consistent across ranks
model._checkpoint_tag_validation(tag)
# Real save via deepspeed # Real save via deepspeed
model._create_checkpoint_file(save_dir, tag, False) model._create_checkpoint_file(save_dir, tag, False)
model._save_checkpoint(save_dir, tag, client_state=client_state) model._save_checkpoint(save_dir, tag, client_state=client_state)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment