diff --git a/recipes/quickstart/finetuning/README.md b/recipes/quickstart/finetuning/README.md index bee4db7f565c45e5a9066e748c5089c686093c7c..46d58aa6cfd58ae8387cefb9a3ba29d963556bce 100644 --- a/recipes/quickstart/finetuning/README.md +++ b/recipes/quickstart/finetuning/README.md @@ -54,6 +54,7 @@ It lets us specify the training settings for everything from `model_name` to `da output_dir: str = "PATH/to/save/PEFT/model" freeze_layers: bool = False num_freeze_layers: int = 1 + freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned quantization: str = None one_gpu: bool = False save_model: bool = True diff --git a/recipes/quickstart/finetuning/finetune_vision_model.md b/recipes/quickstart/finetuning/finetune_vision_model.md index 6f7d64f64c1b2f2183c0478f7c212db8bf48cdf1..d0868796895c71a0b5524b1e4f66bbcc5bb368c3 100644 --- a/recipes/quickstart/finetuning/finetune_vision_model.md +++ b/recipes/quickstart/finetuning/finetune_vision_model.md @@ -18,6 +18,12 @@ For **LoRA finetuning with FSDP**, we can run the following code: ```bash torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py" --run_validation True --batching_strategy padding --use_peft --peft_method lora ``` + +For **finetuning with LLM freeze using FSDP**, we can run the following code: + +```bash + torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py" --run_validation True --batching_strategy padding --freeze_LLM_only True +``` **Note**: `--batching_strategy padding` is needed as the vision model will not work with `packing` method. For more details about the finetuning configurations, please read the [finetuning readme](./README.md). diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index acdbc890b0940947098a289e6fe54b273fd0fdfa..19e273d39640bd539af328dcec4a15021cb5fdea 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -35,6 +35,7 @@ class train_config: output_dir: str = "PATH/to/save/PEFT/model" freeze_layers: bool = False num_freeze_layers: int = 1 + freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned quantization: str = None one_gpu: bool = False save_model: bool = True diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 548184e6ab85be6d473defd4d401afb9c9f1a093..75ef6337f3c49d6be679a8c498f91aa6e9ff82b2 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -38,8 +38,10 @@ from llama_recipes.utils.fsdp_utils import hsdp_device_mesh from llama_recipes.utils.train_utils import ( clear_gpu_cache, freeze_transformer_layers, + freeze_LLM_only, get_policies, print_model_size, + print_frozen_model_status, setup, setup_environ_flags, train, @@ -194,7 +196,7 @@ def main(**kwargs): model.resize_token_embeddings(len(tokenizer)) print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) - + # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled if ( train_config.enable_fsdp @@ -235,7 +237,14 @@ def main(**kwargs): if not train_config.use_peft and train_config.freeze_layers: freeze_transformer_layers(model, train_config.num_freeze_layers) - + # print model size and frozen layers after freezing layers + print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0) + + if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama": + freeze_LLM_only(model) + # print model size and frozen layers after freezing layers + print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0) + mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models if is_vision: @@ -255,6 +264,11 @@ def main(**kwargs): device_id = torch.xpu.current_device() elif torch.cuda.is_available(): device_id = torch.cuda.current_device() + + if train_config.freeze_LLM_only: + use_orig_params = True + else: + use_orig_params = False model = FSDP( model, auto_wrap_policy=( @@ -282,6 +296,7 @@ def main(**kwargs): if train_config.low_cpu_fsdp and rank != 0 else None ), + use_orig_params=use_orig_params, ) if fsdp_config.fsdp_activation_checkpointing: model.enable_input_require_grads() @@ -297,7 +312,7 @@ def main(**kwargs): dataset_processer = processor else: dataset_processer = tokenizer - + # Load and preprocess the dataset for training and validation dataset_train = get_preprocessed_dataset( diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index d3b42ae1254ec308548e0bb89e381e8b42a6fee8..c594b6a1e6555bba31524d972b4adf41995e7bb3 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -409,7 +409,17 @@ def freeze_transformer_layers(model, num_layer): if i < num_layer: for param in layer.parameters(): param.requires_grad = False - + +def freeze_LLM_only(model): + """ + Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned + """ + for name, param in model.language_model.named_parameters(): + param.requires_grad = False + for i, layer in enumerate(model.language_model.model.layers): + if i in model.language_model.model.cross_attention_layers: + for param in layer.parameters(): + param.requires_grad = True def check_frozen_layers_peft_model(model): for i, layer in enumerate(model.base_model.model.model.layers): @@ -476,8 +486,52 @@ def print_model_size(model, config, rank: int = 0) -> None: total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n") +def print_frozen_model_status(model, config, rank: int = 0) -> None: + """ + Print the frozen status of the model's and the number of trainable parameters after frozen. - + Args: + model: The PyTorch model. + model_name (str): Name of the model. + rank (int, optional): Current process's rank. Defaults to 0. + """ + if rank == 0: + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("After freezing the model:") + print(f"--> {config.model_name} has {trainable_params / 1e6} Million trainable params\n") + + module_states = {} + # Iterate over all parameters + for name, param in model.named_parameters(): + # Extract the top-level module name (e.g., "vision_model", "language_model") + top_module = name.split(".")[0] + + # Initialize a record for the top-level module + if top_module not in module_states: + module_states[top_module] = {"frozen": [], "unfrozen": []} + + # Group parameters into frozen or unfrozen + if param.requires_grad: + module_states[top_module]["unfrozen"].append(name) + else: + module_states[top_module]["frozen"].append(name) + + print("--> Model state after freezing:") + # Analyze and print the results + for module, states in module_states.items(): + frozen_params = states["frozen"] + unfrozen_params = states["unfrozen"] + + if frozen_params and unfrozen_params: + # Mixed state: both frozen and unfrozen parameters + print(f" {module}: Mixed") + elif frozen_params: + # All parameters are frozen + print(f" {module}: Frozen") + else: + # All parameters are unfrozen + print(f" {module}: Unfrozen") + print("") def get_policies(cfg, rank): """Get the policies for mixed precision and fsdp wrapping"""