diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py
index acdbc890b0940947098a289e6fe54b273fd0fdfa..19e273d39640bd539af328dcec4a15021cb5fdea 100644
--- a/src/llama_recipes/configs/training.py
+++ b/src/llama_recipes/configs/training.py
@@ -35,6 +35,7 @@ class train_config:
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1
+    freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned
     quantization: str = None
     one_gpu: bool = False
     save_model: bool = True
diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index 548184e6ab85be6d473defd4d401afb9c9f1a093..e1d702e2005b557650e63f5d6c424a03bf686daf 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -38,6 +38,7 @@ from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
 from llama_recipes.utils.train_utils import (
     clear_gpu_cache,
     freeze_transformer_layers,
+    freeze_LLM_only,
     get_policies,
     print_model_size,
     setup,
@@ -193,8 +194,6 @@ def main(**kwargs):
         )
         model.resize_token_embeddings(len(tokenizer))
 
-    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
-
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
     if (
         train_config.enable_fsdp
@@ -235,6 +234,10 @@ def main(**kwargs):
 
         if not train_config.use_peft and train_config.freeze_layers:
             freeze_transformer_layers(model, train_config.num_freeze_layers)
+            
+        if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
+            freeze_LLM_only(model)
+            
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
@@ -255,6 +258,11 @@ def main(**kwargs):
             device_id = torch.xpu.current_device()
         elif torch.cuda.is_available():
             device_id = torch.cuda.current_device()
+        
+        if train_config.freeze_LLM_only:
+            use_orig_params = True
+        else:
+            use_orig_params = False
         model = FSDP(
             model,
             auto_wrap_policy=(
@@ -282,6 +290,7 @@ def main(**kwargs):
                 if train_config.low_cpu_fsdp and rank != 0
                 else None
             ),
+            use_orig_params=use_orig_params,
         )
         if fsdp_config.fsdp_activation_checkpointing:
             model.enable_input_require_grads()
@@ -298,6 +307,8 @@ def main(**kwargs):
     else:
         dataset_processer = tokenizer
 
+    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
+    
     # Load and preprocess the dataset for training and validation
 
     dataset_train = get_preprocessed_dataset(
diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py
index d3b42ae1254ec308548e0bb89e381e8b42a6fee8..cec2df784bce4e2f48b4c2a1365c41b69bd3c095 100644
--- a/src/llama_recipes/utils/train_utils.py
+++ b/src/llama_recipes/utils/train_utils.py
@@ -409,7 +409,17 @@ def freeze_transformer_layers(model, num_layer):
             if i < num_layer:
                 for param in layer.parameters():
                     param.requires_grad = False
-
+                    
+def freeze_LLM_only(model):
+    """
+    Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned
+    """
+    for name, param in model.language_model.named_parameters():
+                param.requires_grad = False
+    for i, layer in enumerate(model.language_model.model.layers):
+        if i in model.language_model.model.cross_attention_layers:
+            for param in layer.parameters():
+                param.requires_grad = True
 
 def check_frozen_layers_peft_model(model):
      for i, layer in enumerate(model.base_model.model.model.layers):