diff --git a/src/llama_recipes/configs/fsdp.py b/src/llama_recipes/configs/fsdp.py index ee10262d7554ab900ad439d2d4f9d849ab1f3ffc..c89aff1f92809411fa84b19d05e4ffd241ed5634 100644 --- a/src/llama_recipes/configs/fsdp.py +++ b/src/llama_recipes/configs/fsdp.py @@ -13,8 +13,7 @@ class fsdp_config: sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. fsdp_activation_checkpointing: bool=True + fsdp_cpu_offload: bool=False pure_bf16: bool = False optimizer: str= "AdamW" - - \ No newline at end of file diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index a475c1cad42994ee8517db27548156bf1e35703f..5f7f8b24abcf45e1984279f4f6e6ae07345084d8 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -12,6 +12,7 @@ from peft import get_peft_model, prepare_model_for_int8_training from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ) +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload from torch.optim.lr_scheduler import StepLR from torch.utils.data import DistributedSampler from transformers import ( @@ -144,6 +145,7 @@ def main(**kwargs): model = FSDP( model, auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy, + cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, sharding_strategy=fsdp_config.sharding_strategy, device_id=torch.cuda.current_device(),