Skip to content
Snippets Groups Projects
Commit bedb96b7 authored by Hamid Shojanazeri's avatar Hamid Shojanazeri
Browse files

fixing the full state path in checkpoint handler

parent 74bde65a
No related branches found
No related tags found
No related merge requests found
......@@ -4,8 +4,6 @@
from .checkpoint_handler import (
load_model_checkpoint,
save_model_checkpoint,
save_distributed_model_checkpoint,
load_distributed_model_checkpoint,
load_optimizer_checkpoint,
save_optimizer_checkpoint,
save_model_and_optimizer_sharded,
......
......@@ -44,7 +44,7 @@ def get_date_of_run():
fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
def load_model_sharded(model, rank, cfg, verbose=True):
def load_model_sharded(model, rank, cfg):
# torch.manual_seed(103)
folder_name = (
cfg.dist_checkpoint_root_folder
......@@ -83,7 +83,7 @@ def load_model_sharded(model, rank, cfg, verbose=True):
print(f"Sharded state checkpoint loaded from {load_dir}")
def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True):
def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
"""save model and optimizer via sharded_state_dict to save_dir"""
folder_name = (
......@@ -142,7 +142,14 @@ def save_model_checkpoint(
if rank == 0:
print(f"--> saving model ...")
# create save path
save_dir = Path.cwd() / cfg.checkpoint_folder
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)
save_dir = Path.cwd() / folder_name
save_dir.mkdir(parents=True, exist_ok=True)
save_name = cfg.model_name + "-" + str(epoch) + ".pt"
save_full_path = str(save_dir) + "/" + save_name
......@@ -150,12 +157,12 @@ def save_model_checkpoint(
# save model
torch.save(cpu_state, save_full_path)
if cfg.verbose:
print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
def load_model_checkpoint(model, rank, cfg, verbose=True):
def load_model_checkpoint(model, rank, cfg):
"""load local checkpoint to rank0 cpu
must be called * before * passing to FSDP"""
......@@ -178,8 +185,8 @@ def load_model_checkpoint(model, rank, cfg, verbose=True):
# integrate into loaded model
model.load_state_dict(model_checkpoint)
if cfg.verbose:
print(f"model checkpoint loaded to rank0 cpu")
print(f"model checkpoint loaded to rank0 cpu")
def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
......@@ -192,15 +199,22 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
optim_state = FSDP.full_optim_state_dict(model, optimizer)
if cfg.verbose:
print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
if rank == 0:
save_dir = Path.cwd() / cfg.checkpoint_folder
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)
save_dir = Path.cwd() / folder_name
save_dir.mkdir(parents=True, exist_ok=True)
opt_save_name = (
cfg.optimizer_name + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
"optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
)
opt_save_full_path = save_dir / opt_save_name
......@@ -211,96 +225,25 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
print(f"--> saved {opt_save_full_path} to disk")
def load_optimizer_checkpoint(model, optimizer, rank, cfg):
def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
"""load an fsdp optimizer full_state checkpoint using scatter method
this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
"""
opt_file_path = Path.cwd() / cfg.checkpoint_folder / cfg.optimizer_checkpoint_file
if not opt_file_path.is_file():
if not optimizer_checkpoint_path.is_file():
print(
f"warning - optimizer checkpoint not present {opt_file_path}. Returning. "
f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. "
)
return
full_osd = None
if rank == 0:
full_osd = torch.load(opt_file_path)
if cfg.verbose:
print(f"loaded full osd on rank 0")
full_osd = torch.load(optimizer_checkpoint_path)
# called from all ranks, though only rank0 has a valid param for full_osd
sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
if cfg.verbose:
print(f"optimizer shard loaded on rank {rank}")
print(f"optimizer shard loaded on rank {rank}")
def load_distributed_model_checkpoint(model, rank, cfg):
if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
print(f"loading distributed checkpoint, rank {rank}...")
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)
checkdir = Path.cwd() / folder_name
if not checkdir.exists():
if rank == 0:
print(f"No checkpoint directory found...skipping")
return
reader = FileSystemReader(checkdir)
with FSDP.state_dict_type(
model,
StateDictType.LOCAL_STATE_DICT,
):
state_dict = model.state_dict()
load_state_dict(state_dict, reader)
model.load_state_dict(state_dict)
print(f"--> local state loaded on rank {rank}")
return
def save_distributed_model_checkpoint(model, rank, cfg, epoch=1):
# distributed checkpoint saving
# confirm type of checkpoint and save
if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
# create writer to current path
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)
save_dir = Path.cwd() / folder_name
writer = FileSystemWriter(
save_dir,
)
with FSDP.state_dict_type(
model,
StateDictType.LOCAL_STATE_DICT,
):
state_dict = model.state_dict()
# write out distributed checkpoint
save_state_dict(state_dict, writer)
return
......@@ -84,7 +84,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
else:
batch[key] = batch[key].to('cuda:0')
loss = model(**batch).loss
loss = loss / gradient_accumulation_steps
......@@ -137,7 +136,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
model_checkpointing.save_model_checkpoint(
model, optimizer, rank, train_config, epoch=1
model, optimizer, rank, train_config, epoch=epoch
)
elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
print(" we are about to save the models *******")
......@@ -148,7 +147,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
if not train_config.use_peft and train_config.save_optimizer:
model_checkpointing.save_optimizer_checkpoint(
model, optimizer, rank, train_config, epoch=1
model, optimizer, rank, train_config, epoch=epoch
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment