diff --git a/src/llama_recipes/configs/fsdp.py b/src/llama_recipes/configs/fsdp.py
index ee10262d7554ab900ad439d2d4f9d849ab1f3ffc..c89aff1f92809411fa84b19d05e4ffd241ed5634 100644
--- a/src/llama_recipes/configs/fsdp.py
+++ b/src/llama_recipes/configs/fsdp.py
@@ -13,8 +13,7 @@ class fsdp_config:
     sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
     checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT  # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
     fsdp_activation_checkpointing: bool=True
+    fsdp_cpu_offload: bool=False
     pure_bf16: bool = False
     optimizer: str= "AdamW"
     
-    
-    
\ No newline at end of file
diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index a475c1cad42994ee8517db27548156bf1e35703f..5f7f8b24abcf45e1984279f4f6e6ae07345084d8 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -12,6 +12,7 @@ from peft import get_peft_model, prepare_model_for_int8_training
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
 )
+from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from torch.utils.data import DistributedSampler
 from transformers import (
@@ -144,6 +145,7 @@ def main(**kwargs):
         model = FSDP(
             model,
             auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
+            cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             sharding_strategy=fsdp_config.sharding_strategy,
             device_id=torch.cuda.current_device(),