From 21e8368c7ed23a823ff3e6787e4b6ac38204bacb Mon Sep 17 00:00:00 2001
From: JimChienTW <jim87112729@gmail.com>
Date: Sat, 16 Nov 2024 15:30:22 +0800
Subject: [PATCH] add freeze_LLM_only option for mllama finetuning

---
 src/llama_recipes/configs/training.py  |  1 +
 src/llama_recipes/finetuning.py        | 15 +++++++++++++--
 src/llama_recipes/utils/train_utils.py | 12 +++++++++++-
 3 files changed, 25 insertions(+), 3 deletions(-)

diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py
index acdbc890..19e273d3 100644
--- a/src/llama_recipes/configs/training.py
+++ b/src/llama_recipes/configs/training.py
@@ -35,6 +35,7 @@ class train_config:
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1
+    freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned
     quantization: str = None
     one_gpu: bool = False
     save_model: bool = True
diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index 548184e6..e1d702e2 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -38,6 +38,7 @@ from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
 from llama_recipes.utils.train_utils import (
     clear_gpu_cache,
     freeze_transformer_layers,
+    freeze_LLM_only,
     get_policies,
     print_model_size,
     setup,
@@ -193,8 +194,6 @@ def main(**kwargs):
         )
         model.resize_token_embeddings(len(tokenizer))
 
-    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
-
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
     if (
         train_config.enable_fsdp
@@ -235,6 +234,10 @@ def main(**kwargs):
 
         if not train_config.use_peft and train_config.freeze_layers:
             freeze_transformer_layers(model, train_config.num_freeze_layers)
+            
+        if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
+            freeze_LLM_only(model)
+            
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
@@ -255,6 +258,11 @@ def main(**kwargs):
             device_id = torch.xpu.current_device()
         elif torch.cuda.is_available():
             device_id = torch.cuda.current_device()
+        
+        if train_config.freeze_LLM_only:
+            use_orig_params = True
+        else:
+            use_orig_params = False
         model = FSDP(
             model,
             auto_wrap_policy=(
@@ -282,6 +290,7 @@ def main(**kwargs):
                 if train_config.low_cpu_fsdp and rank != 0
                 else None
             ),
+            use_orig_params=use_orig_params,
         )
         if fsdp_config.fsdp_activation_checkpointing:
             model.enable_input_require_grads()
@@ -298,6 +307,8 @@ def main(**kwargs):
     else:
         dataset_processer = tokenizer
 
+    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
+    
     # Load and preprocess the dataset for training and validation
 
     dataset_train = get_preprocessed_dataset(
diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py
index d3b42ae1..cec2df78 100644
--- a/src/llama_recipes/utils/train_utils.py
+++ b/src/llama_recipes/utils/train_utils.py
@@ -409,7 +409,17 @@ def freeze_transformer_layers(model, num_layer):
             if i < num_layer:
                 for param in layer.parameters():
                     param.requires_grad = False
-
+                    
+def freeze_LLM_only(model):
+    """
+    Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned
+    """
+    for name, param in model.language_model.named_parameters():
+                param.requires_grad = False
+    for i, layer in enumerate(model.language_model.model.layers):
+        if i in model.language_model.model.cross_attention_layers:
+            for param in layer.parameters():
+                param.requires_grad = True
 
 def check_frozen_layers_peft_model(model):
      for i, layer in enumerate(model.base_model.model.model.layers):
-- 
GitLab