diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index acdbc890b0940947098a289e6fe54b273fd0fdfa..19e273d39640bd539af328dcec4a15021cb5fdea 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -35,6 +35,7 @@ class train_config: output_dir: str = "PATH/to/save/PEFT/model" freeze_layers: bool = False num_freeze_layers: int = 1 + freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned quantization: str = None one_gpu: bool = False save_model: bool = True diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 548184e6ab85be6d473defd4d401afb9c9f1a093..e1d702e2005b557650e63f5d6c424a03bf686daf 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -38,6 +38,7 @@ from llama_recipes.utils.fsdp_utils import hsdp_device_mesh from llama_recipes.utils.train_utils import ( clear_gpu_cache, freeze_transformer_layers, + freeze_LLM_only, get_policies, print_model_size, setup, @@ -193,8 +194,6 @@ def main(**kwargs): ) model.resize_token_embeddings(len(tokenizer)) - print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) - # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled if ( train_config.enable_fsdp @@ -235,6 +234,10 @@ def main(**kwargs): if not train_config.use_peft and train_config.freeze_layers: freeze_transformer_layers(model, train_config.num_freeze_layers) + + if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama": + freeze_LLM_only(model) + mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models @@ -255,6 +258,11 @@ def main(**kwargs): device_id = torch.xpu.current_device() elif torch.cuda.is_available(): device_id = torch.cuda.current_device() + + if train_config.freeze_LLM_only: + use_orig_params = True + else: + use_orig_params = False model = FSDP( model, auto_wrap_policy=( @@ -282,6 +290,7 @@ def main(**kwargs): if train_config.low_cpu_fsdp and rank != 0 else None ), + use_orig_params=use_orig_params, ) if fsdp_config.fsdp_activation_checkpointing: model.enable_input_require_grads() @@ -298,6 +307,8 @@ def main(**kwargs): else: dataset_processer = tokenizer + print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) + # Load and preprocess the dataset for training and validation dataset_train = get_preprocessed_dataset( diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index d3b42ae1254ec308548e0bb89e381e8b42a6fee8..cec2df784bce4e2f48b4c2a1365c41b69bd3c095 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -409,7 +409,17 @@ def freeze_transformer_layers(model, num_layer): if i < num_layer: for param in layer.parameters(): param.requires_grad = False - + +def freeze_LLM_only(model): + """ + Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned + """ + for name, param in model.language_model.named_parameters(): + param.requires_grad = False + for i, layer in enumerate(model.language_model.model.layers): + if i in model.language_model.model.cross_attention_layers: + for param in layer.parameters(): + param.requires_grad = True def check_frozen_layers_peft_model(model): for i, layer in enumerate(model.base_model.model.model.layers):