diff --git a/recipes/quickstart/finetuning/datasets/vqa_dataset.py b/recipes/quickstart/finetuning/datasets/vqa_dataset.py
index 38eb933bf84603886517fb03e058f15ae07c1c9f..f3ff552bedbeb6398c85996f334355b25c1ca3e8 100644
--- a/recipes/quickstart/finetuning/datasets/vqa_dataset.py
+++ b/recipes/quickstart/finetuning/datasets/vqa_dataset.py
@@ -23,7 +23,7 @@ def tokenize_dialogs(dialogs, images, processor):
     text_prompt = processor.apply_chat_template(dialogs)
     #print("text_prompt",text_prompt)
     batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
-    batch["labels"] = copy.copy(batch["input_ids"])
+    label_list = []
     for i in range(len(batch["input_ids"])):
         dialog_tokens = batch["input_ids"][i].tolist()
         labels = copy.copy(dialog_tokens)
@@ -42,14 +42,62 @@ def tokenize_dialogs(dialogs, images, processor):
             # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
         assistant_header_seq = [128006, 78191, 128007]
         labels = replace_target(assistant_header_seq,labels)
-        batch["labels"][i] = torch.tensor(labels)
+        label_list.append(labels)
+    batch["labels"] = torch.tensor(label_list)
+    tokenizer_length = len(processor.tokenizer)
     return batch
 
+def tokenize_dialog(dialog, images, processor):
+    # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
+    text_prompt = processor.apply_chat_template(dialog)
+    #print("text_prompt",text_prompt)
+    batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")    
+    labels = copy.copy(batch["input_ids"].tolist()[0])
+    eot_indices = [i for i,n in enumerate(labels) if n == 128009]
+    last_idx = 0
+    # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
+    # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
+    prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
+    for n, idx in enumerate(eot_indices):
+        current_seq = labels[last_idx:idx+1]
+        if check_header(prompt_header_seqs,current_seq):
+            # found prompt header, indicating that this seq should be masked
+            labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+        else:
+            last_idx = idx+1
+        # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
+    assistant_header_seq = [128006, 78191, 128007]
+    labels = replace_target(assistant_header_seq,labels)
+    #print("labels",labels)
+    # print("pixel_values .shape",batch["pixel_values"].shape)
+    # print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
+
+    batch["labels"] = torch.tensor(labels)
+    # exit()
+    # combined_tokens = {
+    #     # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
+    #     # "labels": list(itertools.chain(*(t for t in labels_tokens))),
+    #     "input_ids": dialog_tokens,
+    #     "labels": labels,
+    #     "attention_mask": [1]*len(dialog_tokens),
+    #     "pixel_values": batch["pixel_values"],
+    #     "aspect_ratio_ids": batch["aspect_ratio_ids"],
+    #     "aspect_ratio_mask": batch["aspect_ratio_mask"],
+    #     "cross_attention_mask": batch["cross_attention_mask"]
+    # }
+    # input_ids =  list(itertools.chain(*(t for t in dialog_tokens))),
+    # labels = list(itertools.chain(*(t for t in labels_tokens))),
+    # attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
+    # pixel_values =  batch["pixel_values"],
+    # image_sizes = batch["image_sizes"]
+#    print("combined_tokens",combined_tokens[image_sizes])
+    
+    return batch
 def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
     # load_dataset will return DatasetDict that contains all the data in the train set
     dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
     dataset = dataset_dict[split]
-    dataset = dataset.select(range(100))
+    dataset = dataset.select(range(500))
     return dataset
 
 class VQADataCollator:
@@ -74,5 +122,20 @@ class VQADataCollator:
             dialogs.append(dialog)
             images.append(image)
         return tokenize_dialogs(dialogs,images, self.processor)
+    def __callworking__(self, samples):
+        for sample in samples:
+            image,sample_text = sample["images"],sample["messages"]
+            dialog = []
+            for line in sample_text:
+                content = []
+                messages = line["content"]
+                role = line["role"]
+                for message in messages:
+                    if message["type"] == "image":
+                        content.append({"type": "image"})
+                    elif message["type"] == "text":
+                        content.append({"type": "text", "text": message["text"].strip()})
+                dialog.append({"role": role,"content":content})
+            return tokenize_dialog(dialog,image, self.processor)
 def get_data_collator(processor):
     return VQADataCollator(processor)
