From bd22f407d5472bf935a7cb6ff22934dbdc2a9205 Mon Sep 17 00:00:00 2001
From: Kai Wu <kaiwu@meta.com>
Date: Mon, 23 Sep 2024 19:21:14 -0700
Subject: [PATCH] changed to aid2 dataset

---
 .../finetuning/datasets/vqa_dataset.py        | 96 ++++---------------
 .../finetuning/finetune_vision_model.md       | 24 +++--
 src/llama_recipes/finetuning.py               | 14 ++-
 src/llama_recipes/utils/train_utils.py        |  7 +-
 4 files changed, 44 insertions(+), 97 deletions(-)

diff --git a/recipes/quickstart/finetuning/datasets/vqa_dataset.py b/recipes/quickstart/finetuning/datasets/vqa_dataset.py
index 50772bb5..05c0f596 100644
--- a/recipes/quickstart/finetuning/datasets/vqa_dataset.py
+++ b/recipes/quickstart/finetuning/datasets/vqa_dataset.py
@@ -51,57 +51,12 @@ def tokenize_dialogs(dialogs, images, processor):
     tokenizer_length = len(processor.tokenizer)
     return batch
 
-def tokenize_dialog(dialog, images, processor):
-    # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
-    text_prompt = processor.apply_chat_template(dialog)
-    #print("text_prompt",text_prompt)
-    batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")    
-    labels = copy.copy(batch["input_ids"].tolist()[0])
-    eot_indices = [i for i,n in enumerate(labels) if n == 128009]
-    last_idx = 0
-    # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
-    # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
-    prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
-    for n, idx in enumerate(eot_indices):
-        current_seq = labels[last_idx:idx+1]
-        if check_header(prompt_header_seqs,current_seq):
-            # found prompt header, indicating that this seq should be masked
-            labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
-        else:
-            last_idx = idx+1
-        # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
-    assistant_header_seq = [128006, 78191, 128007]
-    labels = replace_target(assistant_header_seq,labels)
-    #print("labels",labels)
-    # print("pixel_values .shape",batch["pixel_values"].shape)
-    # print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
 
-    batch["labels"] = torch.tensor(labels)
-    # exit()
-    # combined_tokens = {
-    #     # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
-    #     # "labels": list(itertools.chain(*(t for t in labels_tokens))),
-    #     "input_ids": dialog_tokens,
-    #     "labels": labels,
-    #     "attention_mask": [1]*len(dialog_tokens),
-    #     "pixel_values": batch["pixel_values"],
-    #     "aspect_ratio_ids": batch["aspect_ratio_ids"],
-    #     "aspect_ratio_mask": batch["aspect_ratio_mask"],
-    #     "cross_attention_mask": batch["cross_attention_mask"]
-    # }
-    # input_ids =  list(itertools.chain(*(t for t in dialog_tokens))),
-    # labels = list(itertools.chain(*(t for t in labels_tokens))),
-    # attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
-    # pixel_values =  batch["pixel_values"],
-    # image_sizes = batch["image_sizes"]
-#    print("combined_tokens",combined_tokens[image_sizes])
-    
-    return batch
 def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
     # load_dataset will return DatasetDict that contains all the data in the train set
-    dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
-    dataset = dataset_dict[split]
-    dataset = dataset.select(range(500))
+    dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ai2d")
+    dataset = dataset_dict['train']
+    dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)[split]
     return dataset
 
 class VQADataCollator:
@@ -111,35 +66,26 @@ class VQADataCollator:
     def __call__(self, samples):
         dialogs,images = [],[]
         for sample in samples:
-            image,sample_text = sample["images"],sample["messages"]
+            image_list,sample_list = sample["images"],sample["texts"]
+            if len(image_list) > 1:
+                raise ValueError("Only support one image per sample")
+            image = image_list[0].convert("RGB") # only use the first image
             dialog = []
-            for line in sample_text:
-                content = []
-                messages = line["content"]
-                role = line["role"]
-                for message in messages:
-                    if message["type"] == "image":
-                        content.append({"type": "image"})
-                    elif message["type"] == "text":
-                        content.append({"type": "text", "text": message["text"].strip()})
-                dialog.append({"role": role,"content":content})
+            for sample_dict in sample_list:
+                if not dialog:
+                    # only append image to the first sentence
+                    dialog += [
+                    {"role":"user","content":[{"type": "image"},{"type": "text", "text": sample_dict["user"].strip()}]},
+                    {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
+                ]
+                
+                else:
+                    dialog += [
+                    {"role":"user","content":[{"type": "text", "text": sample_dict["user"].strip()}]},
+                    {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
+                ]
             dialogs.append(dialog)
-            images.append(image)
+            images.append([image])
         return tokenize_dialogs(dialogs,images, self.processor)
-    def __callworking__(self, samples):
-        for sample in samples:
-            image,sample_text = sample["images"],sample["messages"]
-            dialog = []
-            for line in sample_text:
-                content = []
-                messages = line["content"]
-                role = line["role"]
-                for message in messages:
-                    if message["type"] == "image":
-                        content.append({"type": "image"})
-                    elif message["type"] == "text":
-                        content.append({"type": "text", "text": message["text"].strip()})
-                dialog.append({"role": role,"content":content})
-            return tokenize_dialog(dialog,image, self.processor)
 def get_data_collator(processor):
     return VQADataCollator(processor)
diff --git a/recipes/quickstart/finetuning/finetune_vision_model.md b/recipes/quickstart/finetuning/finetune_vision_model.md
index ad76cdf1..91366a9f 100644
--- a/recipes/quickstart/finetuning/finetune_vision_model.md
+++ b/recipes/quickstart/finetuning/finetune_vision_model.md
@@ -1,5 +1,5 @@
 ## Fine-Tuning Meta Llama Multi Modal Models recipe
-Here we discuss fine-tuning Meta Llama 3.2 11B and 90B models.
+This recipe steps you through how to finetune a Llama 3.2 vision model on the VQA task using the [the_cauldron](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron) dataset.
 
 ### Concepts
 Model Architecture
@@ -12,18 +12,24 @@ We need have a new processor class added, that will handle the image processing
 
 
 ### Fine-tuning steps
-1. Download the dataset:
-an example of the dataset looks like this:
-2. Processor example looks like this
 
-3. Load the dataset
 
-Full-finetune
+For **full finetuning with FSDP**, we can run the following code:
 ```bash
-  torchrun --nnodes 1 --nproc_per_node 8  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 3 --batch_size_training 2 --model_name nltpt/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder /home/kaiwu/work/fb_connect/finetune_11bmodel --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py"  --run_validation True --batching_strategy padding  --use-wandb
+  torchrun --nnodes 1 --nproc_per_node 4  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py"  --run_validation True --batching_strategy padding
 ```
 
-LoRA:
+For **LoRA finetuning with FSDP**, we can run the following code:
 ```bash
-  torchrun --nnodes 1 --nproc_per_node 4  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 1 --batch_size_training 1 --model_name llava-hf/llama3-llava-next-8b-hf --dist_checkpoint_root_folder /home/kaiwu/work/fb_connect/finetune_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py" --use-wandb  --run_validation True
+  torchrun --nnodes 1 --nproc_per_node 4  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py"  --run_validation True --batching_strategy padding  --use_peft --peft_method lora
 ```
+**Note**: `--batching_strategy padding` is needed as the vision model will not work with `packing` method.
+
+For more details about the finetuning configurations, please read the [finetuning readme](./README.md).
+
+### How to use custom dataset to fine-tune vision model
+
+1. Create a new dataset python file under `recipes/quickstart/finetuning/dataset` folder
+2. In this python file, you need to define a `get_custom_dataset(dataset_config, processor, split, split_ratio=0.9)` function that handles the dataloading.
+3. In this python file, you need to define a `get_data_collator(processor)` that returns a custom data collartor that can be used by the Pytorch Data Loader.
+4. This custom data collator class must have a `__call__(self, samples)` function that converts the image and text samples into the actual inputs that vision model expects.
diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index cc8643ae..1587611b 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -22,6 +22,7 @@ from torch.distributed.fsdp.wrap import (
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from transformers import (
+    AutoConfig,
     AutoTokenizer,
     BitsAndBytesConfig,
     LlamaForCausalLM,
@@ -125,7 +126,8 @@ def main(**kwargs):
 
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
-    if "11B" in train_config.model_name or "90B" in train_config.model_name:
+    config = AutoConfig.from_pretrained(train_config.model_name)
+    if config.model_type == "mllama":
         is_vision = True
         model = MllamaForConditionalGeneration.from_pretrained(
         train_config.model_name,
@@ -136,7 +138,7 @@ def main(**kwargs):
     )
         processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
         processor.tokenizer.padding_side='right'
-    else:
+    elif config.model_type == "llama":
         is_vision = False
         model = LlamaForCausalLM.from_pretrained(
             train_config.model_name,
@@ -146,7 +148,8 @@ def main(**kwargs):
             device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
             torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
         )
-    print(model)
+    else:
+        raise ValueError(f"Model type {config.model_type} is not supported. Please use llama or mllama model.")
     # Load the tokenizer and add special tokens
     tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
     tokenizer.pad_token_id = tokenizer.eos_token_id
@@ -190,7 +193,6 @@ def main(**kwargs):
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
-        print("FSDP is enabled",my_auto_wrapping_policy)
         device_id = 0
         if is_xpu_available():
             device_id = torch.xpu.current_device()
@@ -218,8 +220,6 @@ def main(**kwargs):
             model.to("xpu:0")
         elif torch.cuda.is_available():
             model.to("cuda")
-    print("-------------------")
-    print("FSDP model", model)
     dataset_config = generate_dataset_config(train_config, kwargs)
     if is_vision:
         dataset_processer = processor
@@ -306,8 +306,6 @@ def main(**kwargs):
             weight_decay=train_config.weight_decay,
         )
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
-    # Start the training process
-   
     results = train(
         model,
         train_dataloader,
diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py
index 3eaf62c2..d7f3d9dc 100644
--- a/src/llama_recipes/utils/train_utils.py
+++ b/src/llama_recipes/utils/train_utils.py
@@ -132,7 +132,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             with profile(train_config,local_rank) as profile_context:
                 for step, batch in enumerate(train_dataloader):
                     total_train_steps += 1
-                    #print("batch: ", batch)
                     # stop when the maximum number of training steps is reached
                     if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
                         max_steps_reached = True
@@ -151,10 +150,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             elif torch.cuda.is_available():
                                 batch[key] = batch[key].to('cuda:0')
                     with autocast():
-                        assert(next(model.parameters()).device == batch['input_ids'].device)
                         loss = model(**batch).loss
                     loss = loss / gradient_accumulation_steps
-                    #print("loss",loss)
                     if train_config.save_metrics:
                         train_step_loss.append(loss.detach().float().item())
                         train_step_perplexity.append(float(torch.exp(loss.detach().float())))
@@ -175,7 +172,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             pbar.update(1)
                     else:
                         # regular backpropagation when fp16 is not used
-                        #print("loss123",loss)
                         loss.backward()
                         if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                             if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
@@ -364,7 +360,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
             # Ensure no gradients are computed for this scope to save memory
             with torch.no_grad():
                 # Forward pass and compute loss
-                outputs = model(**batch,use_cache=False)
+                #outputs = model(**batch,use_cache=False)
+                outputs = model(**batch)
                 loss = outputs.loss
                 if train_config.save_metrics:
                     val_step_loss.append(loss.detach().float().item())
-- 
GitLab