diff --git a/inference_object.py b/inference_object.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..576635af756dbdf36d4a99ecd6edcda5343fda3a 100644 --- a/inference_object.py +++ b/inference_object.py @@ -0,0 +1,110 @@ +# -*- encoding: utf-8 -*- +''' +@File : inference_cogview.py +@Time : 2021/10/09 19:41:58 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import argparse + +from arguments import get_args +from model.cached_autoregressive_model import CachedAutoregressiveModel +from model.cached_object_model import CachedObjectModel +from model.ObjectModel import ObjectModel +from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer +from tokenization import get_tokenizer +from generation.sampling_strategies import BaseStrategy +from generation.autoregressive_sampling import filling_sequence +from generation.utils import timed_name, save_multiple_images, generate_continually + + +def main(args): + initialize_distributed(args) + tokenizer = prepare_tokenizer(args) + # build model + model = CachedObjectModel(args) + if args.fp16: + model = model.half() + model = model.to(args.device) + load_checkpoint(model, args) + set_random_seed(args.seed) + model.eval() + + # define function for each query + invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] + strategy = BaseStrategy(invalid_slices, + temperature=args.temperature, topk=args.top_k) + + def process(raw_text): + if args.with_id: + query_id, raw_text = raw_text.split('\t') + print('raw text: ', raw_text) + raw_text = raw_text.split(' ') + objects = raw_text[1:] + seq = tokenizer.parse_query(f"[ROI1] {raw_text[0]}", img_size=args.img_size) + for i in range(len(objects)//5): + seq.append(tokenizer['POS0']) + seq.extend(objects[i*5:i*5+4] + args.old_token_num) + seq.extend(tokenizer.EncodeAsIds(objects[i*5 + 4])) + seq.extend(tokenizer.parse_query('[BASE] [BOI1] [MASK]*1024', img_size=args.img_size)) + + if len(seq) > 1271: + raise ValueError('text too long.') + # calibrate text length + txt_len = seq.index(tokenizer['[BASE]']) + log_attention_weights = torch.zeros(len(seq), len(seq), + device=args.device, dtype=torch.half if args.fp16 else torch.float32) + log_attention_weights[txt_len + 2:, 1:txt_len] = 1.8 if txt_len <= 10 else 1.4 # TODO args + # generation + seq = torch.cuda.LongTensor(seq, device=args.device) + mbz = args.max_inference_batch_size + assert args.batch_size < mbz or args.batch_size % mbz == 0 + output_list = [] + for tim in range(max(args.batch_size // mbz, 1)): + output_list.append( + filling_sequence(model, seq.clone(), + batch_size=min(args.batch_size, mbz), + strategy=strategy, + log_attention_weights=log_attention_weights + ) + ) + output_tokens = torch.cat(output_list, dim=0) + # decoding + imgs, txts = [], [] + for seq in output_tokens: + txt_len = seq.index(tokenizer['[BASE]']) + seq = seq[txt_len:] + _, decoded_imgs = tokenizer.DecodeIds(seq.tolist()) + imgs.append(decoded_imgs[-1]) # only the last image (target) + # save + if args.with_id: + full_path = os.path.join(args.output_path, query_id) + os.makedirs(full_path, exist_ok=True) + save_multiple_images(imgs, full_path, False) + else: + prefix = raw_text.replace('/', '')[:20] + full_path = timed_name(prefix, '.jpg', args.output_path) + save_multiple_images(imgs, full_path, True) + + os.makedirs(args.output_path, exist_ok=True) + generate_continually(process, args.input_source) + + +if __name__ == "__main__": + py_parser = argparse.ArgumentParser(add_help=False) + py_parser.add_argument('--full-query', action='store_true') + py_parser.add_argument('--img-size', type=int, default=256) + ObjectModel.add_model_specific_args(py_parser) + known, args_list = py_parser.parse_known_args() + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + with torch.no_grad(): + main(args) \ No newline at end of file diff --git a/model/ObjectModel.py b/model/ObjectModel.py index c236f717da63aa2431c1b7b86e9c19d9355608f2..1000fe1f7594cc44c9cd65e9387e0f03d9b18c37 100644 --- a/model/ObjectModel.py +++ b/model/ObjectModel.py @@ -3,37 +3,55 @@ import torch.nn.functional as F from .base_model import BaseModel -from .mixins import PositionEmbeddingMixin, AttentionMixin +from .mixins import PositionEmbeddingMixin, AttentionMixin, WordEmebeddingMixin from mpu.transformer import split_tensor_along_last_dim from mpu.local_attention_function import f_similar, f_weighting from mpu.utils import sqrt from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker +from mpu.mappings import gather_from_model_parallel_region, copy_to_model_parallel_region class ObjectModel(BaseModel): def __init__(self, args, transformer=None): super().__init__(args, transformer=transformer) additional_seqlen = args.new_sequence_length - args.max_sequence_length self.mixins.append(PositionEmbeddingMixin( - additional_seqlen, args.hidden_size + additional_seqlen, args.hidden_size, + reinit_slice=slice(-180, None) + )) + self.mixins.append(WordEmebeddingMixin( + args.old_token_num, args.additional_token_num, args.hidden_size )) self.layout = args.layout def position_embedding_forward(self, position_ids, *other_tensors): - position = position_ids[..., :self.layout[1]] - position_plus = position_ids[..., self.layout[1]:] + # breakpoint() + position_text = position_ids[..., :self.layout[0]] + position_object = position_ids[..., self.layout[0]:self.layout[1]] + position_image = position_ids[..., self.layout[1]:] position_embeddings = torch.cat( ( - self.transformer.position_embeddings(position), - self.mixins[0].position_embeddings(position_plus) + self.transformer.position_embeddings(position_text), + self.mixins[0].position_embeddings(position_object), + self.transformer.position_embeddings(position_image) ), dim=-2 ) return position_embeddings + def word_embedding_forward(self, input_ids, *other_tensors): + return self.mixins[1].word_embeddings(input_ids) + + def final_forward(self, logits, *other_tensors): + logits = copy_to_model_parallel_region(logits) + logits = F.linear(logits, self.mixins[1].word_embeddings.weight) + return logits + @classmethod def add_model_specific_args(cls, parser): group = parser.add_argument_group('ObjectModel', 'Object model configurations') - group.add_argument("--layout", type=str, default='64,1088') - group.add_argument("--new-sequence-length", type=int, default=5185) + group.add_argument("--layout", type=str, default='64,246,1270') + group.add_argument("--old-token-num", type=int, default=58219) + group.add_argument("--additional-token-num", type=int, default=257) + group.add_argument("--new-sequence-length", type=int, default=1271) #1089 + 180 + 2 return parser \ No newline at end of file diff --git a/model/cached_object_model.py b/model/cached_object_model.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..6ef67d76ef63fbc905bc9b7155846066f1b9be99 100644 --- a/model/cached_object_model.py +++ b/model/cached_object_model.py @@ -0,0 +1,50 @@ +# -*- encoding: utf-8 -*- + +# here put the import lib +import os +import sys +import math +import random +import torch + +from .base_model import BaseModel +from .ObjectModel import ObjectModel +from mpu.transformer import standard_attention, split_tensor_along_last_dim + + +class CachedObjectModel(ObjectModel): + def __init__(self, args, transformer=None): + super().__init__(args, transformer=transformer) + self.log_attention_weights = None + + def attention_forward(self, hidden_states, mask, *other_tensors, layer_id=None): + attn_module = self.transformer.layers[layer_id].attention + mem = other_tensors[layer_id] if len(other_tensors) > 0 else None + + mixed_raw_layer = attn_module.query_key_value(hidden_states) + (mixed_query_layer, + mixed_key_layer, + mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) + + if mem is not None: # the first time, mem is None + b = mixed_key_layer.shape[0] # might change batch_size + memk, memv = split_tensor_along_last_dim(mem.expand(b, -1, -1), 2) + mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1) + mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1) + + # same as training + query_layer = attn_module._transpose_for_scores(mixed_query_layer) + key_layer = attn_module._transpose_for_scores(mixed_key_layer) + value_layer = attn_module._transpose_for_scores(mixed_value_layer) + context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, + log_attention_weights=self.log_attention_weights) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + output = attn_module.dense(context_layer) + + # new mem this layer + new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous() + + return output, new_mem diff --git a/model/mixins.py b/model/mixins.py index b8063fb04001dea35a281af72683ad2b41073d64..b5e12dcce595d17d8cdfee591b8aaab564ac717f 100644 --- a/model/mixins.py +++ b/model/mixins.py @@ -69,9 +69,16 @@ class AttentionMixin(BaseMixin): self.dense[layer_id].bias.data.copy_(old_attention.dense.bias.data) class WordEmebeddingMixin(BaseMixin): - def __init__(self, additional_token_num, hidden_size): + def __init__(self, old_token_num, additional_token_num, hidden_size, + init_method_std=0.02, reinit_slice=slice(-1024, None) + ): super(WordEmebeddingMixin, self).__init__() - self.word_embeddings = torch.nn.Embedding(additional_token_num, hidden_size) + self.reinit_slice = reinit_slice + self.word_embeddings = torch.nn.Embedding(old_token_num + additional_token_num, hidden_size) torch.nn.init.normal_(self.word_embeddings.weight, mean=0.0, std=init_method_std) + def reinit(self, transformer, *pre_mixins): - old_weights = transformer. \ No newline at end of file + old_weights = transformer.word_embedding.weight.data + old_len, hidden_size = old_weights.shape + assert hidden_size == self.word_embeddings.weight.shape[-1] + self.word_embeddings.weight[:old_len].data.view(-1, old_len, hidden_size).copy_(old_weights) \ No newline at end of file diff --git a/pretrain_objectModel.py b/pretrain_objectModel.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0428db3b72aa3cfa67e9a8f1b7d6771de5d8ea7a 100644 --- a/pretrain_objectModel.py +++ b/pretrain_objectModel.py @@ -0,0 +1,207 @@ +# -*- encoding: utf-8 -*- +''' +@File : pretrain_cogview2.py +@Time : 2021/10/06 00:58:32 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import argparse +import numpy as np + +import mpu +from arguments import get_args +from model.ObjectModel import ObjectModel +from training.deepspeed_training import training_main +from data_utils import BinaryDataset +from tokenization import get_tokenizer +from tokenization.cogview import TextCodeTemplate + +def get_names(): + return ["背景", + "人","自行车","汽车","摩托车","飞机","公交车","火车","卡车","船","红绿灯", + "消防栓","空", "停车牌","停车收费表","长椅","鸟","猫","狗","马","羊", + "牛","大象","熊","斑马","长颈鹿","空","背包","伞","空","空", + "手提包","领带","手提箱","飞盘","滑雪板","滑雪板","球","风筝","棒球棒","棒球手套", + "滑板","冲浪板","网球拍","瓶子","空", "酒杯","杯子","叉子","刀子","勺子", + "碗","香蕉","苹果","三明治","橘子","花椰菜","胡萝卜","热狗","比萨饼","甜甜圈", + "蛋糕","椅子","沙发","盆栽植物","床","空", "餐桌","空","空","厕所", + "空","电视","笔记本","鼠标","遥控器","键盘","手机","微波炉","烤箱","烤面包机", + "水槽","冰箱","空","书","钟","花瓶","剪刀","泰迪熊","吹风机","牙刷"] + + +def get_masks_and_position_ids(data, + n_pads, + object_pads, + loss_mask=None, + attention_mask=None, args=None): + # Extract batch size and sequence length. + batch_size, seq_length = data.size() + + # Attention mask (lower triangular). + if attention_mask is None: + assert loss_mask is not None + # loss_mask has n_pad(+1 CLS and [1:] then) zeros, so it is the same as attention_mask, reuse. + attention_mask = loss_mask[:, :seq_length].unsqueeze(-2).expand(batch_size, seq_length, seq_length).tril() + for i in range(batch_size): + attention_mask[i].fill_diagonal_(1) + attention_mask = attention_mask.unsqueeze(1) + + # Loss mask. + if loss_mask is None: + loss_mask = torch.ones(data.size(), dtype=data.dtype, device=data.device) + + # Position ids. + #1270 + position_ids = torch.zeros(batch_size, seq_length, dtype=torch.long, + device=data.device) + + for i in range(batch_size): + torch.arange(64 - n_pads[i], out=position_ids[i, n_pads[i]:64], + dtype=torch.long, device=data.device) + torch.arange(180-object_pads[i], out=position_ids[i, 64+object_pads:64+180]) + # breakpoint() + torch.arange(64 - n_pads[i], 64 - n_pads[i] + seq_length - (64+180), + out=position_ids[i, 64+180:], + dtype=torch.long, device=data.device) + return attention_mask, loss_mask, position_ids + + +def get_batch(data_iterator, args, timers): + # Items and their type. + keys = ['text', 'loss_mask', 'object_pad', 'n_pad'] + datatype = torch.int64 + + # Broadcast data. + timers('data loader').start() + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + timers('data loader').stop() + + data_b = mpu.broadcast_data(keys, data, datatype) + # Unpack. + # breakpoint() + tokens_ = data_b['text'].long() + loss_mask = data_b['loss_mask'].float() + n_pads = data_b['n_pad'].long() + object_pads = data_b['object_pad'].long() + + labels = tokens_[:, 1:].contiguous() + loss_mask = loss_mask[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + attention_mask = None + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_masks_and_position_ids( + tokens, + n_pads, + object_pads, + loss_mask=loss_mask, + attention_mask=attention_mask, + args=args + ) + # Convert + if args.fp16: + attention_mask = attention_mask.half() + + return tokens, labels, loss_mask, attention_mask, position_ids + + +def forward_step(data_iterator, model, args, timers): + """Forward step.""" + + # Get the batch. + timers('batch generator').start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator, args, timers) + timers('batch generator').stop() + # Forward model. + logits, *mems = model(tokens, position_ids, attention_mask) + losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels) + # scaling loss mask + loss_mask = loss_mask.view(-1) + + losses = losses.view(-1) * loss_mask + loss = torch.sum(losses) / loss_mask.sum() + + return loss, {} + +def create_dataset_function(path, args): + tokenizer = get_tokenizer() + layout = [100,164,1188] + names = get_names() + tokens = [] + for name in names: + tokens.append(tokenizer.EncodeAsIds(name)) + # 4 + 4 + 1 一个object需要9个token 20个需要 20*9 = 180个 + def process_fn(row): + row = row.astype(np.int64) + codes = [row[layout[1]:layout[2]]] + text = row[layout[0]:layout[1]] + text = text[text > 0][:63] # [ROI] + object_tokens = [] + # print(row[:100]) + for i in range(20): + object = row[i * 5: (i+1) * 5] + if object[0] == -1: + break + # print("object", object) + object[2] += object[0] + object[3] += object[1] + object_tokens.append(tokenizer['[POS0]']) + object_tokens.extend([object[j] + args.old_token_num for j in range(4)]) + object_tokens.extend(tokens[object[4]]) + object_tokens = np.array(object_tokens) + # print(object_tokens) + object_pad = 180 - object_tokens.shape[-1] + object_tokens = np.concatenate([ + np.array([tokenizer['[PAD]']] * object_pad, dtype=np.int64), + object_tokens + ], axis = 0) + # 180 + text_object = np.concatenate([text, np.array(object_tokens, dtype=np.int64)], axis=0) + # print(len(text), len(text_object)) + # print(len(codes[0])) + merged = TextCodeTemplate(text_object, codes[0], tokenizer) + # print(len(merged), len(text)) + n_pad = args.new_sequence_length - len(merged) + parts = [ + np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64), + merged + ] + ret = np.concatenate(parts, axis=0) + return {'text': ret, + 'loss_mask': np.array([0] * n_pad + [1] * (len(text) + 1) + [0] * object_pad + [1] * (182 - object_pad + 1025)), + 'object_pad':object_pad, + 'n_pad':n_pad + } + + return BinaryDataset(path, process_fn, length_per_sample=layout[-1]) + + + +if __name__ == '__main__': + py_parser = argparse.ArgumentParser(add_help=False) + + py_parser.add_argument('--txt-loss-scale', type=float, default=1) + + ObjectModel.add_model_specific_args(py_parser) + + known, args_list = py_parser.parse_known_args() + + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + args.layout = [int(x) for x in args.layout.split(',')] + + training_main(args, model_cls=ObjectModel, forward_step_function=forward_step, + create_dataset_function=create_dataset_function) diff --git a/scripts/finetune_into_object.sh b/scripts/finetune_into_object.sh index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..9cae1aae45e65ec3c86c63804c6e81a617c26ec1 100644 --- a/scripts/finetune_into_object.sh +++ b/scripts/finetune_into_object.sh @@ -0,0 +1,60 @@ +#! /bin/bash + +# Change for multinode config + +NUM_WORKERS=4 +NUM_GPUS_PER_WORKER=8 +MP_SIZE=1 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +main_dir=$(dirname $script_dir) + +OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" +HOST_FILE_PATH="hostfile" + +full_data='/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_coco_detection_task/coco/coco.bin.part_0.cogdata' + +# --mode finetune \ +# --resume-dataloader \ +# --load pretrained/cogview/cogview-base +config_json="$script_dir/ds_config_zero.json" +gpt_options=" \ + --experiment-name finetune2-object-test \ + --tokenizer-type cogview \ + --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ + --model-parallel-size ${MP_SIZE} \ + --num-layers 48 \ + --hidden-size 2560 \ + --num-attention-heads 40 \ + --train-iters 200000 \ + --train-data ${full_data} \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr-decay-style cosine \ + --warmup .1 \ + --checkpoint-activations \ + --max-sequence-length 1089 \ + --sandwich-ln \ + --fp16 \ + --save-interval 2000 \ + --eval-interval 1000 \ + --save $main_dir/checkpoints \ + --mode finetune \ + --resume-dataloader \ + --load pretrained/cogview/cogview-base +" + # --load pretrained/cogview/cogview-base + + +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + + +run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_objectModel.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/scripts/kill_python.sh b/scripts/kill_python.sh index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..3900363636355cbdd970edaff9d622eef166337a 100644 --- a/scripts/kill_python.sh +++ b/scripts/kill_python.sh @@ -0,0 +1 @@ +pdsh -w ssh:node[1-3] "pkill -9 python" \ No newline at end of file diff --git a/scripts/pretrain_multiple_nodes.sh b/scripts/pretrain_multiple_nodes.sh index 5a09f7712528974023124dd510d3762623e7c8d9..78e665fb2ca7c9b872e5708ae2af8b0ae2bf47b2 100755 --- a/scripts/pretrain_multiple_nodes.sh +++ b/scripts/pretrain_multiple_nodes.sh @@ -2,7 +2,7 @@ # Change for multinode config -NUM_WORKERS=1 +NUM_WORKERS=3 NUM_GPUS_PER_WORKER=8 MP_SIZE=1 @@ -12,7 +12,6 @@ main_dir=$(dirname $script_dir) OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" HOST_FILE_PATH="hostfile" -HOST_FILE_PATH="hostfile_single" full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin" small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata" diff --git a/scripts/text2image_object.sh b/scripts/text2image_object.sh index 96e8a045bbb2f3d4f046b9949dc734b12ae0288e..4ab4b70b745a3c47dfdd03c1f9ef80645ef531b3 100755 --- a/scripts/text2image_object.sh +++ b/scripts/text2image_object.sh @@ -1,6 +1,6 @@ #!/bin/bash -CHECKPOINT_PATH=pretrained/cogview/cogview2-base +CHECKPOINT_PATH=checkpoints/finetune2-object-test10-24-08-21 NLAYERS=48 NHIDDEN=2560 NATT=40 @@ -15,9 +15,9 @@ TOPK=200 script_path=$(realpath $0) script_dir=$(dirname $script_path) -MASTER_PORT=${MASTER_PORT} python inference_cogview2.py \ +MASTER_PORT=${MASTER_PORT} python inference_object.py \ --tokenizer-type cogview \ - --img-tokenizer-path pretrained/vqvae/l1+ms-ssim+revd_percep.pt \ + --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ --mode inference \ --distributed-backend nccl \ --max-sequence-length 1089 \ @@ -32,10 +32,10 @@ MASTER_PORT=${MASTER_PORT} python inference_cogview2.py \ --top_k $TOPK \ --sandwich-ln \ --input-source ./input.txt \ - --output-path samples_text2image \ + --output-path samples_text2image_object \ --batch-size 4 \ --max-inference-batch-size 8 \ - --device 0 \ + --device 7 \ $@