Skip to content
Snippets Groups Projects
Commit a30e32ad authored by minkowski0125's avatar minkowski0125
Browse files

set retrieval model

parent 8584bf91
No related branches found
No related tags found
No related merge requests found
# here put the import lib
import os
import sys
import math
import random
import torch
import torch.nn.functional as F
import argparse
import numpy as np
import mpu
from arguments import get_args
from model.retrieval_model import RetrievalModel
from training.deepspeed_training import training_main
from data_utils import BinaryDataset
from tokenization import get_tokenizer
from tokenization.cogview import TextCodeTemplate
def get_masks_and_position_ids(data,
n_pads,
attention_mask=None, args=None):
# Extract batch size and sequence length.
batch_size, seq_length = data.size()
layout = args.layout
# Attention mask
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=data.device)
attention_mask.tril_()
attention_mask[:, :args.layout[0] - 5, args.layout[0] - 5:] = 0 # 0 attention for txt to img
attention_mask[:, args.layout[0] - 5:-2, :args.layout[0] - 5] = 0 # 0 attention for img to txt
attention_mask[:, -1, args.layout[0] - 5:-1] = 0 # attention for txt retrieval
attention_mask[:, -2, :args.layout[0] - 5] = 0 # attention for img retrieval
for i in range(batch_size): # 0 attention for padding
attention_mask[i, :n_pads[i], :] = 0
attention_mask[i, :, :n_pads[i]] = 0
attention_mask.unsqueeze_(1)
# Position ids.
position_ids = torch.zeros(batch_size, seq_length, dtype=torch.long,
device=data.device)
for i in range(batch_size):
torch.arange(layout[1] - n_pads[i], out=position_ids[i, n_pads[i]:layout[1]],
dtype=torch.long, device=data.device)
position_ids[:, -1] = 0
position_ids[:, -2] = 1
return attention_mask, position_ids
def get_batch(data_iterator, args, timers):
# Items and their type.
keys = ['text', 'n_pads']
datatype = torch.int64
# Broadcast data.
timers('data loader').start()
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['text'].long()
n_pads = data_b['n_pads'].long()
attention_mask = None
# Get the masks and postition ids.
attention_mask, position_ids = get_masks_and_position_ids(
tokens,
n_pads,
attention_mask=attention_mask,
args=args
)
# Convert
if args.fp16:
attention_mask = attention_mask.half()
return tokens, attention_mask, position_ids, n_pads
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(
numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def parallel_contrastive_loss(lvec, rvec, args):
# if mpu.get_data_parallel_rank() == 0:
# print('lvec', lvec)
# print('rvec', rvec)
device = args.device
temp = args.retrieval_temp
lvec = lvec * np.exp(temp)
rank = mpu.get_data_parallel_rank()
world_size = mpu.get_data_parallel_world_size()
batch_size_per_partition, hidden_size = lvec.shape
batch_size = batch_size_per_partition * world_size
split_start, split_end = rank * batch_size_per_partition, (rank + 1) * batch_size_per_partition
arange_1d = torch.arange(0, batch_size, device=lvec.device)
# Broadcast vecs
broadcast_lvec, broadcast_rvec = lvec.detach(), rvec.detach()
parallel_lvec = [torch.zeros(batch_size_per_partition, hidden_size, device=torch.device(device)).half() for _ in range(world_size)]
parallel_rvec = [torch.zeros(batch_size_per_partition, hidden_size, device=torch.device(device)).half() for _ in range(world_size)]
torch.distributed.all_gather(parallel_lvec, broadcast_lvec, group=mpu.get_data_parallel_group())
torch.distributed.all_gather(parallel_rvec, broadcast_rvec, group=mpu.get_data_parallel_group())
parallel_lvec = torch.cat(parallel_lvec, dim=0)
parallel_rvec = torch.cat(parallel_rvec, dim=0)
# if mpu.get_data_parallel_rank() == 0:
# print('parallel vec', parallel_lvec, parallel_rvec)
# Calculate logits
local_local_logits = lvec @ rvec.permute(1, 0)
local_dist_logits = lvec @ parallel_rvec.permute(1, 0)
dist_local_logits = parallel_lvec @ rvec.permute(1, 0)
# if mpu.get_data_parallel_rank() == 0:
# print('local_dist_logits', local_dist_logits)
# Broadcast logits
broadcast_local_dist_logits = local_dist_logits.detach()
parallel_logits = [torch.zeros(batch_size_per_partition, batch_size, device=torch.device(device)).half() for _ in range(world_size)]
torch.distributed.all_gather(parallel_logits, broadcast_local_dist_logits, group=mpu.get_data_parallel_group())
parallel_logits = torch.cat(parallel_logits, dim=0)
# Fill in local logits for backward
parallel_logits[split_start:split_end, :] = local_dist_logits
parallel_logits[:, split_start:split_end] = dist_local_logits
parallel_logits[split_start:split_end, split_start:split_end] = local_local_logits
predicted_logits = parallel_logits[arange_1d, arange_1d]
# Calculate left2right loss
left_logits_max = torch.max(parallel_logits, dim=1)[0]
left_logits = parallel_logits.sub(left_logits_max.unsqueeze(dim=1))
left_exp_logits = left_logits.exp()
left_sum_exp_logits = left_exp_logits.sum(dim=1)
left_loss = torch.log(left_sum_exp_logits) - predicted_logits # Loss = log(sum(exp(logits))) - predicted-logit.
left_loss = left_loss.sum()
# Calculate right2left loss
parallel_logits_t = parallel_logits.permute(1, 0)
right_logits_max = torch.max(parallel_logits_t, dim=1)[0]
right_logits = parallel_logits_t.sub(right_logits_max.unsqueeze(dim=1))
right_exp_logits = right_logits.exp()
right_sum_exp_logits = right_exp_logits.sum(dim=1)
right_loss = torch.log(right_sum_exp_logits) - predicted_logits # Loss = log(sum(exp(logits))) - predicted-logit.
right_loss = right_loss.sum()
total_loss = (left_loss + right_loss) / 2
# if mpu.get_data_parallel_rank() == 0:
# print('parallel', parallel_logits)
# print('loss', total_loss, left_loss, right_loss)
return total_loss, left_loss, right_loss
def forward_step(data_iterator, model, args, timers):
"""Forward step."""
# Get the batch.
timers('batch generator').start()
tokens, attention_mask, position_ids, n_pads = get_batch(
data_iterator, args, timers)
timers('batch generator').stop()
# Forward model.
(txt_vecs, img_vecs), *mems = model(tokens, position_ids, attention_mask)
# L2 Normalize
txt_vecs = txt_vecs / txt_vecs.pow(2).sum(dim=1).sqrt().unsqueeze(1)
img_vecs = img_vecs / img_vecs.pow(2).sum(dim=1).sqrt().unsqueeze(1)
loss, txt2img_loss, img2txt_loss = parallel_contrastive_loss(txt_vecs, img_vecs, args)
# loss, txt2img_loss, img2txt_loss = torch.tensor(0.), torch.tensor(0.), torch.tensor(0.)
return loss, {'txt2img_loss': txt2img_loss, 'img2txt_loss': img2txt_loss}
def create_dataset_function(path, args):
tokenizer = get_tokenizer(args)
# layout = args.layout
layout = [64, 1088]
def process_fn(row):
row = row.astype(np.int64)
codes = [row[layout[i-1]:layout[i]] for i in range(1, len(layout))]
text = row[:layout[0]]
text = text[text > 0][:layout[0] - 6]
n_pad = layout[0] - 6 - len(text)
merged = TextCodeTemplate(text, codes[0], tokenizer)
parts = [
np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64),
merged,
np.array([tokenizer['[POS0]'], tokenizer['[POS1]']], dtype=np.int64)
]
ret = np.concatenate(parts, axis=0)
return {'text': ret,
'n_pads': n_pad}
return BinaryDataset(path, process_fn, length_per_sample=layout[-1])
if __name__ == '__main__':
py_parser = argparse.ArgumentParser(add_help=False)
RetrievalModel.add_model_specific_args(py_parser)
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
args.layout = [int(x) for x in args.layout.split(',')]
training_main(args, model_cls=RetrievalModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
\ No newline at end of file
......@@ -67,3 +67,40 @@ class AttentionMixin(BaseMixin):
self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data)
self.dense[layer_id].bias.data.copy_(old_attention.dense.bias.data)
class ParallelLinearMixin(BaseMixin):
def __init__(self, input_size, output_size, bias=True,
init_method=torch.nn.init.xavier_normal_, stride=1,
keep_master_weight_for_test=False):
super(ParallelLinearMixin, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.parallel_linear = ColumnParallelLinear(
input_size, output_size, bias=bias, gather_output=True,
init_method=init_method, stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test)
def forward(self, input_):
return self.parallel_linear(input_)
class ParallelDoubleLayerLinearMixin(BaseMixin):
def __init__(self, input_size, hidden_size, output_size, bias=True,
init_method=torch.nn.init.xavier_normal_, stride=1,
keep_master_weight_for_test=False):
super(ParallelDoubleLayerLinearMixin, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.column_parallel_linear = ColumnParallelLinear(
input_size, hidden_size, bias=bias, gather_output=False,
init_method=init_method, stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test)
self.row_parallel_linear = RowParallelLinear(
hidden_size, output_size, bias=bias,
init_method=init_method, stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test)
def forward(self, input_):
return self.row_parallel_linear(self.column_parallel_linear(input_))
\ No newline at end of file
# -*- encoding: utf-8 -*-
# here put the import lib
import os
import sys
import math
import random
import torch
import torch.nn.functional as F
from .base_model import BaseModel
from .mixins import PositionEmbeddingMixin, AttentionMixin, ParallelLinearMixin
from mpu.transformer import split_tensor_along_last_dim
from mpu.utils import sqrt
from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
class RetrievalModel(BaseModel):
def __init__(self, args, transformer=None):
super().__init__(args, transformer=transformer)
self.layout = args.layout
self.txt_img_split = args.txt_img_split
self.mixins.append(PositionEmbeddingMixin(
2, args.hidden_size
))
self.mixins.extend([
ParallelLinearMixin(
args.hidden_size, args.retrieval_size),
ParallelLinearMixin(
args.hidden_size, args.retrieval_size)
])
def reinit(self):
pass
def position_embedding_forward(self, position_ids, *other_tensors):
position_embeddings = torch.cat(
(
self.transformer.position_embeddings(position_ids[:, :-2]),
self.mixins[0].position_embeddings(position_ids[:, -2:])
),
dim=-2
)
return position_embeddings
def final_forward(self, logits, *other_tensors):
txt_logits = logits[:, -1, :]
img_logits = logits[:, -2, :]
return (self.mixins[1](txt_logits), self.mixins[2](img_logits))
def disable_untrainable_params(self):
pass
@classmethod
def add_model_specific_args(cls, parser):
group = parser.add_argument_group('RetrievalModel', 'retrieval model configurations')
group.add_argument('--txt-img-split', action='store_true')
group.add_argument('--retrieval-temp', type=int, default=0)
group.add_argument('--retrieval-mode', type=str, default='txt2img',
choices=['txt2img', 'img2txt', 'symmetric'])
group.add_argument('--retrieval-hidden-size', type=int, default=2048)
group.add_argument('--retrieval-size', type=int, default=1024)
group.add_argument("--layout", type=str, default='64,1088')
return parser
\ No newline at end of file
#! /bin/bash
# Change for multinode config
NUM_WORKERS=1
NUM_GPUS_PER_WORKER=8
MP_SIZE=1
script_path=$(realpath $0)
script_dir=$(dirname $script_path)
main_dir=$(dirname $script_dir)
OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
HOST_FILE_PATH="hostfile"
HOST_FILE_PATH="hostfile_single"
# full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin"
small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_1/ali/ali.bin.part_0.cogdata"
config_json="$script_dir/ds_config_zero.json"
gpt_options=" \
--experiment-name finetune-retrieval-test \
--tokenizer-type cogview \
--img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \
--model-parallel-size ${MP_SIZE} \
--mode finetune \
--batch-size 4 \
--num-layers 48 \
--hidden-size 2560 \
--num-attention-heads 40 \
--train-iters 200000 \
--resume-dataloader \
--train-data ${small_data} \
--split 949,50,1 \
--distributed-backend nccl \
--lr-decay-style cosine \
--warmup .1 \
--checkpoint-activations \
--max-sequence-length 1089 \
--sandwich-ln \
--fp16 \
--save-interval 2000 \
--eval-interval 1000 \
--save $main_dir/checkpoints \
--load /workspace/dm/SwissArmyTransformer/pretrained/cogview/cogview-base
"
# --load pretrained/cogview/cogview-base
gpt_options="${gpt_options}
--deepspeed \
--deepspeed_config ${config_json} \
"
run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} finetune_retrieval.py $@ ${gpt_options}"
echo ${run_cmd}
eval ${run_cmd}
set +x
......@@ -335,7 +335,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
# Check nan or inf in forward, preventing it from interfering loss scaler,
# and all reduce metrics by the way
loss_checker = lm_loss.detach()
loss_checker = lm_loss.clone().detach()
for name in metrics:
metrics[name] = metrics[name].detach()
torch.distributed.all_reduce(metrics[name].data)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment