Skip to content
Snippets Groups Projects
Commit 21e8368c authored by JimChienTW's avatar JimChienTW
Browse files

add freeze_LLM_only option for mllama finetuning

parent b9fc1069
Branches
No related tags found
No related merge requests found
...@@ -35,6 +35,7 @@ class train_config: ...@@ -35,6 +35,7 @@ class train_config:
output_dir: str = "PATH/to/save/PEFT/model" output_dir: str = "PATH/to/save/PEFT/model"
freeze_layers: bool = False freeze_layers: bool = False
num_freeze_layers: int = 1 num_freeze_layers: int = 1
freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned
quantization: str = None quantization: str = None
one_gpu: bool = False one_gpu: bool = False
save_model: bool = True save_model: bool = True
......
...@@ -38,6 +38,7 @@ from llama_recipes.utils.fsdp_utils import hsdp_device_mesh ...@@ -38,6 +38,7 @@ from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
from llama_recipes.utils.train_utils import ( from llama_recipes.utils.train_utils import (
clear_gpu_cache, clear_gpu_cache,
freeze_transformer_layers, freeze_transformer_layers,
freeze_LLM_only,
get_policies, get_policies,
print_model_size, print_model_size,
setup, setup,
...@@ -193,8 +194,6 @@ def main(**kwargs): ...@@ -193,8 +194,6 @@ def main(**kwargs):
) )
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
if ( if (
train_config.enable_fsdp train_config.enable_fsdp
...@@ -235,6 +234,10 @@ def main(**kwargs): ...@@ -235,6 +234,10 @@ def main(**kwargs):
if not train_config.use_peft and train_config.freeze_layers: if not train_config.use_peft and train_config.freeze_layers:
freeze_transformer_layers(model, train_config.num_freeze_layers) freeze_transformer_layers(model, train_config.num_freeze_layers)
if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
freeze_LLM_only(model)
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
# Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
...@@ -255,6 +258,11 @@ def main(**kwargs): ...@@ -255,6 +258,11 @@ def main(**kwargs):
device_id = torch.xpu.current_device() device_id = torch.xpu.current_device()
elif torch.cuda.is_available(): elif torch.cuda.is_available():
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
if train_config.freeze_LLM_only:
use_orig_params = True
else:
use_orig_params = False
model = FSDP( model = FSDP(
model, model,
auto_wrap_policy=( auto_wrap_policy=(
...@@ -282,6 +290,7 @@ def main(**kwargs): ...@@ -282,6 +290,7 @@ def main(**kwargs):
if train_config.low_cpu_fsdp and rank != 0 if train_config.low_cpu_fsdp and rank != 0
else None else None
), ),
use_orig_params=use_orig_params,
) )
if fsdp_config.fsdp_activation_checkpointing: if fsdp_config.fsdp_activation_checkpointing:
model.enable_input_require_grads() model.enable_input_require_grads()
...@@ -298,6 +307,8 @@ def main(**kwargs): ...@@ -298,6 +307,8 @@ def main(**kwargs):
else: else:
dataset_processer = tokenizer dataset_processer = tokenizer
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
# Load and preprocess the dataset for training and validation # Load and preprocess the dataset for training and validation
dataset_train = get_preprocessed_dataset( dataset_train = get_preprocessed_dataset(
......
...@@ -409,7 +409,17 @@ def freeze_transformer_layers(model, num_layer): ...@@ -409,7 +409,17 @@ def freeze_transformer_layers(model, num_layer):
if i < num_layer: if i < num_layer:
for param in layer.parameters(): for param in layer.parameters():
param.requires_grad = False param.requires_grad = False
def freeze_LLM_only(model):
"""
Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned
"""
for name, param in model.language_model.named_parameters():
param.requires_grad = False
for i, layer in enumerate(model.language_model.model.layers):
if i in model.language_model.model.cross_attention_layers:
for param in layer.parameters():
param.requires_grad = True
def check_frozen_layers_peft_model(model): def check_frozen_layers_peft_model(model):
for i, layer in enumerate(model.base_model.model.model.layers): for i, layer in enumerate(model.base_model.model.model.layers):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment