diff --git a/arguments.py b/arguments.py index b06c987ea51f430abef05a49caa2df6734fa8a98..a9e722aa247c317239666c5f82906be55fbfb433 100755 --- a/arguments.py +++ b/arguments.py @@ -147,6 +147,8 @@ def add_training_args(parser): group.add_argument('--warmup', type=float, default=0.01, help='percentage of data to warmup on (.01 = 1% of all ' 'training iters). Default 0.01') + group.add_argument('--restart-iter', type=int, default=0, + help='restart with warmup from this iteration.') # model checkpointing group.add_argument('--save', type=str, default=None, help='Output directory to save checkpoints to.') @@ -301,19 +303,20 @@ def add_sparse_args(parser): group.add_argument("--key-window-times", type=int, default=6) group.add_argument("--num-pivot", type=int, default=768) # for cuda_2d - group.add_argument("--kernel-size", type=int, default=11) + group.add_argument("--kernel-size", type=int, default=9) group.add_argument("--kernel-size2", type=int, default=7) - group.add_argument("--layout", type=str, default='0,64,1088,5184') + group.add_argument("--layout", type=str, default='64,1088,5184') return parser def make_sparse_config(args): + args.layout = [int(x) for x in args.layout.split(',')] sparse_config = argparse.Namespace(sparse_type=args.sparse_type) if args.sparse_type == 'standard': pass if args.sparse_type == 'cuda_2d' or args.generation_task == 'cuda-2d generation': sparse_config.kernel_size = args.kernel_size sparse_config.kernel_size2 = args.kernel_size2 - sparse_config.layout = [int(x) for x in args.layout.split(',')] + sparse_config.layout = args.layout elif args.sparse_type == 'torch_1d': raise NotImplementedError args.sparse_config = sparse_config diff --git a/data_utils/datasets.py b/data_utils/datasets.py index 877b53bd896c6bc55c48514f7ed1327226292c7b..7b1203839f85f14874d96647c4f7c0fd7be7d042 100755 --- a/data_utils/datasets.py +++ b/data_utils/datasets.py @@ -80,16 +80,16 @@ class BinaryDataset(Dataset): def __getitem__(self, index): return self.process_fn(self.bin[index]) -def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset): +def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset): + kwargs_to_dataset = {} tokenizer = get_tokenizer() - if args.finetune and args.max_position_embeddings_finetune > args.max_position_embeddings: - ml = args.max_position_embeddings_finetune + if args.layout[-1] > args.max_position_embeddings: + ml = args.layout[-1] else: ml = args.max_position_embeddings def pad_to_len(ret): - if len(ret) < ml: # pad return np.concatenate((ret, np.array([tokenizer['[PAD]']] * (ml - len(ret)))), @@ -117,14 +117,27 @@ def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset): } elif dataset_type == 'CompactBinaryDataset': + layout = args.layout DS_CLASS = BinaryDataset + kwargs_to_dataset['length_per_sample'] = layout[-1] def process_fn(row): - text, code = row[:64].astype(np.int64), row[64:].astype(np.int64) # must 64 + 1024 - text = text[text>-1] - ret = TextCodeTemplate(text, code) - ret, attention_mask_sep = pad_to_len(ret) + row = row.astype(np.int64) + # THIS IS Reverse order, TODO + lens = list(reversed([layout[i] - layout[i-1] for i in range(1, len(layout))])) + codes = [row[layout[0]: layout[0]+lens[0]]] + if len(lens) > 1: + codes.append(row[layout[0]+lens[0]: layout[0]+lens[0]+lens[1]]) + text = row[:layout[0]] + text = text[text>0][:layout[0] - 3] # [CLS] [BASE] [ROI1] + n_pad = layout[0]-3-len(text) + parts = [ + np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64), + TextCodeTemplate(text, codes[-1]), + *reversed(codes[:-1]) + ] + ret = np.concatenate(parts, axis=0) return {'text': ret, - 'loss_mask': np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep)) + 'loss_mask': np.array([0] * (n_pad+1) + [1] * (len(ret) - n_pad - 1)) # don't predict [CLS] } elif dataset_type == 'BinaryDataset': DS_CLASS = BinaryDataset @@ -134,5 +147,5 @@ def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset): 'loss_mask': loss_mask } - return DS_CLASS(path, process_fn) + return DS_CLASS(path, process_fn, **kwargs_to_dataset) diff --git a/data_utils/vqvae_tokenizer.py b/data_utils/vqvae_tokenizer.py index 13addb42fa8e20b8ae368a8e5f49b1eec206b17f..56ee251126fdcb901d4b88cea114342b1dfccdb7 100755 --- a/data_utils/vqvae_tokenizer.py +++ b/data_utils/vqvae_tokenizer.py @@ -50,12 +50,15 @@ class VQVAETokenizer(object): self.device = device self.image_tokens = model.quantize_t.n_embed self.num_tokens = model.quantize_t.n_embed + self.tr_normalize = transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800]) def __len__(self): return self.num_tokens - def EncodeAsIds(self, img): + def EncodeAsIds(self, img, add_normalization=False): assert len(img.shape) == 4 # [b, c, h, w] + if add_normalization: + img = self.tr_normalize(img) return img2code(self.model, img) def DecodeIds(self, code, shape=None): @@ -78,7 +81,6 @@ class VQVAETokenizer(object): img = tr(Image.open(path)) if img.shape[0] == 4: img = img[:-1] - tr_normalize = transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800]) - img = tr_normalize(img) + img = self.tr_normalize(img) img = img.unsqueeze(0).float().to(self.device) # size [1, 3, h, w] return img \ No newline at end of file diff --git a/draw_diff.py b/draw_diff.py index 6dd12a980c6e3b2ea859726d9ad1202331ed22b4..b4baab18c797077231138786e377f66adee62f62 100644 --- a/draw_diff.py +++ b/draw_diff.py @@ -23,9 +23,10 @@ transform = transforms.Compose([ ]) img = torchvision.io.read_image('bao.jpeg') img = transform(img) / 255. -a = np.array(loadbao('bao.txt')) -b = np.array(loadbao('bao2.txt')) -for t in (a-b>2).nonzero()[0]: +a = np.array(loadbao('bao2.txt')) +b = np.array(loadbao('bao3.txt')) +for t in (b-a>1).nonzero()[0]: x,y = t // 32, t % 32 sq(img, x*16, y*16, 15, 15) +print(a.mean(), b.mean()) torchvision.utils.save_image(img, 'example_bao.jpg') diff --git a/generate_samples.py b/generate_samples.py index aff3e469b9f458bb93e63166994a663f31532933..4ae8fa02da9c751920e649b057705aab2a5587c4 100755 --- a/generate_samples.py +++ b/generate_samples.py @@ -41,7 +41,7 @@ from pretrain_gpt2 import get_model import math from copy import deepcopy from tqdm import tqdm -from generation import get_batch, filling_sequence, add_interlacing_beam_marks, magnify, inverse_prompt_score, filling_sequence_local +from generation import get_batch, filling_sequence, add_interlacing_beam_marks, magnify, inverse_prompt_score, filling_sequence_local, filling_sequence_cuda_2d from torchvision.utils import save_image import torch.distributed as dist @@ -167,7 +167,7 @@ def generate_images_once(model, args, raw_text, seq=None, num=8, query_template= # profile = line_profiler.LineProfiler(model.module.forward) # profile = line_profiler.LineProfiler(standard_attention) # profile.enable() - fill_fn = filling_sequence_local if args.generation_task == 'cuda-2d generation' else filling_sequence + fill_fn = filling_sequence_cuda_2d if args.generation_task == 'cuda-2d generation' else filling_sequence output_tokens_list.append(fill_fn(model, seq.clone(), args)) # torch.cuda.empty_cache() # profile.disable() # åœæ¢åˆ†æž @@ -212,7 +212,7 @@ def generate_images_once(model, args, raw_text, seq=None, num=8, query_template= def generate_images_continually(model, args): if args.generation_task == 'text2image': - query_template = '[ROI1] {} [BASE] [BOI1] [Image200]bao.jpeg' + query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024' elif args.generation_task == 'image2text': query_template = '[BASE] [BOI1] [Image]{} [EOI1] [ROI1] [MASK]*20' elif args.generation_task == 'low-level super-resolution': @@ -222,7 +222,7 @@ def generate_images_continually(model, args): elif args.generation_task == 'post-selection': query_template = '[BASE] [BOI1] [Image]{} [EOI1] [ROI1] {}' elif args.generation_task == 'cuda-2d generation': - query_template = '[CLS] {} [BASE] [Image200]bao.jpeg [MASK]*4096' + query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024 [EOI1] [MASK]*4096' else: raise NotImplementedError for raw_text, seq, output_path in get_context(args, query_template): diff --git a/generation/__init__.py b/generation/__init__.py index b73ab2313cef87d5306e653894797d211f4a2c36..94f6ea1d0e968244e57ee1bf024dadb9fd1bd654 100755 --- a/generation/__init__.py +++ b/generation/__init__.py @@ -1,3 +1,4 @@ from .sampling import get_batch, filling_sequence, add_interlacing_beam_marks, inverse_prompt_score from .magnify import magnify -from .local_sampling import filling_sequence_local \ No newline at end of file +from .local_sampling import filling_sequence_local +from .cuda_2d_sampling import filling_sequence_cuda_2d \ No newline at end of file diff --git a/generation/cuda_2d_sampling.py b/generation/cuda_2d_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..d5fe7f83935a99dcdbbbe5da57fe506c5ee700ab --- /dev/null +++ b/generation/cuda_2d_sampling.py @@ -0,0 +1,121 @@ +from .sampling import * +import math +import sys +from copy import deepcopy +from torchvision.utils import save_image +def filling_sequence_cuda_2d( + model, + seq, + args, + mems=None, + invalid_slices=[], + iterative_step=20, + **kwargs): + ''' + seq: [id[ROI1], 10000, 20000, id[BASE], id[BOI1], 1024 * -1/known tokens, id[EOI1], 4096 * -1..., ] + ''' + tokenizer = get_tokenizer() + invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] + device = seq.device + assert args.sparse_config.sparse_type == 'cuda_2d' + std_config = deepcopy(args.sparse_config) + std_config.sparse_type = 'standard' + sparse_config = args.sparse_config + # split two parts + seq0, seq1 = seq[:-4097], seq[-4097:] # +1 for EOI1 + # generate a batch of seq0 + model.module.transformer.reset_sparse_config(std_config) + args.sparse_config = std_config + output0 = filling_sequence(model, seq0, args) + model.module.transformer.reset_sparse_config(sparse_config) + args.sparse_config = sparse_config + model.module.transformer.max_memory_length = 0 + + + # filter bad generation & select top N=2, TODO + output0 = output0 + + from torchvision import transforms + tr = transforms.Compose([ + transforms.Resize(512), + ]) + imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1024:].tolist())) for x in output0] # ground truth + blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(device), add_normalization=True) # blured image as init value + + # pad seq to desired shape + n_pad = args.layout[1] - len(seq0) + batch_size = output0.shape[0] + assert n_pad > 0, "You should truncate long input before filling." + seq = torch.cat(( + torch.tensor([tokenizer['[PAD]']]* n_pad, device=seq.device, dtype=seq.dtype) + .unsqueeze(0).expand(batch_size, n_pad), + output0, + seq1.unsqueeze(0).expand(batch_size, len(seq1)) + ), dim=1 + ) # [b, layout[-1]] + + # init + step_cnt = 0 + tokens = seq[:, :-1].clone() + unfixed = (seq < 0) + # tokens[unfixed[:, :-1]] = tokens[unfixed[:, :-1]].random_(0, tokenizer.img_tokenizer.num_tokens) + tokens[:, -4095:] = blur64[:, :-1] + attention_mask = torch.ones(args.layout[1], args.layout[1]).tril().to(device) + attention_mask[n_pad:, :n_pad] = 0 + position_ids = torch.cat(( + torch.zeros(n_pad, dtype=torch.long), + torch.arange(0, args.layout[1] - n_pad), + torch.arange(0, args.layout[2]-args.layout[1]))).to(device) + # iterate + imgs = [] + # import pdb;pdb.set_trace() + while unfixed.sum() > 0: + print(unfixed.sum()) + logits, *_dump = model(tokens, position_ids, attention_mask) + step_cnt += 1 + last_logits = logits + + # warmup + real_topk = 5 + real_temp = 2 - min(1,((step_cnt) / iterative_step)) * 1.9 + # sampling + for invalid_slice in invalid_slices: # forbide to generate other tokens + logits[..., invalid_slice] = -float('Inf') + assert args.top_k > 0 + tk_value, tk_idx = torch.topk(logits, real_topk, dim=-1) + tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1]) + prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1) + prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1) + # update unfixed + choice = 1 + if choice == 0 and step_cnt > 5: + mprob = tk_probs.max(dim=-1)[0].view(*(tk_value.shape[:2])) + dprob = (mprob[:, 1:] < 0.5) & ((mprob[:, :-1] > 0.8)| (unfixed[:, 1:-1].logical_not())) + new_fixed = unfixed.clone() + new_fixed[:, 2:] &= dprob + else: + new_fixed = unfixed & False # TODO + new_fixed[:, -1] = True + unfixed &= new_fixed.logical_not() + # update seq and tokens + seq[new_fixed] = prev[new_fixed[:, 1:]] + tokens = seq[:, :-1].clone() + tokens[:,1:][unfixed[:, 1:-1]] = prev[:, :-1][unfixed[:, 1:-1]] + + if step_cnt == iterative_step: + seq[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step + n_unfixed = unfixed.sum(dim=-1).tolist() + print(f'Exit with {n_unfixed} unfixed tokens.') + break + if args.debug: + from torchvision.utils import save_image + seqt = seq.clone() + seqt[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step + imgs.extend([tokenizer.img_tokenizer.DecodeIds(s[-4096:]) for s in seqt]) + if args.debug: + imgs = torch.cat(imgs, dim=0) + save_image(imgs, f'steps{device}.jpg', normalize=True) + + model.module.transformer.max_memory_length = args.max_memory_length + + return seq \ No newline at end of file diff --git a/generation/sampling.py b/generation/sampling.py index 4c3acd2501e50f74589884963a87a76840d019dd..8d614724a10a3f1ea1733d4c3fe23a602b70ed11 100755 --- a/generation/sampling.py +++ b/generation/sampling.py @@ -132,8 +132,8 @@ def filling_sequence( tmp = -F.log_softmax(logits, dim=-1) tmp = tmp[0,:-1].gather(dim=-1,index=tokens[0,1:].unsqueeze(-1))[4:,0] - for i in range(1,len(tmp)): - print(i, tmp[i].item()) + # for i in range(1,len(tmp)): + # print(i, tmp[i].item()) index = counter print(tmp[1:].mean(), file=sys.stderr) elif seq[counter + 1] >= 0: # provided diff --git a/learning_rates.py b/learning_rates.py index 4749ae4891383d0cbf11212d085cfd696d463131..b68e8be1d26559e41ab318c70493b388efd43e01 100755 --- a/learning_rates.py +++ b/learning_rates.py @@ -17,13 +17,16 @@ import torch from torch.optim.lr_scheduler import _LRScheduler import math +from utils import print_rank_0 + class AnnealingLR(_LRScheduler): """Anneals the learning rate from start to zero along a cosine curve.""" DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None'] - def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, decay_ratio=0.5): + def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, decay_ratio=0.5, restart_iter=0): + self.restart_iter = restart_iter assert warmup_iter <= num_iters self.optimizer = optimizer self.start_lr = start_lr @@ -38,13 +41,16 @@ class AnnealingLR(_LRScheduler): def get_lr(self): # https://openreview.net/pdf?id=BJYwwY9ll pg. 4 - if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: - return float(self.start_lr) * self.num_iters / self.warmup_iter + real_num_iters = self.num_iters - self.restart_iter + real_end_iter = self.end_iter - self.restart_iter + # print_rank_0(f'real_num_iters: {real_num_iters}') + if self.warmup_iter > 0 and real_num_iters <= self.warmup_iter: + return float(self.start_lr) * real_num_iters / self.warmup_iter else: if self.decay_style == self.DECAY_STYLES[0]: - return self.start_lr*((self.end_iter-(self.num_iters-self.warmup_iter))/self.end_iter) + return self.start_lr*((real_end_iter-(real_num_iters-self.warmup_iter))/real_end_iter) elif self.decay_style == self.DECAY_STYLES[1]: - decay_step_ratio = min(1.0, (self.num_iters - self.warmup_iter) / self.end_iter) + decay_step_ratio = min(1.0, (real_num_iters - self.warmup_iter) / real_end_iter) return self.start_lr / self.decay_ratio * ( (math.cos(math.pi * decay_step_ratio) + 1) * (self.decay_ratio - 1) / 2 + 1) elif self.decay_style == self.DECAY_STYLES[2]: @@ -73,8 +79,9 @@ class AnnealingLR(_LRScheduler): return sd def load_state_dict(self, sd): + import pdb;pdb.set_trace() # self.start_lr = sd['start_lr'] - self.warmup_iter = sd['warmup_iter'] + # self.warmup_iter = sd['warmup_iter'] self.num_iters = sd['num_iters'] # self.end_iter = sd['end_iter'] self.decay_style = sd['decay_style'] diff --git a/model/gpt2_modeling.py b/model/gpt2_modeling.py index 6c47e98ada30ff7fb73881f4d93ba960e2693900..c854f1bce00f0a0397ebb796c1f7f17f7e3c1b4b 100755 --- a/model/gpt2_modeling.py +++ b/model/gpt2_modeling.py @@ -41,15 +41,14 @@ def gpt2_get_params_for_weight_decay_optimization(module): if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)): no_weight_decay_params['params'].extend( [p for p in list(module_._parameters.values()) - if p is not None]) + if p is not None and p.requires_grad]) else: weight_decay_params['params'].extend( [p for n, p in list(module_._parameters.items()) - if p is not None and n != 'bias']) + if p is not None and n != 'bias' and p.requires_grad]) no_weight_decay_params['params'].extend( [p for n, p in list(module_._parameters.items()) - if p is not None and n == 'bias']) - + if p is not None and n == 'bias' and p.requires_grad]) return weight_decay_params, no_weight_decay_params @@ -74,7 +73,8 @@ class GPT2Model(torch.nn.Module): sandwich_ln, checkpoint_num_layers=1, parallel_output=True, - sparse_config=argparse.Namespace(sparse_type='standard') + sparse_config=argparse.Namespace(sparse_type='standard'), + finetune=False ): super(GPT2Model, self).__init__() @@ -99,7 +99,8 @@ class GPT2Model(torch.nn.Module): checkpoint_activations, checkpoint_num_layers, sandwich_ln=sandwich_ln, - sparse_config=sparse_config + sparse_config=sparse_config, + finetune=finetune ) def forward(self, input_ids, position_ids, attention_mask, *mems): @@ -120,3 +121,6 @@ class GPT2Model(torch.nn.Module): return (logits_parallel, *hidden_layers) return (mpu.gather_from_model_parallel_region(logits_parallel), *hidden_layers) + + def init_plus_from_old(self): + self.transformer.init_plus_from_old() diff --git a/mpu/sparse_transformer.py b/mpu/sparse_transformer.py index 7adc0dfbb13e25aba14aa64fa1f0285a457d2bbf..53a92de78d6b0409e520dbb139774d18d37af6f7 100755 --- a/mpu/sparse_transformer.py +++ b/mpu/sparse_transformer.py @@ -75,7 +75,8 @@ class GPT2ParallelSelfAttention(torch.nn.Module): """ def __init__(self, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, - init_method, output_layer_init_method=None,sparse_config=None): + init_method, output_layer_init_method=None,sparse_config=None, + finetune=False): super(GPT2ParallelSelfAttention, self).__init__() # Set output layer initialization if not provided. if output_layer_init_method is None: @@ -111,11 +112,30 @@ class GPT2ParallelSelfAttention(torch.nn.Module): get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint - # self.offset_bias = torch.nn.Parameter( - # torch.ones(num_attention_heads, sparse_config.kernel_size**2//2+1 +sparse_config.kernel_size2**2)) - self.sparse_config = sparse_config + if finetune: + # build new branch + self.query_key_value_plus = ColumnParallelLinear(hidden_size, 3*hidden_size, + stride=3, + gather_output=False, + init_method=init_method) + self.dense_plus = RowParallelLinear(hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method) + + def init_plus_from_old(self): + self.query_key_value_plus.weight.data.copy_(self.query_key_value.weight.data) + if hasattr(self.query_key_value_plus, 'bias') and hasattr(self.query_key_value, 'bias'): + self.query_key_value_plus.bias.data.copy_(self.query_key_value.bias.data) + + self.dense_plus.weight.data.copy_(self.dense.weight.data) + if hasattr(self.dense_plus, 'bias') and hasattr(self.dense, 'bias'): + self.dense_plus.bias.data.copy_(self.dense.bias.data) + def reset_sparse_config(self, config): + self.sparse_config = config + def _transpose_for_scores(self, tensor): """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. @@ -128,16 +148,15 @@ class GPT2ParallelSelfAttention(torch.nn.Module): def forward(self, hidden_states, mask, mem=None): - # import pdb;pdb.set_trace() sparse_config = self.sparse_config - # hidden_states: [b, s, h] - # ltor_mask: [1, 1, s, s] - - # Attention heads. [b, s, hp] - query_length = hidden_states.size(1) + layout = sparse_config.layout + if sparse_config.sparse_type == 'cuda_2d': + assert hidden_states.size(1) == sparse_config.layout[-1] + # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]} + hidden_states_plus = hidden_states[:, layout[1]:] + hidden_states = hidden_states[:, :layout[1]] mixed_raw_layer = self.query_key_value(hidden_states) - (mixed_query_layer, mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) @@ -145,38 +164,45 @@ class GPT2ParallelSelfAttention(torch.nn.Module): memk, memv = split_tensor_along_last_dim(mem, 2) mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1) mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1) + + if sparse_config.sparse_type == 'cuda_2d': + mixed_raw_layer_plus = self.query_key_value_plus(hidden_states_plus) + q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer_plus, 3) dropout_fn = self.attention_dropout if self.training else None - if sparse_config.sparse_type in ['standard', 'torch_1d']: - # Reshape and transpose [b, np, s, hn] + if sparse_config.sparse_type == 'standard': query_layer = self._transpose_for_scores(mixed_query_layer) key_layer = self._transpose_for_scores(mixed_key_layer) value_layer = self._transpose_for_scores(mixed_value_layer) - if sparse_config.sparse_type == 'standard': - context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn) - else: - context_layer = sparse_attention_1d(query_layer, key_layer, value_layer, sparse_config.pivot_idx, - mask, sparse_config.query_window, sparse_config.key_window_times, dropout_fn) - # inference: context_layer = sparse_attention_inference(query_layer, key_layer, value_layer, pivot_idx) + context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size_per_partition,) - # [b, s, hp] context_layer = context_layer.view(*new_context_layer_shape) elif sparse_config.sparse_type == 'cuda_2d': - context_layer = sparse_attention_2dfull(mixed_query_layer, mixed_key_layer, mixed_value_layer, self.num_attention_heads_per_partition, - sparse_config.layout, mask, sparse_config.kernel_size, - kernel_size2=sparse_config.kernel_size2, - attention_dropout=dropout_fn - ) - - # Output. [b, s, h] - output = self.dense(context_layer) + context_layer0, context_layer1 = sparse_attention_2d_light( + mixed_query_layer, mixed_key_layer, mixed_value_layer, + q1, k1, v1, + mask, + n_head=self.num_attention_heads_per_partition, + text_len=sparse_config.layout[0], + kernel_size=sparse_config.kernel_size, + kernel_size2=sparse_config.kernel_size2, + attention_dropout=dropout_fn + ) + + if sparse_config.sparse_type == 'cuda_2d': + output_0 = self.dense(context_layer0) + output_1 = self.dense_plus(context_layer1) + output = torch.cat((output_0, output_1), dim=1) + else: + output = self.dense(context_layer) + if self.training: output = self.output_dropout(output) @@ -284,7 +310,8 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): init_method, output_layer_init_method=None, sandwich_ln=True, - sparse_config=argparse.Namespace(sparse_type='standard') + sparse_config=argparse.Namespace(sparse_type='standard'), + finetune=False ): super(GPT2ParallelTransformerLayer, self).__init__() # Set output layer initialization if not provided. @@ -302,7 +329,8 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): output_dropout_prob, init_method, output_layer_init_method=output_layer_init_method, - sparse_config=sparse_config + sparse_config=sparse_config, + finetune=finetune ) # Layernorm on the input data. @@ -324,6 +352,10 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): self.sparse_config = sparse_config + def reset_sparse_config(self, config): + self.sparse_config = config + self.attention.reset_sparse_config(config) + def forward(self, hidden_states, ltor_mask, mem=None): # hidden_states: [b, s, h] # ltor_mask: [1, 1, s, s] @@ -419,7 +451,8 @@ class GPT2ParallelTransformer(torch.nn.Module): init_method_std=0.02, use_scaled_init_for_output_weights=True, sandwich_ln=True, - sparse_config=argparse.Namespace(sparse_type='standard') + sparse_config=argparse.Namespace(sparse_type='standard'), + finetune=False ): super(GPT2ParallelTransformer, self).__init__() # Store activation checkpoiting flag. @@ -441,12 +474,11 @@ class GPT2ParallelTransformer(torch.nn.Module): # Initialize the position embeddings. torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) - # TODO: after testing, this is not useful. - # self.img_type_embeddings = torch.nn.Parameter(torch.Tensor(64, hidden_size)) - # torch.nn.init.normal_(self.img_type_embeddings, mean=0.0, std=init_method_std) - # self.txt_type_embeddings = torch.nn.Parameter(torch.Tensor(hidden_size)) - # torch.nn.init.normal_(self.txt_type_embeddings, mean=0.0, std=init_method_std) - + if finetune: + self.position_embeddings_plus = torch.nn.Embedding(4096, # FIXME + hidden_size) + # Initialize the position embeddings. + torch.nn.init.normal_(self.position_embeddings_plus.weight, mean=0.0, std=init_method_std) def get_layer(layer_id): return GPT2ParallelTransformerLayer( @@ -458,7 +490,8 @@ class GPT2ParallelTransformer(torch.nn.Module): unscaled_init_method(init_method_std), output_layer_init_method=output_layer_init_method, sandwich_ln=sandwich_ln, - sparse_config=sparse_config + sparse_config=sparse_config, + finetune=finetune ) # Transformer layers. @@ -474,6 +507,16 @@ class GPT2ParallelTransformer(torch.nn.Module): checkpoint = deepspeed.checkpointing.checkpoint self.sparse_config = sparse_config + def init_plus_from_old(self): + self.position_embeddings_plus.weight.data.view(4, 1024, -1).copy_(self.position_embeddings.weight.data[-1024:]) # FIXME + for layer in self.layers: + layer.attention.init_plus_from_old() + + def reset_sparse_config(self, config): + self.sparse_config = config + for layer in self.layers: + layer.reset_sparse_config(config) + def forward(self, hidden_states, position_ids, attention_mask, *mems): batch_size, query_length = hidden_states.size()[:2] @@ -495,13 +538,14 @@ class GPT2ParallelTransformer(torch.nn.Module): return m attention_mask = build_mask_matrix(query_length, key_length, sep) - # ===================== Image & Text Type Embedding ======================== # - # TODO: after testing, this is not useful. - # extend_len = (key_length + 63) // 64 - # hidden_states = hidden_states + txt_indices_bool.unsqueeze(-1) * self.txt_type_embeddings.view(1, 1, -1) + \ - # img_indices_bool.unsqueeze(-1) * self.img_type_embeddings.expand(extend_len, 64, -1).reshape(extend_len * 64, -1)[memory_length: key_length] - # ===================== END OF BLOCK ======================= # - position_embeddings = self.position_embeddings(position_ids) + + if self.sparse_config.sparse_type == 'cuda_2d': + position = position_ids[..., :self.sparse_config.layout[1]] + position_plus = position_ids[..., self.sparse_config.layout[1]:] + position_embeddings = torch.cat( + (self.position_embeddings(position), self.position_embeddings_plus(position_plus)), dim=-2) + else: + position_embeddings = self.position_embeddings(position_ids) hidden_states = hidden_states + position_embeddings hidden_states = self.embedding_dropout(hidden_states) @@ -539,10 +583,8 @@ class GPT2ParallelTransformer(torch.nn.Module): hidden_states = checkpoint(custom(l, l + chunk_length), *args) l += chunk_length else: - assert self.sparse_config.sparse_type == 'standard' for i, layer in enumerate(self.layers): args = [hidden_states, attention_mask_saved] - if mems: mem_i = mems[i] elif self.max_memory_length > 0: @@ -555,11 +597,6 @@ class GPT2ParallelTransformer(torch.nn.Module): # Final layer norm. output = self.final_layernorm(hidden_states) - # if self.max_memory_length > 0: # TODO cache - # if self.sparse_config.sparse_type != 'cuda_2d': - # mem_layers = self.update_mems(mem_layers, mems) - # else: - # pass # handle update outside the model, because mems is not the full cached qkv. return (output, *mem_layers) @@ -609,245 +646,65 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, atte attention_probs = attention_dropout(attention_probs) # Context layer. # [b, np, s, hn] - - context_layer = torch.matmul(attention_probs, value_layer) - return context_layer - -def sparse_attention_1d(q, k, v, pivot_idx, pivot_attention_mask, query_window=128, key_window_times=6, attention_dropout=None): - ''' Sparse Attention - Args: - q, k, v: inputs, [b, num_heads, s, hn], k is padded to n * query_window - pivot_idx: [b, num_pivots] - pivot_attention_mask: [b, s, num_pivots] - query_window: . - key_window_times: key_window = query_window * key_window_times - ''' - - b, n_head, s, hn = q.shape - b, n_piv = pivot_idx.shape - w = query_window - - pivot_idx_dummy = pivot_idx.view(b, 1, n_piv, 1).expand(b, n_head, n_piv, hn) - # ===================== Pivot Attention ======================== # - pivot_k, pivot_v = torch.gather(k, 2, pivot_idx_dummy), torch.gather(v, 2, pivot_idx_dummy) - attention_scores = torch.matmul(q, pivot_k.transpose(-1, -2)) - pivot_attention_mask = pivot_attention_mask.unsqueeze(1) - - attention_scores_pivot = torch.mul(attention_scores, pivot_attention_mask / math.sqrt(hn)) - 10000.0 * (1.0 - pivot_attention_mask) - - attention_scores_pivot = attention_scores_pivot + math.log(s // n_piv) - # ===================== Window Attention ======================= # - window_k = _chunk(k, query_window, key_window_times) - window_v = _chunk(v, query_window, key_window_times) - # window_k [b, n_head, s // w up int, w*times, hn] - - if s % w == 0: # training # TODO args check - assert k.shape[2] == s - assert window_k.shape[2] == s // w - window_q = q.view(b, n_head, s // w, w, hn) - attention_scores = torch.matmul(window_q, window_k.transpose(-1, -2)) - window_attention_mask = torch.ones((w, w * key_window_times), dtype=attention_scores.dtype, device=q.device).tril_(diagonal=w * (key_window_times - 1)) - attention_scores_window = torch.mul(attention_scores, window_attention_mask / math.sqrt(hn)) - 10000.0 * (1.0 - window_attention_mask) - for t in range(1, key_window_times): - attention_scores_window[:, :, t - 1, :, :w * key_window_times - w * t] -= 10000.0 - else: - raise ValueError('The seq_len must be exactly divided by window_size.') - # ===================== Joint Softmax ======================= # - attention_scores_window = attention_scores_window.view(b, n_head, s, w * key_window_times) - attention_scores = torch.cat((attention_scores_pivot, attention_scores_window), dim=-1) - attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) - - if attention_dropout is not None: - with get_cuda_rng_tracker().fork(): - attention_probs = attention_dropout(attention_probs) - - context_layer = torch.matmul(attention_probs[..., :-w * key_window_times], pivot_v) + torch.einsum('bcgwk,bcgkh->bcgwh', attention_probs[..., -w * key_window_times:].view(b, n_head, s // w, w, w * key_window_times), window_v).view(b, n_head, s, hn) + context_layer = torch.matmul(attention_probs, value_layer) return context_layer -# def sparse_attention_inference_1d(q, k, v, pivot_and_window_idx, **kwargs): -# '''the inference process of sparse attention. -# The Qs are in the same block, but seq_len mod window size might != 0. - -# The Qs are the final tokens of Ks. the pivot_and_window_idx[-query_len] are Qs. - -# ''' -# b, n_head, sq, hn = q.shape -# sk = k.shape[2] -# _b, n_piv = pivot_and_window_idx.shape - -# pivot_and_window_idx_dummy = pivot_and_window_idx.view(b, 1, n_piv, 1).expand(b, n_head, n_piv, hn) -# pivot_k, pivot_v = torch.gather(k, 2, pivot_and_window_idx_dummy), torch.gather(v, 2, pivot_and_window_idx_dummy) -# attention_scores = torch.matmul(q / math.sqrt(hn), pivot_k.transpose(-1, -2)) -# if sq > 1: -# query_part_scores = attention_scores[:, :, -sq:, -sq:] -# m = torch.ones((sq, sq), device=q.device, dtype=q.dtype) * -10000. -# m.triu_(diagonal=1) -# query_part_scores += m - -# attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) - -# context_layer = torch.matmul(attention_probs, pivot_v) -# return context_layer -def transpose_and_split(x, layout, n_head): - x = x.transpose(1, 2) - x = x.reshape(x.shape[0]*n_head, x.shape[1] // n_head, x.shape[2]) - x_text = x[..., :layout[0]] - x0 = x[...,layout[1]:layout[2]].view(x.shape[0], x.shape[1], sqrt(layout[2] - layout[1]), sqrt(layout[2] - layout[1])).contiguous() - x1 = x[...,layout[2]:layout[3]].view(x.shape[0], x.shape[1], sqrt(layout[3] - layout[2]), sqrt(layout[3] - layout[2])).contiguous() - return x, x_text, x0, x1 - -def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs): +def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len=64, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs): ''' - q, k, v: [batch_size, 64+1024+4096, hidden_size] + q0, k0, v0: [batch_size, 1088, hidden_size] + q1, k1, v1: [batch_size, 4096, h2] n_head: int - layout: [endoftext/startofpad, startof0, startof1, endofall] - attention_mask_text2d: [batch_size, sq_len, endoftext] + attention_mask: [batch_size, 1088, 1088] ''' from .local_attention_function import f_similar, f_weighting - b, sq_len, hn = q.shape - alpha = sqrt((layout[3] - layout[2]) // (layout[2] - layout[1])) - - q = q / math.sqrt(hn // n_head) # normalization - - q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext] - k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head) - v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head) - # import pdb; pdb.set_trace() - # all to text - scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d) - scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0]) - # 0 to 0 - scores_0_to_0 = f_similar(q0, k0, kernel_size*2-1, kernel_size, True) - # 1 to 1 + b, s0, h0 = q0.shape + b, s1, h1 = q1.shape + assert v1.shape[-1] == h0, 'q1, k1 can be smaller, but v1 cannot.' + h = h0 // n_head + l0, l1 = int(math.sqrt(s0-text_len)+0.0001), int(math.sqrt(s1)+0.0001) + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1) + # standard attention for level 0 + attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T) + attention_scores = torch.mul(attention_scores, attention_mask) - \ + 10000.0 * (1.0 - attention_mask) + attention_probs0 = F.softmax(attention_scores, dim=-1) + # local attention for level 1 + q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1) + k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True) - # 1 to 0 - scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2] - # softmax - # if 'offset_bias' in kwargs: - # p1, p2 = kernel_size**2//2 + 1, kernel_size2**2 - # offset_bias = kwargs['offset_bias'].expand(b, n_head, p1+p2).view(b*n_head, 1, p1+p2) - # scores_0_to_0 = scores_0_to_0 * offset_bias[...,:p1] - # scores_1_to_1 = scores_1_to_1 * offset_bias[...,:p1] - # scores_1_to_0 = scores_1_to_0 * offset_bias[...,-p2:] - - scores_0 = torch.cat( - (scores_all_to_text[:, layout[1]:layout[2]], - scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])), - dim=-1) - scores_1 = torch.cat( - (scores_all_to_text[:, layout[2]:layout[3]], - scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]), - scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])), - dim=-1) - probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text] - probs_0 = F.softmax(scores_0, dim=-1) # - probs_1 = F.softmax(scores_1, dim=-1) - - if attention_dropout is not None: - with get_cuda_rng_tracker().fork(): - probs_0 = attention_dropout(probs_0) - probs_1 = attention_dropout(probs_1) - # weighting - pad = torch.zeros(layout[1], device=q.device, dtype=q.dtype) - probs_all_to_text = torch.cat(( - probs_text, - pad[-layout[0]:].expand(b*n_head, layout[1]-layout[0], layout[0]), - probs_0[:, :, :layout[0]], - probs_1[:, :, :layout[0]] - ), dim=1) - - context_all_to_text = torch.einsum('bhij,bhcj->bihc', - probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]), - v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn) - - context_0_to_0 = f_weighting(v0, probs_0[..., layout[0]:].view_as(scores_0_to_0).contiguous(), kernel_size*2-1, kernel_size, True) - - context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False) + # attention_probs1 = F.softmax(scores_1_to_1, dim=-1) - context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size*2-1, kernel_size, True) - - context_all_to_01 =torch.cat( + # cross attention + k0T = k0T[..., -l0**2:].reshape(b*n_head, h, l0, l0).contiguous() + scores_1_to_0 = f_similar(q1, k0T, kernel_size2, kernel_size2, False) # [b*n_head, l1, l1, field] + scores_1 = torch.cat( ( - pad.expand(b*n_head, hn//n_head, layout[1]), - context_0_to_0.view(b*n_head, hn//n_head, layout[2]-layout[1]), - (context_1_to_0 + context_1_to_1).view(b*n_head, hn//n_head, layout[3]-layout[2]) - ), dim=-1).view(b, hn, -1).transpose(1, 2) - return context_all_to_text + context_all_to_01 - - -def sparse_attention_2dfull(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs): - ''' - q, k, v: [batch_size, 64+1024+4096, hidden_size] - n_head: int - layout: [endoftext/startofpad, startof0, startof1, endofall] - attention_mask_text2d: [batch_size, sq_len, endoftext] - ''' - from .local_attention_function import f_similar, f_weighting - b, sq_len, hn = q.shape - alpha = sqrt((layout[3] - layout[2]) // (layout[2] - layout[1])) - - q = q / math.sqrt(hn // n_head) # normalization - - q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext] - k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head) - v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head) - # import pdb; pdb.set_trace() - # all to text - scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d) - scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0]) - # 0 to 0 - if not hasattr(sparse_attention_2dfull, 'attention_mask0'): - sparse_attention_2dfull.attention_mask0 = torch.ones((layout[2] - layout[1], layout[2] - layout[1]), device=q.device, dtype=q.dtype).tril_() - attention_mask0 = sparse_attention_2dfull.attention_mask0 - scores_0_to_0 = torch.einsum('bhi,bhj->bij', q0.view(*q0.shape[:2], -1), k0.view(*k0.shape[:2], -1)) * attention_mask0 - 10000.0 * (1.0 - attention_mask0) - # 1 to 1 - scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True) - # 1 to 0 - scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2] - # softmax - - scores_0 = torch.cat( - (scores_all_to_text[:, layout[1]:layout[2]], - scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])), + scores_1_to_0.view(b*n_head, -1, scores_1_to_0.shape[3]), + scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3]) + ), dim=-1) - scores_1 = torch.cat( - (scores_all_to_text[:, layout[2]:layout[3]], - scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]), - scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])), - dim=-1) - probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text] - probs_0 = F.softmax(scores_0, dim=-1) # - probs_1 = F.softmax(scores_1, dim=-1) + attention_probs1 = F.softmax(scores_1, dim=-1) if attention_dropout is not None: with get_cuda_rng_tracker().fork(): - probs_0 = attention_dropout(probs_0) - probs_1 = attention_dropout(probs_1) - # weighting - pad = torch.zeros(layout[1], device=q.device, dtype=q.dtype) - probs_all_to_text = torch.cat(( - probs_text, - pad[-layout[0]:].expand(b*n_head, layout[1]-layout[0], layout[0]), - probs_0[:, :, :layout[0]], - probs_1[:, :, :layout[0]] - ), dim=1) - - context_all_to_text = torch.einsum('bhij,bhcj->bihc', - probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]), - v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn) - - context_0_to_0 = torch.einsum('bcj,bij->bci', v0.view(*v0.shape[:2], -1), probs_0[..., layout[0]:].view_as(scores_0_to_0)) - - context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False) - - context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size*2-1, kernel_size, True) - - context_all_to_01 =torch.cat( - ( - pad.expand(b*n_head, hn//n_head, layout[1]), - context_0_to_0.view(b*n_head, hn//n_head, layout[2]-layout[1]), - (context_1_to_0 + context_1_to_1).view(b*n_head, hn//n_head, layout[3]-layout[2]) - ), dim=-1).view(b, hn, -1).transpose(1, 2) - return context_all_to_text + context_all_to_01 + attention_probs0 = attention_dropout(attention_probs0) + attention_probs1 = attention_dropout(attention_probs1) + + # weighting for level 0 + context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h] + # weighting for level 1 + probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1) + context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True) + context1_to_1 = context1_to_1.view(b, n_head * h, l1**2) + # weighting for cross attention + probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0) + v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0) + context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False) + context1_to_0 = context1_to_0.view(b, n_head * h, l1**2) + return context0.transpose(1, 2).reshape(b, s0, h0), (context1_to_0 + context1_to_1).transpose(-1, -2) \ No newline at end of file diff --git a/mpu/unused_codes.py b/mpu/unused_codes.py new file mode 100644 index 0000000000000000000000000000000000000000..23256e777d9053c75c0ce02ff619bdbf314cbd7d --- /dev/null +++ b/mpu/unused_codes.py @@ -0,0 +1,247 @@ + +def sparse_attention_1d(q, k, v, pivot_idx, pivot_attention_mask, query_window=128, key_window_times=6, attention_dropout=None): + ''' Sparse Attention + Args: + q, k, v: inputs, [b, num_heads, s, hn], k is padded to n * query_window + pivot_idx: [b, num_pivots] + pivot_attention_mask: [b, s, num_pivots] + query_window: . + key_window_times: key_window = query_window * key_window_times + ''' + + b, n_head, s, hn = q.shape + b, n_piv = pivot_idx.shape + w = query_window + + pivot_idx_dummy = pivot_idx.view(b, 1, n_piv, 1).expand(b, n_head, n_piv, hn) + # ===================== Pivot Attention ======================== # + pivot_k, pivot_v = torch.gather(k, 2, pivot_idx_dummy), torch.gather(v, 2, pivot_idx_dummy) + attention_scores = torch.matmul(q, pivot_k.transpose(-1, -2)) + pivot_attention_mask = pivot_attention_mask.unsqueeze(1) + + attention_scores_pivot = torch.mul(attention_scores, pivot_attention_mask / math.sqrt(hn)) - 10000.0 * (1.0 - pivot_attention_mask) + + attention_scores_pivot = attention_scores_pivot + math.log(s // n_piv) + # ===================== Window Attention ======================= # + window_k = _chunk(k, query_window, key_window_times) + window_v = _chunk(v, query_window, key_window_times) + # window_k [b, n_head, s // w up int, w*times, hn] + + if s % w == 0: # training # TODO args check + assert k.shape[2] == s + assert window_k.shape[2] == s // w + window_q = q.view(b, n_head, s // w, w, hn) + attention_scores = torch.matmul(window_q, window_k.transpose(-1, -2)) + window_attention_mask = torch.ones((w, w * key_window_times), dtype=attention_scores.dtype, device=q.device).tril_(diagonal=w * (key_window_times - 1)) + attention_scores_window = torch.mul(attention_scores, window_attention_mask / math.sqrt(hn)) - 10000.0 * (1.0 - window_attention_mask) + for t in range(1, key_window_times): + attention_scores_window[:, :, t - 1, :, :w * key_window_times - w * t] -= 10000.0 + else: + raise ValueError('The seq_len must be exactly divided by window_size.') + # ===================== Joint Softmax ======================= # + attention_scores_window = attention_scores_window.view(b, n_head, s, w * key_window_times) + attention_scores = torch.cat((attention_scores_pivot, attention_scores_window), dim=-1) + attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) + + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs = attention_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs[..., :-w * key_window_times], pivot_v) + torch.einsum('bcgwk,bcgkh->bcgwh', attention_probs[..., -w * key_window_times:].view(b, n_head, s // w, w, w * key_window_times), window_v).view(b, n_head, s, hn) + + return context_layer + + +def transpose_and_split(x, layout, n_head): + x = x.transpose(1, 2) + x = x.reshape(x.shape[0]*n_head, x.shape[1] // n_head, x.shape[2]) + x_text = x[..., :layout[0]] + x0 = x[...,layout[1]:layout[2]].view(x.shape[0], x.shape[1], sqrt(layout[2] - layout[1]), sqrt(layout[2] - layout[1])).contiguous() + x1 = x[...,layout[2]:layout[3]].view(x.shape[0], x.shape[1], sqrt(layout[3] - layout[2]), sqrt(layout[3] - layout[2])).contiguous() + return x, x_text, x0, x1 + +def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs): + ''' + q, k, v: [batch_size, 64+1024+4096, hidden_size] + n_head: int + layout: [endoftext/startofpad, startof0, startof1, endofall] + attention_mask_text2d: [batch_size, sq_len, endoftext] + ''' + from .local_attention_function import f_similar, f_weighting + b, sq_len, hn = q.shape + alpha = sqrt((layout[3] - layout[2]) // (layout[2] - layout[1])) + + q = q / math.sqrt(hn // n_head) # normalization + + q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext] + k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head) + v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head) + # import pdb; pdb.set_trace() + # all to text + scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d) + scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0]) + # 0 to 0 + scores_0_to_0 = f_similar(q0, k0, kernel_size*2-1, kernel_size, True) + # 1 to 1 + scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True) + # 1 to 0 + scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2] + # softmax + # if 'offset_bias' in kwargs: + # p1, p2 = kernel_size**2//2 + 1, kernel_size2**2 + # offset_bias = kwargs['offset_bias'].expand(b, n_head, p1+p2).view(b*n_head, 1, p1+p2) + # scores_0_to_0 = scores_0_to_0 * offset_bias[...,:p1] + # scores_1_to_1 = scores_1_to_1 * offset_bias[...,:p1] + # scores_1_to_0 = scores_1_to_0 * offset_bias[...,-p2:] + + scores_0 = torch.cat( + (scores_all_to_text[:, layout[1]:layout[2]], + scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])), + dim=-1) + scores_1 = torch.cat( + (scores_all_to_text[:, layout[2]:layout[3]], + scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]), + scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])), + dim=-1) + probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text] + probs_0 = F.softmax(scores_0, dim=-1) # + probs_1 = F.softmax(scores_1, dim=-1) + + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + probs_0 = attention_dropout(probs_0) + probs_1 = attention_dropout(probs_1) + # weighting + pad = torch.zeros(layout[1], device=q.device, dtype=q.dtype) + probs_all_to_text = torch.cat(( + probs_text, + pad[-layout[0]:].expand(b*n_head, layout[1]-layout[0], layout[0]), + probs_0[:, :, :layout[0]], + probs_1[:, :, :layout[0]] + ), dim=1) + + context_all_to_text = torch.einsum('bhij,bhcj->bihc', + probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]), + v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn) + + context_0_to_0 = f_weighting(v0, probs_0[..., layout[0]:].view_as(scores_0_to_0).contiguous(), kernel_size*2-1, kernel_size, True) + + context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False) + + context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size*2-1, kernel_size, True) + + context_all_to_01 =torch.cat( + ( + pad.expand(b*n_head, hn//n_head, layout[1]), + context_0_to_0.view(b*n_head, hn//n_head, layout[2]-layout[1]), + (context_1_to_0 + context_1_to_1).view(b*n_head, hn//n_head, layout[3]-layout[2]) + ), dim=-1).view(b, hn, -1).transpose(1, 2) + return context_all_to_text + context_all_to_01 + + +def sparse_attention_2dfull(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs): + ''' + q, k, v: [batch_size, 64+1024+4096, hidden_size] + n_head: int + layout: [endoftext/startofpad, startof0, startof1, endofall] + attention_mask_text2d: [batch_size, sq_len, endoftext] + ''' + from .local_attention_function import f_similar, f_weighting + b, sq_len, hn = q.shape + alpha = sqrt((layout[3] - layout[2]) // (layout[2] - layout[1])) + + q = q / math.sqrt(hn // n_head) # normalization + + q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext] + k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head) + v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head) + # import pdb; pdb.set_trace() + # all to text + scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d) + scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0]) + # 0 to 0 + if not hasattr(sparse_attention_2dfull, 'attention_mask0'): + sparse_attention_2dfull.attention_mask0 = torch.ones((layout[2] - layout[1], layout[2] - layout[1]), device=q.device, dtype=q.dtype).tril_() + attention_mask0 = sparse_attention_2dfull.attention_mask0 + scores_0_to_0 = torch.einsum('bhi,bhj->bij', q0.view(*q0.shape[:2], -1), k0.view(*k0.shape[:2], -1)) * attention_mask0 - 10000.0 * (1.0 - attention_mask0) + # 1 to 1 + scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True) + # 1 to 0 + scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2] + # softmax + + scores_0 = torch.cat( + (scores_all_to_text[:, layout[1]:layout[2]], + scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])), + dim=-1) + scores_1 = torch.cat( + (scores_all_to_text[:, layout[2]:layout[3]], + scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]), + scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])), + dim=-1) + probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text] + probs_0 = F.softmax(scores_0, dim=-1) # + probs_1 = F.softmax(scores_1, dim=-1) + + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + probs_0 = attention_dropout(probs_0) + probs_1 = attention_dropout(probs_1) + # weighting + pad = torch.zeros(layout[1], device=q.device, dtype=q.dtype) + probs_all_to_text = torch.cat(( + probs_text, + pad[-layout[0]:].expand(b*n_head, layout[1]-layout[0], layout[0]), + probs_0[:, :, :layout[0]], + probs_1[:, :, :layout[0]] + ), dim=1) + + context_all_to_text = torch.einsum('bhij,bhcj->bihc', + probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]), + v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn) + + context_0_to_0 = torch.einsum('bcj,bij->bci', v0.view(*v0.shape[:2], -1), probs_0[..., layout[0]:].view_as(scores_0_to_0)) + + context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False) + + context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size*2-1, kernel_size, True) + + context_all_to_01 =torch.cat( + ( + pad.expand(b*n_head, hn//n_head, layout[1]), + context_0_to_0.view(b*n_head, hn//n_head, layout[2]-layout[1]), + (context_1_to_0 + context_1_to_1).view(b*n_head, hn//n_head, layout[3]-layout[2]) + ), dim=-1).view(b, hn, -1).transpose(1, 2) + return context_all_to_text + context_all_to_01 + + +if args.sparse_config.sparse_type == 'cuda_2d': + layout = args.sparse_config.layout + unpad_indices = (data[:, :layout[1]+1] >= 0) * 10000. + unpad_indices[:, -1] = 9000. + starts = (torch.arange(layout[1]+1, device=data.device).expand_as(unpad_indices) + unpad_indices).min(dim=-1)[1] + layout[0] = starts.max().item() + attention_mask = torch.ones((batch_size, seq_length, layout[0]), device=data.device) + for i in range(batch_size): + attention_mask[i, :, starts[i]:layout[1]] = 0 + attention_mask[:, :layout[0]].tril_() + attention_mask = attention_mask.unsqueeze(1) +elif args.sparse_config.sparse_type == 'standard': + attention_mask = torch.ones((batch_size, seq_length, seq_length), device=data.device) + attention_mask.tril_() + # attention_mask = torch.zeros((seq_length, seq_length), device=data.device) + # h = w = 32 + # k1=9 + # layout = [10, 64, 64+h*w, 64+h*w*5] + # for i in range(layout[1]): + # attention_mask[i, :i+1] = 1 + # for i in range(layout[1], layout[2]): + # x = (i - layout[1]) // w + # y = (i - layout[1]) % w + # lx = max(0, x - k1 // 2) + # ly = max(0, y - k1 // 2) + # rx = min(h-1, x + k1 // 2) + # ry = min(w-1, y + k1 // 2) + # attention_mask[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1 + # attention_mask[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1 + # attention_mask = attention_mask.unsqueeze(0).unsqueeze(0) \ No newline at end of file diff --git a/pretrain_gpt2.py b/pretrain_gpt2.py index 784153e8b65c70c7004525dfdf208fd69c9cf35a..87844669294b4455e8a1687d405ac83e977f9ca2 100755 --- a/pretrain_gpt2.py +++ b/pretrain_gpt2.py @@ -73,7 +73,8 @@ def get_model(args, sparse_config=None): checkpoint_num_layers=args.checkpoint_num_layers, parallel_output=True, sparse_config=sparse_config if sparse_config is not None else args.sparse_config, - sandwich_ln=args.sandwich_ln + sandwich_ln=args.sandwich_ln, + finetune=args.finetune ) if mpu.get_data_parallel_rank() == 0: @@ -163,7 +164,7 @@ def get_learning_rate_scheduler(optimizer, args): num_iters = args.lr_decay_iters else: num_iters = args.train_iters - num_iters = max(1, num_iters) + num_iters = max(1, num_iters - args.restart_iter) init_step = -1 warmup_iter = args.warmup * num_iters lr_scheduler = AnnealingLR(optimizer, @@ -172,7 +173,9 @@ def get_learning_rate_scheduler(optimizer, args): num_iters=num_iters, decay_style=args.lr_decay_style, last_iter=init_step, - decay_ratio=args.lr_decay_ratio) + decay_ratio=args.lr_decay_ratio, + restart_iter=args.restart_iter + ) return lr_scheduler @@ -182,6 +185,12 @@ def setup_model_and_optimizer(args): model = get_model(args) + if args.finetune: + model.requires_grad_(False) + for name, param in model.named_parameters(): + if name.find('_plus') > 0: + param.requires_grad_(True) + param_groups = get_optimizer_param_groups(model) if args.train_data is not None: @@ -213,38 +222,18 @@ def get_masks_and_position_ids(data, # Attention mask (lower triangular). if attention_mask is None: - # single direction, [PAD]s are at the end of the seq, so doesn't matter. if args.sparse_config.sparse_type == 'cuda_2d': - layout = args.sparse_config.layout - unpad_indices = (data[:, :layout[1]+1] >= 0) * 10000. - unpad_indices[:, -1] = 9000. - starts = (torch.arange(layout[1]+1, device=data.device).expand_as(unpad_indices) + unpad_indices).min(dim=-1)[1] - layout[0] = starts.max().item() - attention_mask = torch.ones((batch_size, seq_length, layout[0]), device=data.device) + # single direction, [PAD]s are at the start of the seq. + assert loss_mask is not None + # loss_mask has n_pad(+1 CLS and [1:] then) zeros, so it is the same as attention_mask, reuse. + attention_mask = loss_mask[:, :args.layout[1]].unsqueeze(-2).expand(batch_size, args.layout[1], args.layout[1]).tril() for i in range(batch_size): - attention_mask[i, :, starts[i]:layout[1]] = 0 - attention_mask[:, :layout[0]].tril_() + attention_mask[i].fill_diagonal_(1) attention_mask = attention_mask.unsqueeze(1) elif args.sparse_config.sparse_type == 'standard': attention_mask = torch.ones((batch_size, seq_length, seq_length), device=data.device) attention_mask.tril_() - # attention_mask = torch.zeros((seq_length, seq_length), device=data.device) - # h = w = 32 - # k1=9 - # layout = [10, 64, 64+h*w, 64+h*w*5] - # for i in range(layout[1]): - # attention_mask[i, :i+1] = 1 - # for i in range(layout[1], layout[2]): - # x = (i - layout[1]) // w - # y = (i - layout[1]) % w - # lx = max(0, x - k1 // 2) - # ly = max(0, y - k1 // 2) - # rx = min(h-1, x + k1 // 2) - # ry = min(w-1, y + k1 // 2) - # attention_mask[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1 - # attention_mask[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1 - # attention_mask = attention_mask.unsqueeze(0).unsqueeze(0) - elif args.sparse_config.sparse_type == 'torch_1d': + else: raise NotImplementedError # Loss mask. @@ -252,26 +241,19 @@ def get_masks_and_position_ids(data, loss_mask = torch.ones(data.size(), dtype=data.dtype, device=data.device) # Position ids. - if args is not None and args.finetune and args.max_position_embeddings < args.max_position_embeddings_finetune: - # for each sample, find [ROI2] and split - # ([ROI1] text... [BOI1] img... [EOI1] [ROI2]<pos_id==1089> ...) - start_token = get_tokenizer()['[ROI2]'] - tmp = torch.nonzero(data == start_token, as_tuple=False) - start_token_poses = [100000] * batch_size - for x, y in tmp: - start_token_poses[x] = min(start_token_poses[x], y) - assert 100000 not in start_token_poses, 'Some samples do not have [ROI2]!' + if args.sparse_config.sparse_type == 'cuda_2d': + assert loss_mask is not None + layout = args.layout + assert seq_length == layout[-1] + n_pads = seq_length - loss_mask.sum(dim=-1).long() position_ids = torch.zeros(batch_size, seq_length, dtype=torch.long, device=data.device) for i in range(batch_size): - sep = start_token_poses[i] - torch.arange(start=0, end=sep, out=position_ids[i, :sep], + torch.arange(layout[1] - n_pads[i], out=position_ids[i, n_pads[i]:layout[1]], dtype=torch.long, device=data.device) - second_pos = 0 # reuse - torch.arange(start=second_pos, end=second_pos + seq_length - sep, - out=position_ids[i, sep:], + torch.arange(layout[2] - layout[1], + out=position_ids[i, layout[1]:], dtype=torch.long, device=data.device) - position_ids[position_ids >= args.max_position_embeddings] = args.max_position_embeddings - 1 else: position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) @@ -298,28 +280,9 @@ def get_batch(data_iterator, args, timers): tokens_ = data_b['text'].long() loss_mask = data_b['loss_mask'].float() - #FIXME change order - if args.sparse_config.sparse_type == 'cuda_2d': - assert args.sparse_config.layout[-1] == 64+32**2+64**2 - tokens_new = tokens_.clone() - tokens_new[:, 64:64+32**2] = tokens_[:, 64+64**2:] - tokens_new[:, 64+32**2:] = tokens_[:, 64:64+64**2] - tokens_ = tokens_new - # only 32# - # tokens_ = tokens_[:, :64+32**2] - # loss_mask = loss_mask[:, :64+32**2] - - if args.dataset_type == 'BinaryDataset': - labels = tokens_.contiguous() - loss_mask = loss_mask.contiguous() - tokenizer = get_tokenizer() - cls_token = torch.zeros(tokens_.shape[0], 1, dtype=tokens_.dtype, device=tokens_.device) + tokenizer['[CLS]'] - tokens = torch.cat((cls_token, tokens_[:, :-1]), dim=1) - tokens[:, 64] = tokenizer['[BASE]'] - else: - labels = tokens_[:, 1:].contiguous() - loss_mask = loss_mask[:, 1:].contiguous() - tokens = tokens_[:, :-1].contiguous() + labels = tokens_[:, 1:].contiguous() + loss_mask = loss_mask[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() attention_mask = None @@ -344,7 +307,6 @@ def forward_step(data_iterator, model, args, timers, mems): timers('batch generator').start() tokens, labels, loss_mask, attention_mask, position_ids = get_batch( data_iterator, args, timers) - timers('batch generator').stop() # split img & txt positions, [PAD] not included # TODO check enough @@ -352,7 +314,6 @@ def forward_step(data_iterator, model, args, timers, mems): img_txt_sep = tokenizer.img_tokenizer.num_tokens img_indices_bool = (tokens.detach() < img_txt_sep) & (loss_mask > 0) txt_indices_bool = (~img_indices_bool) & (loss_mask > 0) - # Forward model. logits, *mems = model(tokens, position_ids, attention_mask, *mems) losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), @@ -363,15 +324,14 @@ def forward_step(data_iterator, model, args, timers, mems): losses = losses.view(-1) * loss_mask loss = torch.sum(losses) / loss_mask.sum() - # ===================== Log partial losses ======================== # if args.sparse_config.sparse_type == 'cuda_2d': img_indices_bool2 = img_indices_bool.clone() - img_indices_bool2[:, :args.sparse_config.layout[2]] = False + img_indices_bool2[:, :args.sparse_config.layout[1]] = False img_loss2 = losses[img_indices_bool2.view(-1)].detach().sum() / max(img_indices_bool2.sum(), 1) torch.distributed.all_reduce(img_loss2.data) img_loss2.data = img_loss2.data / args.world_size - img_indices_bool[:, args.sparse_config.layout[2]:] = False + img_indices_bool[:, args.sparse_config.layout[1]:] = False else: img_loss2 = 0 img_indices_bool = img_indices_bool.view(-1) @@ -386,9 +346,6 @@ def forward_step(data_iterator, model, args, timers, mems): txt_loss.data = txt_loss.data / args.world_size # ===================== END OF BLOCK ======================= # - # import pdb;pdb.set_trace() - # with open('tmp_save.bin', 'wb') as fout: - # torch.save(tokens, fout) return loss, mems, img_loss, txt_loss, img_loss2 @@ -471,7 +428,6 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, timers('backward').start() lm_loss_reduced = backward_step(optimizer, model, lm_loss, args, timers) timers('backward').stop() - # Update parameters. skipped_iter, complete = 0, False timers('optimizer').start() @@ -674,7 +630,14 @@ def evaluate(data_iterator, model, args, timers, verbose=False): def evaluate_and_print_results(prefix, data_iterator, model, args, timers, verbose=False, step=None, summary_writer=None): """Helper function to evaluate and dump results on screen.""" + # import line_profiler + # profile = line_profiler.LineProfiler(model.module.module.transformer.layers[0].forward) + # profile.enable() + # torch.cuda.empty_cache() lm_loss = evaluate(data_iterator, model, args, timers, verbose) + # profile.disable() # åœæ¢åˆ†æž + # import sys + # profile.print_stats(sys.stdout) lm_ppl = math.exp(min(20, lm_loss)) report_evaluate_metrics(summary_writer, prefix, lm_loss, lm_ppl, step) diff --git a/random_display.py b/random_display.py index c59c51adaf56937df7c1bcef072759890006ca21..32096ddbf43b589665943d94bab91b701bec4501 100644 --- a/random_display.py +++ b/random_display.py @@ -6,12 +6,12 @@ import torch import random test_dir = 'tmp' # bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin' -bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/quanjing005/quanjing005.bin.part_0.cogdata' +bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/quanjing003/quanjing003.bin.part_0.cogdata' bin_ds = BinaryDataset(os.path.join(bin_dir), process_fn=lambda x:x, length_per_sample=64*64+32*32+64, dtype='int32', preload=False) args = argparse.Namespace(img_tokenizer_path='pretrained/vqvae/vqvae_hard_biggerset_011.pt', img_tokenizer_num_tokens=None) tokenizer = get_tokenizer(args) -bin_ds = [bin_ds[random.randint(0, len(bin_ds)-1)] for i in range(16)] +bin_ds = [bin_ds[random.randint(0, len(bin_ds)-1)] for i in range(32)] for x in bin_ds: end = x.tolist().index(-1) print(tokenizer.DecodeIds(x[:end])[0]) diff --git a/scripts/cuda_2d_text2image.sh b/scripts/cuda_2d_text2image.sh index 46783ccda9f37af232fd7f46de97c96463d5fccf..3f03974a9f479821f415107e4cefb84efae97ed2 100755 --- a/scripts/cuda_2d_text2image.sh +++ b/scripts/cuda_2d_text2image.sh @@ -1,16 +1,16 @@ #!/bin/bash -CHECKPOINT_PATH=data/checkpoints/cogview-fixgrad-small08-25-09-38 +CHECKPOINT_PATH=data/checkpoints/cogview-long # CHECKPOINT_PATH=data/checkpoints/cogview-compare -NLAYERS=16 -NHIDDEN=1024 -NATT=16 +NLAYERS=48 +NHIDDEN=2560 +NATT=40 MAXSEQLEN=5184 MASTER_PORT=$(shuf -n 1 -i 10000-65535) MPSIZE=1 #SAMPLING ARGS -TEMP=1.05 +TEMP=1.03 #If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p TOPK=100 TOPP=0 @@ -25,22 +25,24 @@ MASTER_PORT=${MASTER_PORT} python generate_samples.py \ --hidden-size $NHIDDEN \ --load $CHECKPOINT_PATH \ --num-attention-heads $NATT \ - --max-position-embeddings 5184 \ + --max-position-embeddings 1089 \ --fp16 \ --temperature $TEMP \ --top_k $TOPK \ --top_p $TOPP \ --sandwich-ln \ --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ - --sparse-type standard \ + --sparse-type cuda_2d \ --max-position-embeddings-finetune $MAXSEQLEN \ --generation-task "cuda-2d generation" \ --input-source ./input.txt \ - --output-path samples_text2image \ - --batch-size 2 \ + --output-path samples_cuda_2d2 \ + --batch-size 3 \ --max-inference-batch-size 4 \ --device 0 \ - --sparse-type standard \ + --finetune \ + --no-load-optim \ + --sparse-type cuda_2d \ $@ diff --git a/scripts/pretrain_multiple_nodes.sh b/scripts/pretrain_multiple_nodes.sh index a03cf273fcb587611f27c318d469a96d62af2366..a1245d3f2bd70fac8f5f1809ac6c7cdf79f0f68d 100755 --- a/scripts/pretrain_multiple_nodes.sh +++ b/scripts/pretrain_multiple_nodes.sh @@ -2,7 +2,7 @@ # Change for multinode config -NUM_WORKERS=10 +NUM_WORKERS=19 NUM_GPUS_PER_WORKER=8 MP_SIZE=1 @@ -12,23 +12,22 @@ main_dir=$(dirname $script_dir) # OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=bond0 NCCL_IB_GID_INDEX=3 NCCL_NET_GDR_LEVEL=0" OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" -HOST_FILE_PATH="hostfile2" +HOST_FILE_PATH="hostfile" # OPTIONS_NCCL="" # HOST_FILE_PATH="hostfile_single" small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/zijian/zijian.bin.part_0.cogdata" full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin" -config_json="$script_dir/ds_config.json" +config_json="$script_dir/ds_config_zero.json" gpt_options=" \ - --experiment-name cogview-fixgrad-small-test \ + --experiment-name cogview-base-continue-long \ --img-tokenizer-num-tokens 8192 \ - --dataset-type BinaryDataset \ + --dataset-type CompactBinaryDataset \ --model-parallel-size ${MP_SIZE} \ - --num-layers 16 \ - --hidden-size 1024 \ - --num-attention-heads 16 \ - --save $main_dir/data/checkpoints \ + --num-layers 48 \ + --hidden-size 2560 \ + --num-attention-heads 40 \ --train-iters 300000 \ --resume-dataloader \ --train-data ${full_data} \ @@ -38,24 +37,35 @@ gpt_options=" \ --warmup .1 \ --checkpoint-activations \ --deepspeed-activation-checkpointing \ - --max-position-embeddings 5184 \ + --max-position-embeddings 1089 \ --max-memory-length 0 \ --sandwich-ln \ - --txt-loss-scale 10 \ + --txt-loss-scale 0.1 \ --sparse-type cuda_2d \ --fp16 \ --save-interval 2000 \ - --load data/checkpoints/cogview-compare + --no-load-optim \ + --no-save-optim \ + --eval-interval 1000 \ + --save /root/checkpoints \ + --fast-load \ + --load data/checkpoints/cogview-continue \ + --finetune " - # + +# --finetune + # --save $main_dir/data/checkpoints \ + # --restart-iter 199000 + + gpt_options="${gpt_options} - --deepspeed \ - --deepspeed_config ${config_json} \ + --deepspeed \ + --deepspeed_config ${config_json} \ " - + run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_gpt2.py $@ ${gpt_options}" echo ${run_cmd} diff --git a/scripts/testnan.sh b/scripts/testnan.sh index a263aa1747a80ac99eb8efa1f57584cff67826b8..2095c0cbad61617442c22799f1938d63fe3cbec2 100755 --- a/scripts/testnan.sh +++ b/scripts/testnan.sh @@ -21,11 +21,11 @@ config_json="$script_dir/ds_config.json" gpt_options=" \ --experiment-name cogview-testlocal \ --img-tokenizer-num-tokens 8192 \ - --dataset-type BinaryDataset \ + --dataset-type CompactBinaryDataset \ --model-parallel-size ${MP_SIZE} \ - --num-layers 16 \ - --hidden-size 1024 \ - --num-attention-heads 16 \ + --num-layers 48 \ + --hidden-size 2560 \ + --num-attention-heads 40 \ --save $main_dir/data/checkpoints \ --train-iters 100000 \ --resume-dataloader \ @@ -36,15 +36,19 @@ gpt_options=" \ --warmup .1 \ --checkpoint-activations \ --deepspeed-activation-checkpointing \ - --max-position-embeddings 5184 \ + --max-position-embeddings 1089 \ --max-memory-length 0 \ - --txt-loss-scale 2 \ + --txt-loss-scale 1 \ --sandwich-ln \ - --sparse-type cuda_2d \ + --sparse-type standard \ --save-interval 2500 \ - --load data/checkpoints/cogview-fixgrad-small08-25-09-38 + --fp16 \ + --eval-iters 1000 \ + --load pretrained/cogview/cogview-base " - # --fp16 \ + # + # --load data/checkpoints/cogview-fixgrad-small08-25-09-38 + gpt_options="${gpt_options} diff --git a/scripts/text2image.sh b/scripts/text2image.sh index 38509fbd8c2fd8ad07e6d59249a827bc9a0b4e8e..fdb3bf3a17ddac30b3954a50d9f219691ed13f12 100755 --- a/scripts/text2image.sh +++ b/scripts/text2image.sh @@ -6,7 +6,8 @@ # NHIDDEN=1024 # NATT=16 -CHECKPOINT_PATH=pretrained/cogview/cogview-base +CHECKPOINT_PATH=data/checkpoints/cogview-continue +# CHECKPOINT_PATH=pretrained/cogview/cogview-base NLAYERS=48 NHIDDEN=2560 NATT=40 @@ -42,8 +43,8 @@ MASTER_PORT=${MASTER_PORT} python generate_samples.py \ --generation-task text2image \ --input-source ./input.txt \ --output-path samples_text2image \ - --batch-size 4 \ - --max-inference-batch-size 4 \ + --batch-size 8 \ + --max-inference-batch-size 8 \ --device 0 \ $@ diff --git a/test_sparse_attention.py b/test_sparse_attention.py index abac1ed99ef230c2054eb997fe815c5cd119c9e0..9716123c45b87b6078201b45937baad70ed7d32a 100644 --- a/test_sparse_attention.py +++ b/test_sparse_attention.py @@ -4,8 +4,7 @@ from tqdm import tqdm import torch import numpy as np -from mpu.sparse_transformer import standard_attention, sparse_attention_1d, sparse_attention_2d, sparse_attention_2dfull - +from mpu.sparse_transformer import standard_attention, sparse_attention_2d_light def test_sparse_attention_1d(): s, w, times = 4096 + 128, 128, 2 num_pivot = 768 @@ -80,7 +79,7 @@ def test_sparse_attention_2d(): device = 'cuda' b, n_head, hn = 2, 16, 1024 h = w = 32 - layout = [10, 64, 64+h*w, 64+h*w*5] + layout = [64, 64+h*w, 64+h*w*5] k1 = 9 k2 = 7 k1h = k1*2-1 @@ -94,9 +93,6 @@ def test_sparse_attention_2d(): m = mask[0] for i in range(layout[1]): m[i, :i+1] = 1 - m[layout[1]:, :layout[0]] = 1 - for i in tqdm(range(layout[1], layout[2])): - m[i, layout[1]:i+1] = 1 # for i in tqdm(range(layout[1], layout[2])): # x = (i - layout[1]) // w # y = (i - layout[1]) % w @@ -106,15 +102,15 @@ def test_sparse_attention_2d(): # ry = min(w-1, y + k1 // 2) # m[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1 # m[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1 - for i in tqdm(range(layout[2], layout[3])): - x = (i - layout[2]) // (2*w) - y = (i - layout[2]) % (2*w) + for i in tqdm(range(layout[1], layout[2])): + x = (i - layout[1]) // (2*w) + y = (i - layout[1]) % (2*w) lx = max(0, x - k1h // 2) ly = max(0, y - k1 // 2) rx = min(2*h-1, x + k1h // 2) ry = min(2*w-1, y + k1 // 2) - m[i, layout[2]:layout[3]].view(h*2, w*2)[lx:x, ly:ry+1] = 1 - m[i, layout[2]:layout[3]].view(h*2, w*2)[x, ly:y+1] = 1 + m[i, layout[1]:layout[2]].view(h*2, w*2)[lx:x, ly:ry+1] = 1 + m[i, layout[1]:layout[2]].view(h*2, w*2)[x, ly:y+1] = 1 x = x // 2 y = y // 2 @@ -122,7 +118,7 @@ def test_sparse_attention_2d(): ly = max(0, y - k2 // 2) rx = min(h-1, x + k2 // 2) ry = min(w-1, y + k2 // 2) - m[i, layout[1]:layout[2]].view(h, w)[lx:rx+1, ly:ry+1] = 1 + m[i, layout[0]:layout[1]].view(h, w)[lx:rx+1, ly:ry+1] = 1 mask[1:] = mask[0] # mask[1][layout[1]:, layout[0]-1] = 0 @@ -133,15 +129,18 @@ def test_sparse_attention_2d(): torch.cuda.synchronize() t0 = time.time() qkv_tmp = qkv.view(3, b, layout[-1], n_head, hn//n_head).permute(0, 1, 3, 2, 4).contiguous() - r1 = standard_attention(*qkv_tmp, mask.unsqueeze(1)).transpose(1, 2).reshape(b, layout[3], hn) + r1 = standard_attention(*qkv_tmp, mask.unsqueeze(1)).transpose(1, 2).reshape(b, layout[2], hn) torch.cuda.synchronize() t1 = time.time() - r2 = sparse_attention_2dfull(*qkv2, n_head, layout, mask[...,:layout[0]].unsqueeze(1), kernel_size=k1, kernel_size2=k2) + # r2 = sparse_attention_2dfull(*qkv2, n_head, layout, mask[...,:layout[0]].unsqueeze(1), kernel_size=k1, kernel_size2=k2) + qkv20, qkv21 = qkv2[:, :, :layout[1]], qkv2[:, :, layout[1]:] + r20, r21 = sparse_attention_2d_light(*qkv20, *qkv21, mask[...,:layout[1],:layout[1]].unsqueeze(1), n_head, layout[0],kernel_size=k1, kernel_size2=k2) + r2 = torch.cat((r20, r21), dim=1) torch.cuda.synchronize() t2 = time.time() print('times: standard ', t1-t0, ' sparse ', t2-t1) - print(( (r1[:,:layout[0]]-r2[:,:layout[0]]).abs() / (r1[:,:layout[0]].abs()+r2[:,:layout[0]].abs())).max()) + print(( (r1[:,:layout[1]]-r2[:,:layout[1]]).abs() / (r1[:,:layout[1]].abs()+r2[:,:layout[1]].abs())).max()) print(( (r1[:,layout[1]:]-r2[:,layout[1]:]).abs() / (r1[:,layout[1]:].abs()+r2[:,layout[1]:].abs())).max()) qkv.retain_grad() l2 = r2[:,layout[1]:].sum() @@ -153,7 +152,7 @@ def test_sparse_attention_2d(): g1 = qkv.grad g2 = qkv2.grad print( (g1-g2).abs().max()) - print( ((g1-g2).abs() / (g1.abs()+g2.abs()+1e-5)).max()) + print( ((g1-g2).abs() / (g1.abs()+g2.abs()+1e-3)).max()) import pdb;pdb.set_trace() diff --git a/utils.py b/utils.py index c5079ce515c1559616b3e6d24dbf365791b82cb9..0d4c51d64c77299f56b9e6861b376a4a37808518 100755 --- a/utils.py +++ b/utils.py @@ -248,8 +248,34 @@ def save_ds_checkpoint(iteration, model, lr_scheduler, args): sd['torch_rng_state'] = torch.get_rng_state() sd['cuda_rng_state'] = torch.cuda.get_rng_state() sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() + if args.no_save_optim: + save_ds_checkpoint_no_optim(model, args.save, str(iteration), client_state=sd) + else: + model.save_checkpoint(args.save, str(iteration), client_state=sd) + +def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save_latest=True): + + os.makedirs(save_dir, exist_ok=True) + + if tag is None: + tag = f"global_step{model.global_steps}" + + # Ensure tag is a string + tag = str(tag) + + # Ensure checkpoint tag is consistent across ranks + model._checkpoint_tag_validation(tag) + + if model.save_non_zero_checkpoint: + model._create_checkpoint_file(save_dir, tag, False) + model._save_checkpoint(save_dir, tag, client_state=client_state) + + # Save latest checkpoint tag + if save_latest: + with open(os.path.join(save_dir, 'latest'), 'w') as fd: + fd.write(tag) - model.save_checkpoint(args.save, str(iteration), client_state=sd) + return True def get_checkpoint_iteration(args): @@ -296,8 +322,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states= if args.deepspeed: - checkpoint_name, sd = model.load_checkpoint(args.load, iteration, load_optimizer_states=not args.no_load_optim) - if "client_lr_scheduler" in sd: + checkpoint_name, sd = model.load_checkpoint(args.load, iteration, load_optimizer_states=not args.no_load_optim, load_module_strict=not args.finetune) + if args.finetune: + model.module.module.init_plus_from_old() + if (args.finetune or args.no_load_optim) and model.zero_optimization(): + model.optimizer.refresh_fp32_params() + if "client_lr_scheduler" in sd and not args.finetune: lr_scheduler.load_state_dict(sd["client_lr_scheduler"]) print_rank_0("Load lr scheduler state") if checkpoint_name is None: