Skip to content
Snippets Groups Projects
Commit 79dbe05a authored by Kai Wu's avatar Kai Wu
Browse files

batch fine-tuning lmm working

parent ce299b34
Branches
Tags
No related merge requests found
......@@ -23,7 +23,7 @@ def tokenize_dialogs(dialogs, images, processor):
text_prompt = processor.apply_chat_template(dialogs)
#print("text_prompt",text_prompt)
batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
batch["labels"] = copy.copy(batch["input_ids"])
label_list = []
for i in range(len(batch["input_ids"])):
dialog_tokens = batch["input_ids"][i].tolist()
labels = copy.copy(dialog_tokens)
......@@ -42,14 +42,62 @@ def tokenize_dialogs(dialogs, images, processor):
# Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
assistant_header_seq = [128006, 78191, 128007]
labels = replace_target(assistant_header_seq,labels)
batch["labels"][i] = torch.tensor(labels)
label_list.append(labels)
batch["labels"] = torch.tensor(label_list)
tokenizer_length = len(processor.tokenizer)
return batch
def tokenize_dialog(dialog, images, processor):
# If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
text_prompt = processor.apply_chat_template(dialog)
#print("text_prompt",text_prompt)
batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
labels = copy.copy(batch["input_ids"].tolist()[0])
eot_indices = [i for i,n in enumerate(labels) if n == 128009]
last_idx = 0
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
for n, idx in enumerate(eot_indices):
current_seq = labels[last_idx:idx+1]
if check_header(prompt_header_seqs,current_seq):
# found prompt header, indicating that this seq should be masked
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
else:
last_idx = idx+1
# Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
assistant_header_seq = [128006, 78191, 128007]
labels = replace_target(assistant_header_seq,labels)
#print("labels",labels)
# print("pixel_values .shape",batch["pixel_values"].shape)
# print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
batch["labels"] = torch.tensor(labels)
# exit()
# combined_tokens = {
# # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
# # "labels": list(itertools.chain(*(t for t in labels_tokens))),
# "input_ids": dialog_tokens,
# "labels": labels,
# "attention_mask": [1]*len(dialog_tokens),
# "pixel_values": batch["pixel_values"],
# "aspect_ratio_ids": batch["aspect_ratio_ids"],
# "aspect_ratio_mask": batch["aspect_ratio_mask"],
# "cross_attention_mask": batch["cross_attention_mask"]
# }
# input_ids = list(itertools.chain(*(t for t in dialog_tokens))),
# labels = list(itertools.chain(*(t for t in labels_tokens))),
# attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
# pixel_values = batch["pixel_values"],
# image_sizes = batch["image_sizes"]
# print("combined_tokens",combined_tokens[image_sizes])
return batch
def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
# load_dataset will return DatasetDict that contains all the data in the train set
dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
dataset = dataset_dict[split]
dataset = dataset.select(range(100))
dataset = dataset.select(range(500))
return dataset
class VQADataCollator:
......@@ -74,5 +122,20 @@ class VQADataCollator:
dialogs.append(dialog)
images.append(image)
return tokenize_dialogs(dialogs,images, self.processor)
def __callworking__(self, samples):
for sample in samples:
image,sample_text = sample["images"],sample["messages"]
dialog = []
for line in sample_text:
content = []
messages = line["content"]
role = line["role"]
for message in messages:
if message["type"] == "image":
content.append({"type": "image"})
elif message["type"] == "text":
content.append({"type": "text", "text": message["text"].strip()})
dialog.append({"role": role,"content":content})
return tokenize_dialog(dialog,image, self.processor)
def get_data_collator(processor):
return VQADataCollator(processor)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
import copy
from datasets import load_dataset
import itertools
import torch
# check system prompt token seq or user prompt token seq is in the current token list
def check_header(targets,seq):
for i in range(len(seq)-3):
if seq[i:i+3] in targets:
return True
return False
def replace_target(target,seq):
for i in range(len(seq)-3):
if seq[i:i+3] == target:
seq[i],seq[i+1],seq[i+2] = -100,-100,-100
return seq
def tokenize_dialog(dialog, images, processor):
# If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
text_prompt = processor.apply_chat_template(dialog)
#print("text_prompt",text_prompt)
batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
labels = copy.copy(batch["input_ids"].tolist()[0])
eot_indices = [i for i,n in enumerate(labels) if n == 128009]
last_idx = 0
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
for n, idx in enumerate(eot_indices):
current_seq = labels[last_idx:idx+1]
if check_header(prompt_header_seqs,current_seq):
# found prompt header, indicating that this seq should be masked
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
else:
last_idx = idx+1
# Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
assistant_header_seq = [128006, 78191, 128007]
labels = replace_target(assistant_header_seq,labels)
#print("labels",labels)
# print("pixel_values .shape",batch["pixel_values"].shape)
# print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
batch["labels"] = torch.tensor(labels)
#pixel_values .shape torch.Size([1, 1, 4, 3, 560, 560])
batch["pixel_values"] = torch.squeeze(batch["pixel_values"], 1)
# pixel_values .shape torch.Size([1, 4, 3, 560, 560])
print("pixel_values .shape",batch["pixel_values"].shape)
# exit()
# combined_tokens = {
# # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
# # "labels": list(itertools.chain(*(t for t in labels_tokens))),
# "input_ids": dialog_tokens,
# "labels": labels,
# "attention_mask": [1]*len(dialog_tokens),
# "pixel_values": batch["pixel_values"],
# "aspect_ratio_ids": batch["aspect_ratio_ids"],
# "aspect_ratio_mask": batch["aspect_ratio_mask"],
# "cross_attention_mask": batch["cross_attention_mask"]
# }
# input_ids = list(itertools.chain(*(t for t in dialog_tokens))),
# labels = list(itertools.chain(*(t for t in labels_tokens))),
# attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
# pixel_values = batch["pixel_values"],
# image_sizes = batch["image_sizes"]
# print("combined_tokens",combined_tokens[image_sizes])
return batch
def image_tokenize(sample, processor):
processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
images,sample_text = sample["images"],sample["messages"]
dialog = []
for line in sample_text:
content = []
messages = line["content"]
role = line["role"]
for message in messages:
if message["type"] == "image":
content.append({"type": "image"})
elif message["type"] == "text":
content.append({"type": "text", "text": message["text"].strip()})
dialog.append({"role": role,"content":content})
return tokenize_dialog(dialog,images, processor)
def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
# load_dataset will return DatasetDict that contains all the data in the train set
dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
dataset = dataset_dict[split]
dataset = dataset.select(range(100))
tokenized_datasets = dataset.map(lambda x: image_tokenize(x, processor))
tokenized_datasets = tokenized_datasets.remove_columns(dataset.column_names)
return tokenized_datasets
......@@ -20,7 +20,7 @@ an example of the dataset looks like this:
Full-finetune
```bash
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 1 --batch_size_training 1 --model_name llava-hf/llama3-llava-next-8b-hf --dist_checkpoint_root_folder /home/kaiwu/work/fb_connect/finetune_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py" --use-wandb --run_validation True
torchrun --nnodes 1 --nproc_per_node 8 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --context_length 8192 --num_epochs 3 --batch_size_training 2 --model_name nltpt/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder /home/kaiwu/work/fb_connect/finetune_11bmodel --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/vqa_dataset.py" --run_validation True --batching_strategy padding --use-wandb
```
LoRA:
......
......@@ -273,6 +273,8 @@ def main(**kwargs):
dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
if custom_data_collator:
val_dl_kwargs["collate_fn"] = custom_data_collator
eval_dataloader = torch.utils.data.DataLoader(
dataset_val,
......
......@@ -146,16 +146,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
else:
batch[key] = batch[key].to(local_rank)
else:
if is_xpu_available():
batch[key] = batch[key].to('xpu:0')
else:
elif torch.cuda.is_available():
batch[key] = batch[key].to('cuda:0')
with autocast():
assert(next(model.parameters()).device == batch['input_ids'].device)
#print("batch: ", batch)
pixel_values = batch['pixel_values']
print("pixel_values.shape input",pixel_values.shape)
loss = model(**batch).loss
loss = loss / gradient_accumulation_steps
#print("loss",loss)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment