diff --git a/recipes/quickstart/finetuning/datasets/vqa_dataset.py b/recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py similarity index 95% rename from recipes/quickstart/finetuning/datasets/vqa_dataset.py rename to recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py index 05c0f596481e4947aaece85c1eeb7c06c30ace6e..ba0c5f6a4ecdba2c1fae3d4e4cc7b56668fdb860 100644 --- a/recipes/quickstart/finetuning/datasets/vqa_dataset.py +++ b/recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py @@ -48,18 +48,19 @@ def tokenize_dialogs(dialogs, images, processor): labels[i] = -100 label_list.append(labels) batch["labels"] = torch.tensor(label_list) - tokenizer_length = len(processor.tokenizer) return batch def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9): # load_dataset will return DatasetDict that contains all the data in the train set - dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ai2d") + dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa") dataset = dataset_dict['train'] + # Comment out the following line to use the full dataset, for quick testing only use 2000 samples + dataset = dataset.select(range(2000)) dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)[split] return dataset -class VQADataCollator: +class OCRVQADataCollator: def __init__(self, processor): self.processor = processor self.processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right @@ -88,4 +89,4 @@ class VQADataCollator: images.append([image]) return tokenize_dialogs(dialogs,images, self.processor) def get_data_collator(processor): - return VQADataCollator(processor) + return OCRVQADataCollator(processor) diff --git a/recipes/quickstart/finetuning/finetune_vision_model.md b/recipes/quickstart/finetuning/finetune_vision_model.md index 91366a9f3cef16f8b17832d9e6e9cef769dfc15c..63983f2a694411680f01398cd0a2678d63b21a59 100644 --- a/recipes/quickstart/finetuning/finetune_vision_model.md +++ b/recipes/quickstart/finetuning/finetune_vision_model.md @@ -1,27 +1,20 @@ ## Fine-Tuning Meta Llama Multi Modal Models recipe -This recipe steps you through how to finetune a Llama 3.2 vision model on the VQA task using the [the_cauldron](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron) dataset. - -### Concepts -Model Architecture -Our Meta Llama 3.2 11B and 90B models consist of two main components: (1) an image encoder, (2) an image adapter. - -[Model Architecture PICTURE] - -We need have a new processor class added, that will handle the image processing and text tokenization. A processor example looks like this: - - +This recipe steps you through how to finetune a Llama 3.2 vision model on the VQA task using the [OCRVQA](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron/viewer/ocrvqa?row=0) dataset. ### Fine-tuning steps +We created an example script [ocrvqa_dataset.py](./datasets/ocrvqa_dataset.py) that can load the OCRVQA dataset with `get_custom_dataset` function, then provide OCRVQADataCollator class to process the image dataset. For **full finetuning with FSDP**, we can run the following code: + ```bash - torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py" --run_validation True --batching_strategy padding + torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py" --run_validation True --batching_strategy padding ``` For **LoRA finetuning with FSDP**, we can run the following code: + ```bash - torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py" --run_validation True --batching_strategy padding --use_peft --peft_method lora + torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py" --run_validation True --batching_strategy padding --use_peft --peft_method lora ``` **Note**: `--batching_strategy padding` is needed as the vision model will not work with `packing` method. @@ -29,7 +22,10 @@ For more details about the finetuning configurations, please read the [finetunin ### How to use custom dataset to fine-tune vision model -1. Create a new dataset python file under `recipes/quickstart/finetuning/dataset` folder -2. In this python file, you need to define a `get_custom_dataset(dataset_config, processor, split, split_ratio=0.9)` function that handles the dataloading. -3. In this python file, you need to define a `get_data_collator(processor)` that returns a custom data collartor that can be used by the Pytorch Data Loader. +In order to use a custom dataset, please follow the steps below: + +1. Create a new dataset python file under `recipes/quickstart/finetuning/dataset` folder. +2. In this python file, you need to define a `get_custom_dataset(dataset_config, processor, split, split_ratio=0.9)` function that handles the data loading. +3. In this python file, you need to define a `get_data_collator(processor)` that returns a custom data collator that can be used by the Pytorch Data Loader. 4. This custom data collator class must have a `__call__(self, samples)` function that converts the image and text samples into the actual inputs that vision model expects. +5. Run the `torchrun` commend from above section, please change the `--custom_dataset.file` to the new dataset python file, adjust the learning rate accordingly. diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 1587611b3a2888b400add200d9dde017e0961cdd..609a94200eb356d1f021b09359d86cf437910e4f 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -14,11 +14,6 @@ from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ShardingStrategy ) -from torch.distributed.fsdp.wrap import ( - always_wrap_policy, - ModuleWrapPolicy, - transformer_auto_wrap_policy, -) from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload from torch.optim.lr_scheduler import StepLR from transformers import ( @@ -26,7 +21,6 @@ from transformers import ( AutoTokenizer, BitsAndBytesConfig, LlamaForCausalLM, - LlamaConfig, AutoProcessor, MllamaForConditionalGeneration ) @@ -152,7 +146,8 @@ def main(**kwargs): raise ValueError(f"Model type {config.model_type} is not supported. Please use llama or mllama model.") # Load the tokenizer and add special tokens tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name) - tokenizer.pad_token_id = tokenizer.eos_token_id + if not tokenizer.pad_token_id: + tokenizer.pad_token_id = tokenizer.eos_token_id # If there is a mismatch between tokenizer vocab size and embedding matrix, # throw a warning and then expand the embedding matrix diff --git a/src/llama_recipes/utils/config_utils.py b/src/llama_recipes/utils/config_utils.py index 9916f8975cbd98c642ff326a9b5a94e623334877..c5f4976d762dd963fc93eb6582ccfe3c64594e59 100644 --- a/src/llama_recipes/utils/config_utils.py +++ b/src/llama_recipes/utils/config_utils.py @@ -14,7 +14,6 @@ from peft import ( ) from transformers import default_data_collator from transformers.data import DataCollatorForSeq2Seq -from functools import partial from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index d7f3d9dcab7ebab9a93ad392cc397601462d174a..dec024520c2f490498df79178b3fa1baa842a319 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -360,7 +360,6 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb # Ensure no gradients are computed for this scope to save memory with torch.no_grad(): # Forward pass and compute loss - #outputs = model(**batch,use_cache=False) outputs = model(**batch) loss = outputs.loss if train_config.save_metrics: