diff --git a/llama_finetuning.py b/llama_finetuning.py index 4e324140781dabc6fd31a8915fe6293951346b9f..ccf8c68457d13597cfe96768b6f8808b94fe243b 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -134,7 +134,7 @@ def main(**kwargs): mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, sharding_strategy=fsdp_config.sharding_strategy, device_id=torch.cuda.current_device(), - limit_all_gathers=False, + limit_all_gathers=True, ) if fsdp_config.fsdp_activation_checkpointing: policies.apply_fsdp_checkpointing(model) diff --git a/model_checkpointing/checkpoint_handler.py b/model_checkpointing/checkpoint_handler.py index e917c7f2daf234d0adcafea55f4ac284428afe9f..51193e80c063adb7812edd15acda586740fa2efc 100644 --- a/model_checkpointing/checkpoint_handler.py +++ b/model_checkpointing/checkpoint_handler.py @@ -212,7 +212,7 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): def load_optimizer_checkpoint(model, optimizer, rank, cfg): - """load an fdsp optimizer full_state checkpoint using scatter method + """load an fsdp optimizer full_state checkpoint using scatter method this ensures only rank 0 loads the optimizer state dict and scatters to other ranks """ diff --git a/policies/activation_checkpointing_functions.py b/policies/activation_checkpointing_functions.py index 0a1e31f427d1bedc6e7b3eb905e6614f2441be87..379bc6bfabc03f34745d995c3257181aa8ae030f 100644 --- a/policies/activation_checkpointing_functions.py +++ b/policies/activation_checkpointing_functions.py @@ -26,7 +26,7 @@ def apply_fsdp_checkpointing(model): """apply activation checkpointing to model returns None as model is updated directly """ - print(f"--> applying fdsp activation checkpointing...") + print(f"--> applying fsdp activation checkpointing...") apply_activation_checkpointing( model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn diff --git a/utils/train_utils.py b/utils/train_utils.py index e41f503ea1ba7140f84b202a4ecd72c88ee0285b..7421907585ba912e4afff3bf59cc42f45ff0ff06 100644 --- a/utils/train_utils.py +++ b/utils/train_utils.py @@ -84,9 +84,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank) else: - batch[key] = batch[key].to('cuda:0') - outputs = model(**batch) - loss = outputs.loss + + batch[key] = batch[key].to('cuda:0') + loss = model(**batch).loss loss = loss / gradient_accumulation_steps total_loss += loss.detach().float() first_key = next(iter(batch)) @@ -105,7 +105,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche optimizer.step() optimizer.zero_grad() - print(f"\n step {step} is completed and loss is {loss.detach().float()}") + print(f"\n step {step} is completed and loss is {loss.detach().float()}") # Reducing total_loss across all devices if there's more than one CUDA device if torch.cuda.device_count() > 1 and train_config.enable_fsdp: dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)