diff --git a/finetune_retrieval.py b/finetune_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..1282a4d262d42def7b14153d3378bd104b8067c1 --- /dev/null +++ b/finetune_retrieval.py @@ -0,0 +1,233 @@ +# here put the import lib +import os +import sys +import math +import random +import torch +import torch.nn.functional as F +import argparse +import numpy as np + +import mpu +from arguments import get_args +from model.retrieval_model import RetrievalModel +from training.deepspeed_training import training_main +from data_utils import BinaryDataset +from tokenization import get_tokenizer +from tokenization.cogview import TextCodeTemplate + +def get_masks_and_position_ids(data, + n_pads, + attention_mask=None, args=None): + # Extract batch size and sequence length. + batch_size, seq_length = data.size() + + layout = args.layout + + # Attention mask + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length, seq_length), device=data.device) + attention_mask.tril_() + + attention_mask[:, :args.layout[0] - 5, args.layout[0] - 5:] = 0 # 0 attention for txt to img + attention_mask[:, args.layout[0] - 5:-2, :args.layout[0] - 5] = 0 # 0 attention for img to txt + + attention_mask[:, -1, args.layout[0] - 5:-1] = 0 # attention for txt retrieval + attention_mask[:, -2, :args.layout[0] - 5] = 0 # attention for img retrieval + + for i in range(batch_size): # 0 attention for padding + attention_mask[i, :n_pads[i], :] = 0 + attention_mask[i, :, :n_pads[i]] = 0 + + attention_mask.unsqueeze_(1) + + # Position ids. + position_ids = torch.zeros(batch_size, seq_length, dtype=torch.long, + device=data.device) + for i in range(batch_size): + torch.arange(layout[1] - n_pads[i], out=position_ids[i, n_pads[i]:layout[1]], + dtype=torch.long, device=data.device) + position_ids[:, -1] = 0 + position_ids[:, -2] = 1 + + return attention_mask, position_ids + + +def get_batch(data_iterator, args, timers): + # Items and their type. + keys = ['text', 'n_pads'] + datatype = torch.int64 + + # Broadcast data. + timers('data loader').start() + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + timers('data loader').stop() + + data_b = mpu.broadcast_data(keys, data, datatype) + # Unpack. + tokens = data_b['text'].long() + n_pads = data_b['n_pads'].long() + + attention_mask = None + + # Get the masks and postition ids. + attention_mask, position_ids = get_masks_and_position_ids( + tokens, + n_pads, + attention_mask=attention_mask, + args=args + ) + # Convert + if args.fp16: + attention_mask = attention_mask.half() + + return tokens, attention_mask, position_ids, n_pads + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, '{} is not divisible by {}'.format( + numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + +def parallel_contrastive_loss(lvec, rvec, args): + # if mpu.get_data_parallel_rank() == 0: + # print('lvec', lvec) + # print('rvec', rvec) + device = args.device + temp = args.retrieval_temp + lvec = lvec * np.exp(temp) + + rank = mpu.get_data_parallel_rank() + world_size = mpu.get_data_parallel_world_size() + + batch_size_per_partition, hidden_size = lvec.shape + batch_size = batch_size_per_partition * world_size + split_start, split_end = rank * batch_size_per_partition, (rank + 1) * batch_size_per_partition + arange_1d = torch.arange(0, batch_size, device=lvec.device) + + # Broadcast vecs + broadcast_lvec, broadcast_rvec = lvec.detach(), rvec.detach() + parallel_lvec = [torch.zeros(batch_size_per_partition, hidden_size, device=torch.device(device)).half() for _ in range(world_size)] + parallel_rvec = [torch.zeros(batch_size_per_partition, hidden_size, device=torch.device(device)).half() for _ in range(world_size)] + + torch.distributed.all_gather(parallel_lvec, broadcast_lvec, group=mpu.get_data_parallel_group()) + torch.distributed.all_gather(parallel_rvec, broadcast_rvec, group=mpu.get_data_parallel_group()) + + parallel_lvec = torch.cat(parallel_lvec, dim=0) + parallel_rvec = torch.cat(parallel_rvec, dim=0) + + # if mpu.get_data_parallel_rank() == 0: + # print('parallel vec', parallel_lvec, parallel_rvec) + + # Calculate logits + local_local_logits = lvec @ rvec.permute(1, 0) + local_dist_logits = lvec @ parallel_rvec.permute(1, 0) + dist_local_logits = parallel_lvec @ rvec.permute(1, 0) + + # if mpu.get_data_parallel_rank() == 0: + # print('local_dist_logits', local_dist_logits) + + # Broadcast logits + broadcast_local_dist_logits = local_dist_logits.detach() + parallel_logits = [torch.zeros(batch_size_per_partition, batch_size, device=torch.device(device)).half() for _ in range(world_size)] + torch.distributed.all_gather(parallel_logits, broadcast_local_dist_logits, group=mpu.get_data_parallel_group()) + parallel_logits = torch.cat(parallel_logits, dim=0) + + # Fill in local logits for backward + parallel_logits[split_start:split_end, :] = local_dist_logits + parallel_logits[:, split_start:split_end] = dist_local_logits + parallel_logits[split_start:split_end, split_start:split_end] = local_local_logits + + predicted_logits = parallel_logits[arange_1d, arange_1d] + + # Calculate left2right loss + left_logits_max = torch.max(parallel_logits, dim=1)[0] + left_logits = parallel_logits.sub(left_logits_max.unsqueeze(dim=1)) + left_exp_logits = left_logits.exp() + left_sum_exp_logits = left_exp_logits.sum(dim=1) + left_loss = torch.log(left_sum_exp_logits) - predicted_logits # Loss = log(sum(exp(logits))) - predicted-logit. + left_loss = left_loss.sum() + + # Calculate right2left loss + parallel_logits_t = parallel_logits.permute(1, 0) + right_logits_max = torch.max(parallel_logits_t, dim=1)[0] + right_logits = parallel_logits_t.sub(right_logits_max.unsqueeze(dim=1)) + right_exp_logits = right_logits.exp() + right_sum_exp_logits = right_exp_logits.sum(dim=1) + right_loss = torch.log(right_sum_exp_logits) - predicted_logits # Loss = log(sum(exp(logits))) - predicted-logit. + right_loss = right_loss.sum() + + total_loss = (left_loss + right_loss) / 2 + + # if mpu.get_data_parallel_rank() == 0: + # print('parallel', parallel_logits) + # print('loss', total_loss, left_loss, right_loss) + + return total_loss, left_loss, right_loss + +def forward_step(data_iterator, model, args, timers): + """Forward step.""" + + # Get the batch. + timers('batch generator').start() + tokens, attention_mask, position_ids, n_pads = get_batch( + data_iterator, args, timers) + timers('batch generator').stop() + + # Forward model. + (txt_vecs, img_vecs), *mems = model(tokens, position_ids, attention_mask) + + # L2 Normalize + txt_vecs = txt_vecs / txt_vecs.pow(2).sum(dim=1).sqrt().unsqueeze(1) + img_vecs = img_vecs / img_vecs.pow(2).sum(dim=1).sqrt().unsqueeze(1) + + loss, txt2img_loss, img2txt_loss = parallel_contrastive_loss(txt_vecs, img_vecs, args) + # loss, txt2img_loss, img2txt_loss = torch.tensor(0.), torch.tensor(0.), torch.tensor(0.) + + return loss, {'txt2img_loss': txt2img_loss, 'img2txt_loss': img2txt_loss} + +def create_dataset_function(path, args): + tokenizer = get_tokenizer(args) + # layout = args.layout + layout = [64, 1088] + def process_fn(row): + row = row.astype(np.int64) + codes = [row[layout[i-1]:layout[i]] for i in range(1, len(layout))] + + text = row[:layout[0]] + text = text[text > 0][:layout[0] - 6] + n_pad = layout[0] - 6 - len(text) + merged = TextCodeTemplate(text, codes[0], tokenizer) + parts = [ + np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64), + merged, + np.array([tokenizer['[POS0]'], tokenizer['[POS1]']], dtype=np.int64) + ] + ret = np.concatenate(parts, axis=0) + return {'text': ret, + 'n_pads': n_pad} + return BinaryDataset(path, process_fn, length_per_sample=layout[-1]) + + +if __name__ == '__main__': + py_parser = argparse.ArgumentParser(add_help=False) + + RetrievalModel.add_model_specific_args(py_parser) + + known, args_list = py_parser.parse_known_args() + + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + args.layout = [int(x) for x in args.layout.split(',')] + + training_main(args, model_cls=RetrievalModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function) \ No newline at end of file diff --git a/model/mixins.py b/model/mixins.py index 78befb8b27637deb7bf7c98bdcfde408eb28b483..618e7910d924707a1967008350a3c2821ea3e6d0 100644 --- a/model/mixins.py +++ b/model/mixins.py @@ -67,3 +67,40 @@ class AttentionMixin(BaseMixin): self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data) self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data) self.dense[layer_id].bias.data.copy_(old_attention.dense.bias.data) + +class ParallelLinearMixin(BaseMixin): + def __init__(self, input_size, output_size, bias=True, + init_method=torch.nn.init.xavier_normal_, stride=1, + keep_master_weight_for_test=False): + super(ParallelLinearMixin, self).__init__() + self.input_size = input_size + self.output_size = output_size + + self.parallel_linear = ColumnParallelLinear( + input_size, output_size, bias=bias, gather_output=True, + init_method=init_method, stride=stride, + keep_master_weight_for_test=keep_master_weight_for_test) + + def forward(self, input_): + return self.parallel_linear(input_) + +class ParallelDoubleLayerLinearMixin(BaseMixin): + def __init__(self, input_size, hidden_size, output_size, bias=True, + init_method=torch.nn.init.xavier_normal_, stride=1, + keep_master_weight_for_test=False): + super(ParallelDoubleLayerLinearMixin, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + + self.column_parallel_linear = ColumnParallelLinear( + input_size, hidden_size, bias=bias, gather_output=False, + init_method=init_method, stride=stride, + keep_master_weight_for_test=keep_master_weight_for_test) + self.row_parallel_linear = RowParallelLinear( + hidden_size, output_size, bias=bias, + init_method=init_method, stride=stride, + keep_master_weight_for_test=keep_master_weight_for_test) + + def forward(self, input_): + return self.row_parallel_linear(self.column_parallel_linear(input_)) \ No newline at end of file diff --git a/model/retrieval_model.py b/model/retrieval_model.py new file mode 100644 index 0000000000000000000000000000000000000000..96fea46e790cb3b43ea48308cb2b26bbce4a0fdc --- /dev/null +++ b/model/retrieval_model.py @@ -0,0 +1,65 @@ +# -*- encoding: utf-8 -*- +# here put the import lib +import os +import sys +import math +import random +import torch +import torch.nn.functional as F + +from .base_model import BaseModel +from .mixins import PositionEmbeddingMixin, AttentionMixin, ParallelLinearMixin + +from mpu.transformer import split_tensor_along_last_dim +from mpu.utils import sqrt +from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker + + +class RetrievalModel(BaseModel): + def __init__(self, args, transformer=None): + super().__init__(args, transformer=transformer) + self.layout = args.layout + self.txt_img_split = args.txt_img_split + + self.mixins.append(PositionEmbeddingMixin( + 2, args.hidden_size + )) + self.mixins.extend([ + ParallelLinearMixin( + args.hidden_size, args.retrieval_size), + ParallelLinearMixin( + args.hidden_size, args.retrieval_size) + ]) + + def reinit(self): + pass + + def position_embedding_forward(self, position_ids, *other_tensors): + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position_ids[:, :-2]), + self.mixins[0].position_embeddings(position_ids[:, -2:]) + ), + dim=-2 + ) + return position_embeddings + + def final_forward(self, logits, *other_tensors): + txt_logits = logits[:, -1, :] + img_logits = logits[:, -2, :] + return (self.mixins[1](txt_logits), self.mixins[2](img_logits)) + + def disable_untrainable_params(self): + pass + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('RetrievalModel', 'retrieval model configurations') + group.add_argument('--txt-img-split', action='store_true') + group.add_argument('--retrieval-temp', type=int, default=0) + group.add_argument('--retrieval-mode', type=str, default='txt2img', + choices=['txt2img', 'img2txt', 'symmetric']) + group.add_argument('--retrieval-hidden-size', type=int, default=2048) + group.add_argument('--retrieval-size', type=int, default=1024) + group.add_argument("--layout", type=str, default='64,1088') + return parser \ No newline at end of file diff --git a/scripts/finetune_into_retrieval.sh b/scripts/finetune_into_retrieval.sh new file mode 100755 index 0000000000000000000000000000000000000000..c0a38d2af5091cede59757fe0f47e1103c77a389 --- /dev/null +++ b/scripts/finetune_into_retrieval.sh @@ -0,0 +1,59 @@ +#! /bin/bash + +# Change for multinode config + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=8 +MP_SIZE=1 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +main_dir=$(dirname $script_dir) + +OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" +HOST_FILE_PATH="hostfile" +HOST_FILE_PATH="hostfile_single" + +# full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin" +small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_1/ali/ali.bin.part_0.cogdata" + +config_json="$script_dir/ds_config_zero.json" +gpt_options=" \ + --experiment-name finetune-retrieval-test \ + --tokenizer-type cogview \ + --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ + --model-parallel-size ${MP_SIZE} \ + --mode finetune \ + --batch-size 4 \ + --num-layers 48 \ + --hidden-size 2560 \ + --num-attention-heads 40 \ + --train-iters 200000 \ + --resume-dataloader \ + --train-data ${small_data} \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr-decay-style cosine \ + --warmup .1 \ + --checkpoint-activations \ + --max-sequence-length 1089 \ + --sandwich-ln \ + --fp16 \ + --save-interval 2000 \ + --eval-interval 1000 \ + --save $main_dir/checkpoints \ + --load /workspace/dm/SwissArmyTransformer/pretrained/cogview/cogview-base +" + # --load pretrained/cogview/cogview-base + + +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + +run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} finetune_retrieval.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/training/deepspeed_training.py b/training/deepspeed_training.py index 90190a6f95eb775808cc058b9f912d3a9b6555d2..2d74dbf8fa7ffd278d109eba03853cd3ad895f6a 100644 --- a/training/deepspeed_training.py +++ b/training/deepspeed_training.py @@ -335,7 +335,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, # Check nan or inf in forward, preventing it from interfering loss scaler, # and all reduce metrics by the way - loss_checker = lm_loss.detach() + loss_checker = lm_loss.clone().detach() for name in metrics: metrics[name] = metrics[name].detach() torch.distributed.all_reduce(metrics[name].data)