diff --git a/model_checkpointing/__init__.py b/model_checkpointing/__init__.py index d9946f413fef3e21f14f4a7279774848cc337f2a..6c6efdf807de0d8baa54ef32ad0b880358418f1e 100644 --- a/model_checkpointing/__init__.py +++ b/model_checkpointing/__init__.py @@ -4,8 +4,6 @@ from .checkpoint_handler import ( load_model_checkpoint, save_model_checkpoint, - save_distributed_model_checkpoint, - load_distributed_model_checkpoint, load_optimizer_checkpoint, save_optimizer_checkpoint, save_model_and_optimizer_sharded, diff --git a/model_checkpointing/checkpoint_handler.py b/model_checkpointing/checkpoint_handler.py index 51193e80c063adb7812edd15acda586740fa2efc..6ab976b5cd8dd9ec939bf7ac1d04bc2427b4fee4 100644 --- a/model_checkpointing/checkpoint_handler.py +++ b/model_checkpointing/checkpoint_handler.py @@ -44,7 +44,7 @@ def get_date_of_run(): fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) -def load_model_sharded(model, rank, cfg, verbose=True): +def load_model_sharded(model, rank, cfg): # torch.manual_seed(103) folder_name = ( cfg.dist_checkpoint_root_folder @@ -83,7 +83,7 @@ def load_model_sharded(model, rank, cfg, verbose=True): print(f"Sharded state checkpoint loaded from {load_dir}") -def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True): +def save_model_and_optimizer_sharded(model, rank, cfg,optim=None): """save model and optimizer via sharded_state_dict to save_dir""" folder_name = ( @@ -142,7 +142,14 @@ def save_model_checkpoint( if rank == 0: print(f"--> saving model ...") # create save path - save_dir = Path.cwd() / cfg.checkpoint_folder + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + cfg.model_name + ) + save_dir = Path.cwd() / folder_name save_dir.mkdir(parents=True, exist_ok=True) save_name = cfg.model_name + "-" + str(epoch) + ".pt" save_full_path = str(save_dir) + "/" + save_name @@ -150,12 +157,12 @@ def save_model_checkpoint( # save model torch.save(cpu_state, save_full_path) - if cfg.verbose: - print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") + + print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") -def load_model_checkpoint(model, rank, cfg, verbose=True): +def load_model_checkpoint(model, rank, cfg): """load local checkpoint to rank0 cpu must be called * before * passing to FSDP""" @@ -178,8 +185,8 @@ def load_model_checkpoint(model, rank, cfg, verbose=True): # integrate into loaded model model.load_state_dict(model_checkpoint) - if cfg.verbose: - print(f"model checkpoint loaded to rank0 cpu") + + print(f"model checkpoint loaded to rank0 cpu") def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): @@ -192,15 +199,22 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): optim_state = FSDP.full_optim_state_dict(model, optimizer) - if cfg.verbose: - print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n") + + print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n") if rank == 0: - save_dir = Path.cwd() / cfg.checkpoint_folder + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + cfg.model_name + ) + save_dir = Path.cwd() / folder_name save_dir.mkdir(parents=True, exist_ok=True) opt_save_name = ( - cfg.optimizer_name + "-" + cfg.model_name + "-" + str(epoch) + ".pt" + "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt" ) opt_save_full_path = save_dir / opt_save_name @@ -211,96 +225,25 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): print(f"--> saved {opt_save_full_path} to disk") -def load_optimizer_checkpoint(model, optimizer, rank, cfg): +def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank): """load an fsdp optimizer full_state checkpoint using scatter method this ensures only rank 0 loads the optimizer state dict and scatters to other ranks """ - opt_file_path = Path.cwd() / cfg.checkpoint_folder / cfg.optimizer_checkpoint_file - if not opt_file_path.is_file(): + if not optimizer_checkpoint_path.is_file(): print( - f"warning - optimizer checkpoint not present {opt_file_path}. Returning. " + f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. " ) return full_osd = None if rank == 0: - full_osd = torch.load(opt_file_path) - - if cfg.verbose: - print(f"loaded full osd on rank 0") + full_osd = torch.load(optimizer_checkpoint_path) # called from all ranks, though only rank0 has a valid param for full_osd sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model) - if cfg.verbose: - print(f"optimizer shard loaded on rank {rank}") - - + print(f"optimizer shard loaded on rank {rank}") -def load_distributed_model_checkpoint(model, rank, cfg): - if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT: - print(f"loading distributed checkpoint, rank {rank}...") - folder_name = ( - cfg.dist_checkpoint_root_folder - + "/" - + cfg.dist_checkpoint_folder - + "-" - + cfg.model_name - ) - - checkdir = Path.cwd() / folder_name - - if not checkdir.exists(): - if rank == 0: - print(f"No checkpoint directory found...skipping") - return - - - reader = FileSystemReader(checkdir) - - with FSDP.state_dict_type( - model, - StateDictType.LOCAL_STATE_DICT, - ): - state_dict = model.state_dict() - load_state_dict(state_dict, reader) - model.load_state_dict(state_dict) - - print(f"--> local state loaded on rank {rank}") - - return - - -def save_distributed_model_checkpoint(model, rank, cfg, epoch=1): - # distributed checkpoint saving - - # confirm type of checkpoint and save - if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT: - # create writer to current path - folder_name = ( - cfg.dist_checkpoint_root_folder - + "/" - + cfg.dist_checkpoint_folder - + "-" - + cfg.model_name - ) - save_dir = Path.cwd() / folder_name - - writer = FileSystemWriter( - save_dir, - ) - - with FSDP.state_dict_type( - model, - StateDictType.LOCAL_STATE_DICT, - ): - state_dict = model.state_dict() - - - # write out distributed checkpoint - save_state_dict(state_dict, writer) - - return diff --git a/utils/train_utils.py b/utils/train_utils.py index 7421907585ba912e4afff3bf59cc42f45ff0ff06..500ba2a83a04a616f04e02d7fbf2bb1f5bc5ec80 100644 --- a/utils/train_utils.py +++ b/utils/train_utils.py @@ -84,7 +84,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank) else: - batch[key] = batch[key].to('cuda:0') loss = model(**batch).loss loss = loss / gradient_accumulation_steps @@ -137,7 +136,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: model_checkpointing.save_model_checkpoint( - model, optimizer, rank, train_config, epoch=1 + model, optimizer, rank, train_config, epoch=epoch ) elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: print(" we are about to save the models *******") @@ -148,7 +147,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if not train_config.use_peft and train_config.save_optimizer: model_checkpointing.save_optimizer_checkpoint( - model, optimizer, rank, train_config, epoch=1 + model, optimizer, rank, train_config, epoch=epoch )