diff --git a/README.md b/README.md
index 370be8974ef3d3d9fe77c080844261c2c3b29115..be2b4becde125dd0d6a8e0b0ebdee90ab147c5b4 100644
--- a/README.md
+++ b/README.md
@@ -125,6 +125,16 @@ torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --mode
 
 ```
 
+### Fine-tuning using FSDP on 70B Model
+
+If you are interested in running full parameter fine-tuning on the 70B model, you can enable `low_cpu_fsdp` mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with `BF16` on 16xA100, 80GB GPUs.
+
+```bash
+
+torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+
+```
+
 ### Multi GPU Multi Node:
 
 ```bash
diff --git a/configs/training.py b/configs/training.py
index 7b0e82d44363e563a2420ef6a2e7977ad8c539c8..088295ed05b5d1c431674420d20eff6613e05fb0 100644
--- a/configs/training.py
+++ b/configs/training.py
@@ -7,7 +7,8 @@ from typing import ClassVar
 @dataclass
 class train_config:
     model_name: str="PATH/to/LLAMA/7B"
-    enable_fsdp: bool= False 
+    enable_fsdp: bool=False
+    low_cpu_fsdp: bool=False
     run_validation: bool=True
     batch_size_training: int=4
     num_epochs: int=3
diff --git a/docs/multi_gpu.md b/docs/multi_gpu.md
index 5c8fdf353fb2b77cb8facaaeff2ba9bb1a8e511a..f4961df9e0923f1fdb62790491cdd2367de5d293 100644
--- a/docs/multi_gpu.md
+++ b/docs/multi_gpu.md
@@ -62,6 +62,16 @@ torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --mode
 
 ```
 
+### Fine-tuning using FSDP on 70B Model
+
+If you are interested in running full parameter fine-tuning on the 70B model, you can enable `low_cpu_fsdp` mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with `BF16` on 16xA100, 80GB GPUs.
+
+```bash
+
+torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+
+```
+
 **Multi GPU multi node**:
 
 Here we use a slurm script to schedule a job with slurm over multiple nodes.
diff --git a/llama_finetuning.py b/llama_finetuning.py
index b7c2dc60355c203fae73c808c8a595baeeb4d5c3..5a2f86267cb2423127e3d28786409f50db52f09b 100644
--- a/llama_finetuning.py
+++ b/llama_finetuning.py
@@ -2,68 +2,47 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import os
-import sys
-from typing import List, Union
 
 import fire
 import torch
-import transformers
-from datasets import load_dataset
-import os.path as osp
-from tqdm import tqdm
-
-# Unused imports removed
-from utils import fsdp_auto_wrap_policy
+import torch.distributed as dist
+import torch.optim as optim
+from peft import get_peft_model, prepare_model_for_int8_training
+from pkg_resources import packaging
+from torch.distributed.fsdp import (
+    FullyShardedDataParallel as FSDP,
+)
+from torch.optim.lr_scheduler import StepLR
+from torch.utils.data import DistributedSampler
 from transformers import (
     LlamaForCausalLM,
     LlamaTokenizer,
-    AutoModelForCausalLM,
-    AutoModelForSeq2SeqLM,
-    AutoTokenizer,
+    LlamaConfig,
     default_data_collator,
-    BitsAndBytesConfig
-)
-import torch.distributed as dist
-# Unused imports removed
-from utils.train_utils import (
-    set_tokenizer_params,
-    train,
-    evaluation,
-    freeze_transformer_layers,
-    check_frozen_layers_peft_model,
-    setup,
-    setup_environ_flags,
-    cleanup,
-    clear_gpu_cache,
-    get_parameter_dtypes,
-    print_model_size,
-    get_policies  
 )
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
-from utils.dataset_utils import get_preprocessed_dataset
+import policies
+from configs import fsdp_config, train_config
+from policies import AnyPrecisionAdamW
 
+from utils import fsdp_auto_wrap_policy
 from utils.config_utils import (
     update_config,
     generate_peft_config,
     generate_dataset_config,
 )
-from peft import get_peft_model, TaskType, prepare_model_for_int8_training
-import configs
-from torch.distributed.fsdp import (
-    FullyShardedDataParallel as FSDP,
-    MixedPrecision,
+from utils.dataset_utils import get_preprocessed_dataset
+
+from utils.train_utils import (
+    train,
+    freeze_transformer_layers,
+    setup,
+    setup_environ_flags,
+    clear_gpu_cache,
+    print_model_size,
+    get_policies
 )
-from torch.utils.data import DistributedSampler
-import policies
-from policies import AnyPrecisionAdamW
-from configs import fsdp_config, train_config
-import torch.optim as optim
-from torch.optim.lr_scheduler import StepLR
-from pkg_resources import packaging
-import torch
-import torch.cuda.nccl as nccl
-import torch.distributed as dist
-from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
 
 def main(**kwargs):
@@ -82,18 +61,43 @@ def main(**kwargs):
         world_size = int(os.environ["WORLD_SIZE"])
 
     if torch.distributed.is_initialized():
-        torch.cuda.set_device(rank)
+        torch.cuda.set_device(local_rank)
+        clear_gpu_cache(local_rank)
         setup_environ_flags(rank)
-    
+
     # Calculate gradient accumulation steps
     gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
-     
+
     # Load the pre-trained model and setup its configuration
-    model = LlamaForCausalLM.from_pretrained(
-        train_config.model_name,
-        load_in_8bit=True if train_config.quantization else None,
-        device_map="auto" if train_config.quantization else None,
-    )
+    if train_config.enable_fsdp and train_config.low_cpu_fsdp:
+        """
+        for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
+        this avoids cpu oom when loading large models like llama 70B, in which case
+        model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
+        overhead and currently requires latest nightly.
+        """
+        v = packaging.version.parse(torch.__version__)
+        verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
+        if not verify_latest_nightly:
+            raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
+                            "please install latest nightly.")
+        if rank == 0:
+            model = LlamaForCausalLM.from_pretrained(
+                train_config.model_name,
+                load_in_8bit=True if train_config.quantization else None,
+                device_map="auto" if train_config.quantization else None,
+            )
+        else:
+            llama_config = LlamaConfig.from_pretrained(train_config.model_name)
+            with torch.device("meta"):
+                model = LlamaForCausalLM(llama_config)
+
+    else:
+        model = LlamaForCausalLM.from_pretrained(
+            train_config.model_name,
+            load_in_8bit=True if train_config.quantization else None,
+            device_map="auto" if train_config.quantization else None,
+        )
     if train_config.enable_fsdp and train_config.use_fast_kernels:
         """
         For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
@@ -106,11 +110,11 @@ def main(**kwargs):
         except ImportError:
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
-    
+
     # Prepare the model for int8 training if quantization is enabled
     if train_config.quantization:
         model = prepare_model_for_int8_training(model)
-        
+
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
     if train_config.enable_fsdp and fsdp_config.pure_bf16:
         model.to(torch.bfloat16)
@@ -119,7 +123,7 @@ def main(**kwargs):
     tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
     tokenizer.add_special_tokens(
             {
-            
+
                 "pad_token": "<PAD>",
             }
         )
@@ -127,16 +131,16 @@ def main(**kwargs):
         peft_config = generate_peft_config(train_config, kwargs)
         model = get_peft_model(model, peft_config)
         model.print_trainable_parameters()
-    
+
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
         if not train_config.use_peft and train_config.freeze_layers:
-            
+
             freeze_transformer_layers(train_config.num_freeze_layers)
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
-   
+
         model = FSDP(
             model,
             auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
@@ -144,6 +148,9 @@ def main(**kwargs):
             sharding_strategy=fsdp_config.sharding_strategy,
             device_id=torch.cuda.current_device(),
             limit_all_gathers=True,
+            sync_module_states=train_config.low_cpu_fsdp,
+            param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
+            if train_config.low_cpu_fsdp and rank != 0 else None,
         )
         if fsdp_config.fsdp_activation_checkpointing:
             policies.apply_fsdp_checkpointing(model)
@@ -151,14 +158,14 @@ def main(**kwargs):
         model.to("cuda")
 
     dataset_config = generate_dataset_config(train_config, kwargs)
-    
+
      # Load and preprocess the dataset for training and validation
     dataset_train = get_preprocessed_dataset(
         tokenizer,
         dataset_config,
         split="train",
     )
-    
+
     if not train_config.enable_fsdp or rank == 0:
         print(f"--> Training Set Length = {len(dataset_train)}")
 
@@ -185,7 +192,7 @@ def main(**kwargs):
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
             )
-        
+
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
@@ -207,7 +214,7 @@ def main(**kwargs):
             drop_last=True,
             collate_fn=default_data_collator,
         )
-        
+
     # Initialize the optimizer and learning rate scheduler
     if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
         optimizer = AnyPrecisionAdamW(
@@ -229,7 +236,7 @@ def main(**kwargs):
     results = train(
         model,
         train_dataloader,
-        eval_dataloader, 
+        eval_dataloader,
         tokenizer,
         optimizer,
         scheduler,
diff --git a/model_checkpointing/checkpoint_handler.py b/model_checkpointing/checkpoint_handler.py
index 5bb8ff11193c50f7fbb9776229746eecdc3d0195..b097df97daf75b3bc9ed57d89d02619e0dfecd4f 100644
--- a/model_checkpointing/checkpoint_handler.py
+++ b/model_checkpointing/checkpoint_handler.py
@@ -65,7 +65,7 @@ def load_model_sharded(model, rank, cfg):
     reader = FileSystemReader(load_dir)
 
     with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
-        checkpoint = model.state_dict()
+        checkpoint = {"model": model.state_dict()}
         if rank == 0:
             ck = checkpoint.keys()
             print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
@@ -78,7 +78,7 @@ def load_model_sharded(model, rank, cfg):
             print(f"checkpoint after load_state_dict()")
             ck = checkpoint.keys()
             print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
-        model.load_state_dict(checkpoint)
+        model.load_state_dict(checkpoint["model"])
     if rank == 0:
         print(f"Sharded state checkpoint loaded from {load_dir}")
 
diff --git a/utils/train_utils.py b/utils/train_utils.py
index 8113ef173bb748a083e88a899ac8033f131e5421..d4eab21a2574c390b8ed117aa6b415d97cb3c757 100644
--- a/utils/train_utils.py
+++ b/utils/train_utils.py
@@ -141,7 +141,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         lr_scheduler.step()
           
         if train_config.run_validation:
-            eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)   
+            eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
             checkpoint_start_time = time.perf_counter()
             if train_config.save_model and eval_epoch_loss < best_val_loss:
                 if train_config.enable_fsdp: