From 21e8368c7ed23a823ff3e6787e4b6ac38204bacb Mon Sep 17 00:00:00 2001 From: JimChienTW <jim87112729@gmail.com> Date: Sat, 16 Nov 2024 15:30:22 +0800 Subject: [PATCH] add freeze_LLM_only option for mllama finetuning --- src/llama_recipes/configs/training.py | 1 + src/llama_recipes/finetuning.py | 15 +++++++++++++-- src/llama_recipes/utils/train_utils.py | 12 +++++++++++- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index acdbc890..19e273d3 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -35,6 +35,7 @@ class train_config: output_dir: str = "PATH/to/save/PEFT/model" freeze_layers: bool = False num_freeze_layers: int = 1 + freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned quantization: str = None one_gpu: bool = False save_model: bool = True diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 548184e6..e1d702e2 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -38,6 +38,7 @@ from llama_recipes.utils.fsdp_utils import hsdp_device_mesh from llama_recipes.utils.train_utils import ( clear_gpu_cache, freeze_transformer_layers, + freeze_LLM_only, get_policies, print_model_size, setup, @@ -193,8 +194,6 @@ def main(**kwargs): ) model.resize_token_embeddings(len(tokenizer)) - print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) - # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled if ( train_config.enable_fsdp @@ -235,6 +234,10 @@ def main(**kwargs): if not train_config.use_peft and train_config.freeze_layers: freeze_transformer_layers(model, train_config.num_freeze_layers) + + if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama": + freeze_LLM_only(model) + mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models @@ -255,6 +258,11 @@ def main(**kwargs): device_id = torch.xpu.current_device() elif torch.cuda.is_available(): device_id = torch.cuda.current_device() + + if train_config.freeze_LLM_only: + use_orig_params = True + else: + use_orig_params = False model = FSDP( model, auto_wrap_policy=( @@ -282,6 +290,7 @@ def main(**kwargs): if train_config.low_cpu_fsdp and rank != 0 else None ), + use_orig_params=use_orig_params, ) if fsdp_config.fsdp_activation_checkpointing: model.enable_input_require_grads() @@ -298,6 +307,8 @@ def main(**kwargs): else: dataset_processer = tokenizer + print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) + # Load and preprocess the dataset for training and validation dataset_train = get_preprocessed_dataset( diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index d3b42ae1..cec2df78 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -409,7 +409,17 @@ def freeze_transformer_layers(model, num_layer): if i < num_layer: for param in layer.parameters(): param.requires_grad = False - + +def freeze_LLM_only(model): + """ + Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned + """ + for name, param in model.language_model.named_parameters(): + param.requires_grad = False + for i, layer in enumerate(model.language_model.model.layers): + if i in model.language_model.model.cross_attention_layers: + for param in layer.parameters(): + param.requires_grad = True def check_frozen_layers_peft_model(model): for i, layer in enumerate(model.base_model.model.model.layers): -- GitLab