diff --git a/recipes/quickstart/finetuning/datasets/vqa_dataset_old.py b/recipes/quickstart/finetuning/datasets/vqa_dataset_old.py
deleted file mode 100644
index 09fa4047ab118cc810e348bdcd8937a6de05f47c..0000000000000000000000000000000000000000
--- a/recipes/quickstart/finetuning/datasets/vqa_dataset_old.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
-
-
-import copy
-from datasets import load_dataset
-import itertools
-import torch
-# check system prompt token seq or user prompt token seq is in the current token list
-def check_header(targets,seq):
-    for i in range(len(seq)-3):
-        if seq[i:i+3] in targets:
-            return True
-    return False
-def replace_target(target,seq):
-    for i in range(len(seq)-3):
-        if seq[i:i+3] == target:
-            seq[i],seq[i+1],seq[i+2] = -100,-100,-100
-    return seq
-def tokenize_dialog(dialog, images, processor):
-    # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
-    text_prompt = processor.apply_chat_template(dialog)
-    #print("text_prompt",text_prompt)
-    batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")    
-    labels = copy.copy(batch["input_ids"].tolist()[0])
-    eot_indices = [i for i,n in enumerate(labels) if n == 128009]
-    last_idx = 0
-    # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
-    # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
-    prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
-    for n, idx in enumerate(eot_indices):
-        current_seq = labels[last_idx:idx+1]
-        if check_header(prompt_header_seqs,current_seq):
-            # found prompt header, indicating that this seq should be masked
-            labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
-        else:
-            last_idx = idx+1
-        # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
-    assistant_header_seq = [128006, 78191, 128007]
-    labels = replace_target(assistant_header_seq,labels)
-    #print("labels",labels)
-    # print("pixel_values .shape",batch["pixel_values"].shape)
-    # print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
-
-    batch["labels"] = torch.tensor(labels)
-    #pixel_values .shape torch.Size([1, 1, 4, 3, 560, 560])
-    batch["pixel_values"] = torch.squeeze(batch["pixel_values"], 1)
-    # pixel_values .shape torch.Size([1, 4, 3, 560, 560])
-    print("pixel_values .shape",batch["pixel_values"].shape)
-    # exit()
-    # combined_tokens = {
-    #     # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
-    #     # "labels": list(itertools.chain(*(t for t in labels_tokens))),
-    #     "input_ids": dialog_tokens,
-    #     "labels": labels,
-    #     "attention_mask": [1]*len(dialog_tokens),
-    #     "pixel_values": batch["pixel_values"],
-    #     "aspect_ratio_ids": batch["aspect_ratio_ids"],
-    #     "aspect_ratio_mask": batch["aspect_ratio_mask"],
-    #     "cross_attention_mask": batch["cross_attention_mask"]
-    # }
-    # input_ids =  list(itertools.chain(*(t for t in dialog_tokens))),
-    # labels = list(itertools.chain(*(t for t in labels_tokens))),
-    # attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
-    # pixel_values =  batch["pixel_values"],
-    # image_sizes = batch["image_sizes"]
-#    print("combined_tokens",combined_tokens[image_sizes])
-    
-    return batch
-def image_tokenize(sample, processor):
-    processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
-    images,sample_text = sample["images"],sample["messages"]
-    dialog = []
-    for line in sample_text:
-        content = []
-        messages = line["content"]
-        role = line["role"]
-        for message in messages:
-            if message["type"] == "image":
-                content.append({"type": "image"})
-            elif message["type"] == "text":
-                content.append({"type": "text", "text": message["text"].strip()})
-        dialog.append({"role": role,"content":content})
-    return tokenize_dialog(dialog,images, processor)
-
-
-def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
-    # load_dataset will return DatasetDict that contains all the data in the train set
-    dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
-    dataset = dataset_dict[split]
-    dataset = dataset.select(range(100))
-    tokenized_datasets = dataset.map(lambda x: image_tokenize(x, processor))
-    tokenized_datasets = tokenized_datasets.remove_columns(dataset.column_names)
-    return tokenized_datasets
diff --git a/recipes/quickstart/finetuning/finetune_vision_model.md b/recipes/quickstart/finetuning/finetune_vision_model.md
index 32635f49820c24cf2f3b6a7dc453cb1f48c22d03..ad76cdf17b9139038e1f5284eb1082e570dd7d35 100644
--- a/recipes/quickstart/finetuning/finetune_vision_model.md
+++ b/recipes/quickstart/finetuning/finetune_vision_model.md
@@ -20,7 +20,7 @@ an example of the dataset looks like this:
 
 Full-finetune
 ```bash
-  torchrun --nnodes 1 --nproc_per_node 4  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 1 --batch_size_training 1 --model_name llava-hf/llama3-llava-next-8b-hf --dist_checkpoint_root_folder /home/kaiwu/work/fb_connect/finetune_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py" --use-wandb  --run_validation True
+  torchrun --nnodes 1 --nproc_per_node 8  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 3 --batch_size_training 2 --model_name nltpt/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder /home/kaiwu/work/fb_connect/finetune_11bmodel --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py"  --run_validation True --batching_strategy padding  --use-wandb
 ```
 
 LoRA:
diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index 55803ae48c8507863a3099f266e8ea9fd1d92ded..c750cdc26e84020cfca011f393722e117666e9dc 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -273,6 +273,8 @@ def main(**kwargs):
             dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
 
         val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
+        if custom_data_collator:
+            val_dl_kwargs["collate_fn"] = custom_data_collator
 
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py
index bbbfe96a88fb084ec7458f2eb4a98fbeb9eb7c21..3eaf62c2e3605556863a56b17e1e5c76f03ecf4f 100644
--- a/src/llama_recipes/utils/train_utils.py
+++ b/src/llama_recipes/utils/train_utils.py
@@ -146,16 +146,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             else:
                                 batch[key] = batch[key].to(local_rank)
                         else:
-
                             if is_xpu_available():
                                 batch[key] = batch[key].to('xpu:0')
-                            else:
+                            elif torch.cuda.is_available():
                                 batch[key] = batch[key].to('cuda:0')
                     with autocast():
                         assert(next(model.parameters()).device == batch['input_ids'].device)
-                        #print("batch: ", batch)
-                        pixel_values = batch['pixel_values']
-                        print("pixel_values.shape input",pixel_values.shape)
                         loss = model(**batch).loss
                     loss = loss / gradient_accumulation_steps
                     #print("loss",loss)