From cc356b60179ebf1adc7708ad1cd43bb4100b92cf Mon Sep 17 00:00:00 2001 From: Howard Liberty <liberty@anymemo.org> Date: Wed, 16 Aug 2023 14:36:07 -0700 Subject: [PATCH] Add FSDP CPU offloading option --- src/llama_recipes/configs/fsdp.py | 3 +-- src/llama_recipes/finetuning.py | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llama_recipes/configs/fsdp.py b/src/llama_recipes/configs/fsdp.py index ee10262d..c89aff1f 100644 --- a/src/llama_recipes/configs/fsdp.py +++ b/src/llama_recipes/configs/fsdp.py @@ -13,8 +13,7 @@ class fsdp_config: sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. fsdp_activation_checkpointing: bool=True + fsdp_cpu_offload: bool=False pure_bf16: bool = False optimizer: str= "AdamW" - - \ No newline at end of file diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index a475c1ca..5f7f8b24 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -12,6 +12,7 @@ from peft import get_peft_model, prepare_model_for_int8_training from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ) +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload from torch.optim.lr_scheduler import StepLR from torch.utils.data import DistributedSampler from transformers import ( @@ -144,6 +145,7 @@ def main(**kwargs): model = FSDP( model, auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy, + cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, sharding_strategy=fsdp_config.sharding_strategy, device_id=torch.cuda.current_device(), -- GitLab