diff --git a/arguments.py b/arguments.py index 067a9a7cc579a49bc0374963f46e836b1636076c..b06c987ea51f430abef05a49caa2df6734fa8a98 100755 --- a/arguments.py +++ b/arguments.py @@ -218,7 +218,8 @@ def add_text_generate_args(parser): 'super-resolution', 'low-level super-resolution', 'post-selection', - 'raw' + 'raw', + 'cuda-2d generation' ], help='what type of inference task to use') group.add_argument('--input-source', type=str, default='interactive', @@ -300,7 +301,7 @@ def add_sparse_args(parser): group.add_argument("--key-window-times", type=int, default=6) group.add_argument("--num-pivot", type=int, default=768) # for cuda_2d - group.add_argument("--kernel-size", type=int, default=9) + group.add_argument("--kernel-size", type=int, default=11) group.add_argument("--kernel-size2", type=int, default=7) group.add_argument("--layout", type=str, default='0,64,1088,5184') return parser @@ -309,7 +310,7 @@ def make_sparse_config(args): sparse_config = argparse.Namespace(sparse_type=args.sparse_type) if args.sparse_type == 'standard': pass - elif args.sparse_type == 'cuda_2d': + if args.sparse_type == 'cuda_2d' or args.generation_task == 'cuda-2d generation': sparse_config.kernel_size = args.kernel_size sparse_config.kernel_size2 = args.kernel_size2 sparse_config.layout = [int(x) for x in args.layout.split(',')] diff --git a/data_utils/configure_data.py b/data_utils/configure_data.py index 4e4cb0b5ca4d5ba382777eba71a71fb6ab7d248d..08d6a59d9176f778819c5bf70703254ea5abafe0 100755 --- a/data_utils/configure_data.py +++ b/data_utils/configure_data.py @@ -37,12 +37,13 @@ def make_data_loader(dataset, batch_size, num_iters, args): drop_last = distributed # the GPUs in the same model parallel group receive the same data if distributed: + gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1) batch_sampler = DistributedBatchSampler(sampler, batch_size, drop_last, rank, world_size, - gradient_accumulation_steps=args.gradient_accumulation_steps) + gradient_accumulation_steps=gradient_accumulation_steps) else: batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, diff --git a/draw_diff.py b/draw_diff.py new file mode 100644 index 0000000000000000000000000000000000000000..6dd12a980c6e3b2ea859726d9ad1202331ed22b4 --- /dev/null +++ b/draw_diff.py @@ -0,0 +1,31 @@ +import numpy as np +import torch +def loadbao(name): + ret = [] + with open(name, 'r') as fin: + for line in fin: + a, b = line.split() + ret.append(abs(float(b))) + return ret +import torchvision +import torchvision.transforms as transforms + +def sq(img, x, y, lx, ly): + assert len(img.shape) == 3 + img[:,x:x+lx,y] = torch.tensor([0,1,0]).unsqueeze(-1) + img[:,x:x+lx,y+ly] = torch.tensor([0,1,0]).unsqueeze(-1) + img[:,x,y:y+ly] = torch.tensor([0,1,0]).unsqueeze(-1) + img[:,x+lx,y:y+ly] = torch.tensor([0,1,0]).unsqueeze(-1) + +transform = transforms.Compose([ + transforms.Resize(512), + transforms.CenterCrop(512), + ]) +img = torchvision.io.read_image('bao.jpeg') +img = transform(img) / 255. +a = np.array(loadbao('bao.txt')) +b = np.array(loadbao('bao2.txt')) +for t in (a-b>2).nonzero()[0]: + x,y = t // 32, t % 32 + sq(img, x*16, y*16, 15, 15) +torchvision.utils.save_image(img, 'example_bao.jpg') diff --git a/env/baai_setup_connection.py b/env/baai_setup_connection.py new file mode 100644 index 0000000000000000000000000000000000000000..420e90c136e0831f7ea8028326aebb94d9e0266c --- /dev/null +++ b/env/baai_setup_connection.py @@ -0,0 +1,17 @@ +import json +import os +import sys + +with open('/home/hostfile.json', 'r') as fin: + t = json.load(fin) +input_txt_path = os.path.join(os.path.dirname(__file__), 'input.txt') +with open(input_txt_path, 'w') as fout: + ip_list = [] + for x in t: + fout.write(x['ip']) + fout.write(' ') + ip_list.append(x['ip']) +sys.path.append(os.path.dirname(__file__)) +from setup_connection import main +main(ip_list, 22) + diff --git a/env/setup_connection.py b/env/setup_connection.py index 71a41340acb5738fd4e3dea11b9af01d59528a43..a17cee03841a0e56345187b31d1b8fef6829e631 100644 --- a/env/setup_connection.py +++ b/env/setup_connection.py @@ -13,19 +13,20 @@ import math import random import base64 -if __name__ == "__main__": +def main(ip_list, port=2222): ssh_config = '' - line_format = 'Host node{}\n\tUser root\n\tPort 2222\n\tHostname {}\n' - for i, ip in enumerate(sys.argv[1:]): - ssh_config += line_format.format(i, ip) + line_format = 'Host node{}\n\tUser root\n\tPort {}\n\tHostname {}\n' + for i, ip in enumerate(ip_list): + ssh_config += line_format.format(i, port, ip) ret = os.system(f'echo \"{ssh_config}\" > ~/.ssh/config && chmod 600 ~/.ssh/config') assert ret == 0 hostfile_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'hostfile') with open(hostfile_path, 'w') as fout: - for i, ip in enumerate(sys.argv[1:]): + for i, ip in enumerate(ip_list): fout.write(f'node{i} slots=8\n') print(f'Successfully generating hostfile \'{hostfile_path}\'!') - +if __name__ == "__main__": + main(sys.argv[1:]) \ No newline at end of file diff --git a/generate_samples.py b/generate_samples.py index b31630b3208e10a78061b50085bc7ecf37927dca..aff3e469b9f458bb93e63166994a663f31532933 100755 --- a/generate_samples.py +++ b/generate_samples.py @@ -41,14 +41,13 @@ from pretrain_gpt2 import get_model import math from copy import deepcopy from tqdm import tqdm -from generation import get_batch, filling_sequence, add_interlacing_beam_marks, magnify, inverse_prompt_score +from generation import get_batch, filling_sequence, add_interlacing_beam_marks, magnify, inverse_prompt_score, filling_sequence_local from torchvision.utils import save_image import torch.distributed as dist def setup_model(args): """Setup model and optimizer.""" - model = get_model(args) if args.load is not None: @@ -163,12 +162,13 @@ def generate_images_once(model, args, raw_text, seq=None, num=8, query_template= assert num < mbz or num % mbz == 0 output_tokens_list = [] for tim in range(max(num // mbz, 1)): - import line_profiler - from mpu.sparse_transformer import standard_attention + # import line_profiler + # from mpu.sparse_transformer import standard_attention # profile = line_profiler.LineProfiler(model.module.forward) # profile = line_profiler.LineProfiler(standard_attention) # profile.enable() - output_tokens_list.append(filling_sequence(model, seq.clone(), args)) + fill_fn = filling_sequence_local if args.generation_task == 'cuda-2d generation' else filling_sequence + output_tokens_list.append(fill_fn(model, seq.clone(), args)) # torch.cuda.empty_cache() # profile.disable() # åœæ¢åˆ†æž # import sys @@ -182,8 +182,11 @@ def generate_images_once(model, args, raw_text, seq=None, num=8, query_template= for seq in output_tokens_list: decoded_txts, decoded_imgs = tokenizer.DecodeIds(seq.tolist()) for i in range(len(decoded_imgs)): - if decoded_imgs[i].shape[-1] == 128: - decoded_imgs[i] = torch.nn.functional.interpolate(decoded_imgs[i], size=(256, 256)) + if decoded_imgs[i].shape[-1] < 512: + decoded_imgs[i] = torch.nn.functional.interpolate(decoded_imgs[i], size=(512, 512)) + # decoded_imgs[i].view(3, 32, 16, 32, 16)[:, :, :4, :, :4] = 0 + # decoded_imgs[i].view(3, 32, 16, 32, 16)[0, :, :4, :, :4] = 1 + # decoded_imgs[i].view(3, 32, 16, 32, 16)[1, :12, :4, :16, :4] = 1 if args.debug: imgs.extend(decoded_imgs) else: @@ -209,7 +212,7 @@ def generate_images_once(model, args, raw_text, seq=None, num=8, query_template= def generate_images_continually(model, args): if args.generation_task == 'text2image': - query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024' + query_template = '[ROI1] {} [BASE] [BOI1] [Image200]bao.jpeg' elif args.generation_task == 'image2text': query_template = '[BASE] [BOI1] [Image]{} [EOI1] [ROI1] [MASK]*20' elif args.generation_task == 'low-level super-resolution': @@ -218,6 +221,8 @@ def generate_images_continually(model, args): query_template = '[ROI1] {} [BASE] [BOI1] [Image]{}' elif args.generation_task == 'post-selection': query_template = '[BASE] [BOI1] [Image]{} [EOI1] [ROI1] {}' + elif args.generation_task == 'cuda-2d generation': + query_template = '[CLS] {} [BASE] [Image200]bao.jpeg [MASK]*4096' else: raise NotImplementedError for raw_text, seq, output_path in get_context(args, query_template): diff --git a/generation/__init__.py b/generation/__init__.py index 8620c9c18c17f4c46a59f42abf6adc16785a9ab2..b73ab2313cef87d5306e653894797d211f4a2c36 100755 --- a/generation/__init__.py +++ b/generation/__init__.py @@ -1,2 +1,3 @@ from .sampling import get_batch, filling_sequence, add_interlacing_beam_marks, inverse_prompt_score -from .magnify import magnify \ No newline at end of file +from .magnify import magnify +from .local_sampling import filling_sequence_local \ No newline at end of file diff --git a/generation/local_sampling.py b/generation/local_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..9fba71aceb1c710076302889928f49f5bbafb1d1 --- /dev/null +++ b/generation/local_sampling.py @@ -0,0 +1,147 @@ +from .sampling import * +import math +import sys +from copy import deepcopy + +def make_local_mask(sparse_config): + layout = sparse_config.layout + k1, k2 = sparse_config.kernel_size, sparse_config.kernel_size2 + k1h = k1*2-1 + h = w = int(math.sqrt(layout[2] - layout[1]) + 1e-3) + m = torch.zeros(layout[-1]+1, layout[-1], dtype=torch.bool, device='cuda') + for i in range(layout[1]): + m[i, :i] = True + m[layout[1]:, :layout[0]] = True + for i in tqdm(range(layout[1], layout[2])): + # m[i, layout[1]:i] = True + x = (i - layout[1]) // w + y = (i - layout[1]) % w + lx = max(0, x - k1h // 2) + ly = max(0, y - k1 // 2) + rx = min(h-1, x + k1h // 2) + ry = min(w-1, y + k1 // 2) + m[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = True + m[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = True + m[i, i] = False + for i in tqdm(range(layout[2], layout[3])): + x = (i - layout[2]) // (2*w) + y = (i - layout[2]) % (2*w) + lx = max(0, x - k1h // 2) + ly = max(0, y - k1 // 2) + rx = min(2*h-1, x + k1h // 2) + ry = min(2*w-1, y + k1 // 2) + m[i, layout[2]:layout[3]].view(h*2, w*2)[lx:x, ly:ry+1] = True + m[i, layout[2]:layout[3]].view(h*2, w*2)[x, ly:y+1] = True + x = x // 2 + y = y // 2 + lx = max(0, x - k2 // 2) + ly = max(0, y - k2 // 2) + rx = min(h-1, x + k2 // 2) + ry = min(w-1, y + k2 // 2) + m[i, layout[1]:layout[2]].view(h, w)[lx:rx+1, ly:ry+1] = True + m[i, i] = False + return m.unsqueeze(-1) + +def update_mems_local(hiddens, mems, start, end, mem_bag, mask): + if isinstance(hiddens, list): + hiddens = torch.stack(hiddens) + mem_bag[:, :, start:end] = hiddens.to('cuda') + # first level + del mems + mems = mem_bag.masked_select(mask[end]).view(*mem_bag.shape[:2], -1, mem_bag.shape[3]).to(hiddens.device) + return mems + + +def filling_sequence_local( + model, + seq, + args, + mems=None, + invalid_slices=[], + **kwargs): + ''' + seq: [2, 3, 5, ..., -1(to be generated), -N (N beams), -1] + context_length: first non(-1)s + ''' + loss_sum, loss_n = 0, 0.0001 + tokenizer = get_tokenizer() + device = seq.device + assert len(seq.shape) == 1 + assert args.sparse_config.sparse_type == 'standard' + sparse_config = deepcopy(args.sparse_config) + # with open('tmp_save.bin', 'rb') as fout: + # seq = torch.load(fout)[1] + # seq = torch.cat((seq, torch.tensor([-1], device=seq.device))) + # import pdb; pdb.set_trace() + # sparse_config.layout[0] = seq.tolist().index(-1) + sparse_config.layout[0] = seq.tolist().index(tokenizer['[BASE]']) + n_pad = sparse_config.layout[1] - sparse_config.layout[0] + assert n_pad > 0 # TODO long trunc + seq = torch.cat((seq[:sparse_config.layout[0]], torch.tensor([tokenizer['[POS8]']]* n_pad, device=seq.device, dtype=seq.dtype), seq[sparse_config.layout[0]:])) + out_seq_length = len(seq) + # building the initial tokens, attention_mask, and position_ids + context_length = sparse_config.layout[1] + 1 + + tokens, attention_mask, position_ids = get_batch(seq[:context_length], device, args) + tokens = tokens.expand(-min(seq), *tokens.shape[1:]) + + counter = context_length - 1 # == len(tokens) - 1 + index = 0 # len(mems) + if mems is None: + mems = [] + mem_bag = torch.zeros(args.num_layers, tokens.shape[0], out_seq_length-1, 2*args.hidden_size, device='cuda') + local_mask = make_local_mask(sparse_config) + + while counter < (out_seq_length - 1): + if counter % 100 == 0: + print(counter, loss_sum / loss_n, file=sys.stderr) + # Now, we want to generate seq[counter + 1] + # token[:, index: counter+1] are just added. + # import pdb;pdb.set_trace() + + if index == 0: # first + logits, *qkv = model(tokens, position_ids, attention_mask, *mems) + mems = update_mems_local(qkv, mems, index, counter+1, mem_bag, local_mask) + index = counter + else: + assert tokens.shape[1] == counter + 1 + position_ids = torch.arange(index, counter + 1, dtype=torch.long, device=tokens.device).unsqueeze(0) + logits, *qkv = model(tokens[:, index: ], + position_ids, + 0, # rebuild in transformers (sep version) + *mems + ) + mems = update_mems_local(qkv, mems, index, counter+1, mem_bag, local_mask) + index = counter + counter += 1 + index += 1 + + if seq[counter] >= 0: # provided + tokens = torch.cat((tokens, seq[counter: counter+1].expand(tokens.shape[0], 1)), dim=1) + loss_n +=1 + loss_this = F.log_softmax(logits, dim=-1)[:, -1, seq[counter]].mean() + print(counter-64, loss_this.item()) + loss_sum -= loss_this + continue + + logits = logits[:, -1] # [batch size, vocab size] + temp = args.temperature + logits /= temp + for invalid_slice in invalid_slices: # forbide to generate other tokens + logits[..., invalid_slice] = -float('Inf') + logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) + log_probs = F.softmax(logits, dim=-1) + + # expand beams + prev = torch.multinomial(log_probs, num_samples=1) + tokens = torch.cat((tokens, prev.view(tokens.shape[0], 1)), dim=1) + + output_tokens_list = tokens.view(tokens.shape[0], -1).contiguous() + output_tokens_list = torch.cat( + ( + output_tokens_list[:, :sparse_config.layout[0]], + output_tokens_list[:, sparse_config.layout[1]+1:sparse_config.layout[2]+1], + torch.tensor([[tokenizer['[EOI1]']]], dtype=tokens.dtype, device=tokens.device).expand(output_tokens_list.shape[0], 1), + output_tokens_list[:, sparse_config.layout[2]+1:] + ), dim=1) + return output_tokens_list \ No newline at end of file diff --git a/generation/sampling.py b/generation/sampling.py index 8cf03263e78ce4d80ddc2dfa5832eb22976b4997..4c3acd2501e50f74589884963a87a76840d019dd 100755 --- a/generation/sampling.py +++ b/generation/sampling.py @@ -20,7 +20,6 @@ import torch.nn.functional as F from pretrain_gpt2 import get_masks_and_position_ids from data_utils import get_tokenizer - def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): # This function has been mostly taken from huggingface conversational ai code at # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 @@ -61,6 +60,19 @@ def get_batch(context_tokens, device, args): tokens, args=args) return tokens, attention_mask, position_ids +def update_mems(hiddens, mems, max_memory_length=10000): + memory_length = mems[0].size(1) if mems else 0 + query_length = hiddens[0].size(1) + new_memory_length = min(max_memory_length, memory_length + query_length) + new_mems = [] + with torch.no_grad(): + for i in range(len(hiddens)): + if new_memory_length <= query_length: + new_mems.append(hiddens[i][:, -new_memory_length:]) + else: + new_mems.append(torch.cat((mems[i][:, -new_memory_length+query_length:], hiddens[i]), dim=1)) + return new_mems + def filling_sequence( model, seq, @@ -115,8 +127,15 @@ def filling_sequence( if index == 0: # first position_ids[position_ids > offset] -= offset - logits, *mems = model(tokens, position_ids, attention_mask, *mems) + logits, *qkv = model(tokens, position_ids, attention_mask, *mems) + mems = update_mems(qkv, mems) + + tmp = -F.log_softmax(logits, dim=-1) + tmp = tmp[0,:-1].gather(dim=-1,index=tokens[0,1:].unsqueeze(-1))[4:,0] + for i in range(1,len(tmp)): + print(i, tmp[i].item()) index = counter + print(tmp[1:].mean(), file=sys.stderr) elif seq[counter + 1] >= 0: # provided if seq[counter + 1] == tokenizer['[ROI2]']: offset = counter + 1 @@ -131,15 +150,18 @@ def filling_sequence( position_ids[position_ids > offset] -= offset # TODO each time, the feed input cannot be too long (window size), or it will have a discrepcy from sparse training, but this is not very important. tokens, mems, score = shrink_beams(tokens, mems, -seq[counter + 1], score) - logits, *mems = model(tokens[:, index: ], + logits, *qkv = model(tokens[:, index: ], position_ids, 0, # rebuild in transformers (sep version) *mems) + mems = update_mems(qkv, mems) + index = counter nb = -seq[counter + 1] counter += 1 index += 1 + logits = logits[:, -1] # [batch size, vocab size] temp = args.temperature @@ -180,7 +202,7 @@ def shrink_beams(tokens, mems, nb, score): new_mems = [mem[max_idx: max_idx + 1] for mem in mems] return tokens, new_mems, score -def add_interlacing_beam_marks(seq, nb=12, period=3000): +def add_interlacing_beam_marks(seq, nb=12, period=30000): assert isinstance(seq, list) or len(seq.shape) == 1 blk_cnt = 0 for i in range(len(seq)): @@ -203,7 +225,9 @@ def inverse_prompt_score(model, seq, args): assert tokenizer['[ROI1]'] == seq[0][botext] tokens, attention_mask, position_ids = get_batch(seq, device, args) - logits, *mems = model(tokens, position_ids, attention_mask) + logits, *qkv = model(tokens, position_ids, attention_mask) + mems = update_mems(qkv, mems) + logits[..., :tokenizer.img_tokenizer.num_tokens] = -float('Inf') log_probs = torch.log(F.softmax(logits, dim=-1)) diff --git a/mpu/local_attention_function.py b/mpu/local_attention_function.py index 5ba073ced5728ff51fcb03115f76afcbaae42451..1a5af91d91310bb2d84112ad5349cf09c2fa2950 100644 --- a/mpu/local_attention_function.py +++ b/mpu/local_attention_function.py @@ -28,6 +28,7 @@ class similarFunction(Function): x_ori, x_loc = ctx.saved_tensors kH, kW = ctx.kHW casual_mask = ctx.casual_mask + grad_outputs = grad_outputs.contiguous() grad_ori = similar_backward(x_ori, x_loc, grad_outputs, kH, kW, True, casual_mask) grad_loc = similar_backward(x_ori, x_loc, grad_outputs, kH, kW, False, casual_mask) @@ -50,6 +51,7 @@ class weightingFunction(Function): x_ori, x_weight = ctx.saved_tensors kH, kW = ctx.kHW casual_mask = ctx.casual_mask + grad_outputs = grad_outputs.contiguous() grad_ori = weighting_backward_ori(x_ori, x_weight, grad_outputs, kH, kW, casual_mask) grad_weight = weighting_backward_weight(x_ori, x_weight, grad_outputs, kH, kW, casual_mask) @@ -139,11 +141,3 @@ class TorchLocalAttention(nn.Module): return out - -if __name__ == '__main__': - b, c, h, w = 8, 3, 32, 32 - kH, kW = 5, 5 - x = torch.rand(b, c, h, w).cuda() - m = LocalAttention(c, c, kH, kW) - m.cuda() - y = m(x) \ No newline at end of file diff --git a/mpu/sparse_transformer.py b/mpu/sparse_transformer.py index 33b5d42758385c222e1bd30bf53279de21e73eb2..7adc0dfbb13e25aba14aa64fa1f0285a457d2bbf 100755 --- a/mpu/sparse_transformer.py +++ b/mpu/sparse_transformer.py @@ -75,7 +75,7 @@ class GPT2ParallelSelfAttention(torch.nn.Module): """ def __init__(self, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, - init_method, output_layer_init_method=None): + init_method, output_layer_init_method=None,sparse_config=None): super(GPT2ParallelSelfAttention, self).__init__() # Set output layer initialization if not provided. if output_layer_init_method is None: @@ -111,6 +111,11 @@ class GPT2ParallelSelfAttention(torch.nn.Module): get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint + # self.offset_bias = torch.nn.Parameter( + # torch.ones(num_attention_heads, sparse_config.kernel_size**2//2+1 +sparse_config.kernel_size2**2)) + + self.sparse_config = sparse_config + def _transpose_for_scores(self, tensor): """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. @@ -122,20 +127,21 @@ class GPT2ParallelSelfAttention(torch.nn.Module): return tensor.permute(0, 2, 1, 3) - def forward(self, hidden_states, mask, sparse_config, mem=None): + def forward(self, hidden_states, mask, mem=None): + # import pdb;pdb.set_trace() + sparse_config = self.sparse_config # hidden_states: [b, s, h] # ltor_mask: [1, 1, s, s] # Attention heads. [b, s, hp] query_length = hidden_states.size(1) - # if mem is None: mixed_raw_layer = self.query_key_value(hidden_states) (mixed_query_layer, mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) - if mem is not None: + if mem is not None and len(mem) > 0: memk, memv = split_tensor_along_last_dim(mem, 2) mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1) mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1) @@ -152,7 +158,7 @@ class GPT2ParallelSelfAttention(torch.nn.Module): if sparse_config.sparse_type == 'standard': context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn) else: - context_layer = sparse_attention(query_layer, key_layer, value_layer, sparse_config.pivot_idx, + context_layer = sparse_attention_1d(query_layer, key_layer, value_layer, sparse_config.pivot_idx, mask, sparse_config.query_window, sparse_config.key_window_times, dropout_fn) # inference: context_layer = sparse_attention_inference(query_layer, key_layer, value_layer, pivot_idx) @@ -163,12 +169,14 @@ class GPT2ParallelSelfAttention(torch.nn.Module): context_layer = context_layer.view(*new_context_layer_shape) elif sparse_config.sparse_type == 'cuda_2d': - context_layer = sparse_attention_2d(mixed_query_layer, mixed_key_layer, mixed_value_layer, self.num_attention_heads_per_partition, - sparse_config.layout, mask, sparse_config.kernel_size, sparse_config.kernel_size2, attention_dropout=dropout_fn) + context_layer = sparse_attention_2dfull(mixed_query_layer, mixed_key_layer, mixed_value_layer, self.num_attention_heads_per_partition, + sparse_config.layout, mask, sparse_config.kernel_size, + kernel_size2=sparse_config.kernel_size2, + attention_dropout=dropout_fn + ) # Output. [b, s, h] output = self.dense(context_layer) - if self.training: output = self.output_dropout(output) @@ -228,6 +236,7 @@ class GPT2ParallelMLP(torch.nn.Module): def forward(self, hidden_states): # [b, s, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel = gelu(intermediate_parallel) @@ -292,7 +301,9 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): attention_dropout_prob, output_dropout_prob, init_method, - output_layer_init_method=output_layer_init_method) + output_layer_init_method=output_layer_init_method, + sparse_config=sparse_config + ) # Layernorm on the input data. self.post_attention_layernorm = LayerNorm(hidden_size, @@ -320,7 +331,7 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): # Layer norm at the begining of the transformer layer. layernorm_output1 = self.input_layernorm(hidden_states) # Self attention. - attention_output, qkv = self.attention(layernorm_output1, ltor_mask, self.sparse_config, mem) + attention_output, qkv = self.attention(layernorm_output1, ltor_mask, mem) # Third LayerNorm if self.sandwich_ln: @@ -501,7 +512,12 @@ class GPT2ParallelTransformer(torch.nn.Module): x_, mask, mems_ = inputs[0], inputs[1], inputs[2:] for i, layer in enumerate(layers_): - mem_i_ = mems_[i] if mems_ else None + if mems_: + mem_i_ = mems_[i] + elif self.max_memory_length > 0: + mem_i_ = [] + else: + mem_i_ = None x_, qkv = layer(x_, mask, mem=mem_i_) if self.max_memory_length > 0: mem_layers.append(qkv) @@ -527,30 +543,25 @@ class GPT2ParallelTransformer(torch.nn.Module): for i, layer in enumerate(self.layers): args = [hidden_states, attention_mask_saved] - mem_i = mems[i] if mems else None + if mems: + mem_i = mems[i] + elif self.max_memory_length > 0: + mem_i = [] + else: + mem_i = None hidden_states, qkv = layer(*args, mem=mem_i) if self.max_memory_length > 0: mem_layers.append(qkv) # Final layer norm. output = self.final_layernorm(hidden_states) - if self.max_memory_length > 0: # TODO cache - mem_layers = self.update_mems(mem_layers, mems) + # if self.max_memory_length > 0: # TODO cache + # if self.sparse_config.sparse_type != 'cuda_2d': + # mem_layers = self.update_mems(mem_layers, mems) + # else: + # pass # handle update outside the model, because mems is not the full cached qkv. return (output, *mem_layers) - - def update_mems(self, hiddens, mems): - memory_length = mems[0].size(1) if mems else 0 - query_length = hiddens[0].size(1) - new_memory_length = min(self.max_memory_length, memory_length + query_length) - new_mems = [] - with torch.no_grad(): - for i in range(len(hiddens)): - if new_memory_length <= query_length: - new_mems.append(hiddens[i][:, -new_memory_length:]) - else: - new_mems.append(torch.cat((mems[i][:, -new_memory_length+query_length:], hiddens[i]), dim=1)) - return new_mems def _chunk(x, w, times): @@ -600,7 +611,6 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, atte # [b, np, s, hn] context_layer = torch.matmul(attention_probs, value_layer) - return context_layer def sparse_attention_1d(q, k, v, pivot_idx, pivot_attention_mask, query_window=128, key_window_times=6, attention_dropout=None): @@ -682,10 +692,10 @@ def sparse_attention_1d(q, k, v, pivot_idx, pivot_attention_mask, query_window=1 def transpose_and_split(x, layout, n_head): x = x.transpose(1, 2) - x = x.reshape(x.shape[0] * n_head, x.shape[1] // n_head, x.shape[2]) + x = x.reshape(x.shape[0]*n_head, x.shape[1] // n_head, x.shape[2]) x_text = x[..., :layout[0]] - x0 = x[...,layout[1]:layout[2]].view(x.shape[0], x.shape[1], sqrt(layout[2] - layout[1]), -1).contiguous() - x1 = x[...,layout[2]:layout[3]].view(x.shape[0], x.shape[1], sqrt(layout[3] - layout[2]), -1).contiguous() + x0 = x[...,layout[1]:layout[2]].view(x.shape[0], x.shape[1], sqrt(layout[2] - layout[1]), sqrt(layout[2] - layout[1])).contiguous() + x1 = x[...,layout[2]:layout[3]].view(x.shape[0], x.shape[1], sqrt(layout[3] - layout[2]), sqrt(layout[3] - layout[2])).contiguous() return x, x_text, x0, x1 def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs): @@ -704,30 +714,35 @@ def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_s q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext] k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head) v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head) - # import pdb; pdb.set_trace() # all to text scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d) scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0]) # 0 to 0 - scores_0_to_0 = f_similar(q0, k0, kernel_size, kernel_size, True) + scores_0_to_0 = f_similar(q0, k0, kernel_size*2-1, kernel_size, True) # 1 to 1 - scores_1_to_1 = f_similar(q1, k1, kernel_size, kernel_size, True) + scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True) # 1 to 0 scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2] # softmax - probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text] + # if 'offset_bias' in kwargs: + # p1, p2 = kernel_size**2//2 + 1, kernel_size2**2 + # offset_bias = kwargs['offset_bias'].expand(b, n_head, p1+p2).view(b*n_head, 1, p1+p2) + # scores_0_to_0 = scores_0_to_0 * offset_bias[...,:p1] + # scores_1_to_1 = scores_1_to_1 * offset_bias[...,:p1] + # scores_1_to_0 = scores_1_to_0 * offset_bias[...,-p2:] scores_0 = torch.cat( (scores_all_to_text[:, layout[1]:layout[2]], scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])), dim=-1) - probs_0 = F.softmax(scores_0, dim=-1) # scores_1 = torch.cat( (scores_all_to_text[:, layout[2]:layout[3]], scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]), scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])), dim=-1) + probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text] + probs_0 = F.softmax(scores_0, dim=-1) # probs_1 = F.softmax(scores_1, dim=-1) if attention_dropout is not None: @@ -747,11 +762,11 @@ def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_s probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]), v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn) - context_0_to_0 = f_weighting(v0, probs_0[..., layout[0]:].view_as(scores_0_to_0).contiguous(), kernel_size, kernel_size, True) + context_0_to_0 = f_weighting(v0, probs_0[..., layout[0]:].view_as(scores_0_to_0).contiguous(), kernel_size*2-1, kernel_size, True) context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False) - context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size, kernel_size, True) + context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size*2-1, kernel_size, True) context_all_to_01 =torch.cat( ( @@ -761,3 +776,78 @@ def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_s ), dim=-1).view(b, hn, -1).transpose(1, 2) return context_all_to_text + context_all_to_01 + +def sparse_attention_2dfull(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs): + ''' + q, k, v: [batch_size, 64+1024+4096, hidden_size] + n_head: int + layout: [endoftext/startofpad, startof0, startof1, endofall] + attention_mask_text2d: [batch_size, sq_len, endoftext] + ''' + from .local_attention_function import f_similar, f_weighting + b, sq_len, hn = q.shape + alpha = sqrt((layout[3] - layout[2]) // (layout[2] - layout[1])) + + q = q / math.sqrt(hn // n_head) # normalization + + q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext] + k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head) + v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head) + # import pdb; pdb.set_trace() + # all to text + scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d) + scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0]) + # 0 to 0 + if not hasattr(sparse_attention_2dfull, 'attention_mask0'): + sparse_attention_2dfull.attention_mask0 = torch.ones((layout[2] - layout[1], layout[2] - layout[1]), device=q.device, dtype=q.dtype).tril_() + attention_mask0 = sparse_attention_2dfull.attention_mask0 + scores_0_to_0 = torch.einsum('bhi,bhj->bij', q0.view(*q0.shape[:2], -1), k0.view(*k0.shape[:2], -1)) * attention_mask0 - 10000.0 * (1.0 - attention_mask0) + # 1 to 1 + scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True) + # 1 to 0 + scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2] + # softmax + + scores_0 = torch.cat( + (scores_all_to_text[:, layout[1]:layout[2]], + scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])), + dim=-1) + scores_1 = torch.cat( + (scores_all_to_text[:, layout[2]:layout[3]], + scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]), + scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])), + dim=-1) + probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text] + probs_0 = F.softmax(scores_0, dim=-1) # + probs_1 = F.softmax(scores_1, dim=-1) + + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + probs_0 = attention_dropout(probs_0) + probs_1 = attention_dropout(probs_1) + # weighting + pad = torch.zeros(layout[1], device=q.device, dtype=q.dtype) + probs_all_to_text = torch.cat(( + probs_text, + pad[-layout[0]:].expand(b*n_head, layout[1]-layout[0], layout[0]), + probs_0[:, :, :layout[0]], + probs_1[:, :, :layout[0]] + ), dim=1) + + context_all_to_text = torch.einsum('bhij,bhcj->bihc', + probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]), + v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn) + + context_0_to_0 = torch.einsum('bcj,bij->bci', v0.view(*v0.shape[:2], -1), probs_0[..., layout[0]:].view_as(scores_0_to_0)) + + context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False) + + context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size*2-1, kernel_size, True) + + context_all_to_01 =torch.cat( + ( + pad.expand(b*n_head, hn//n_head, layout[1]), + context_0_to_0.view(b*n_head, hn//n_head, layout[2]-layout[1]), + (context_1_to_0 + context_1_to_1).view(b*n_head, hn//n_head, layout[3]-layout[2]) + ), dim=-1).view(b, hn, -1).transpose(1, 2) + return context_all_to_text + context_all_to_01 diff --git a/pretrain_gpt2.py b/pretrain_gpt2.py index f9d83e5184c2f9037a337228f86e9231e0dbba9d..784153e8b65c70c7004525dfdf208fd69c9cf35a 100755 --- a/pretrain_gpt2.py +++ b/pretrain_gpt2.py @@ -55,12 +55,10 @@ from data_utils import make_loaders, get_tokenizer, detect_new_datasets import stat -def get_model(args): +def get_model(args, sparse_config=None): """Build the model.""" print_rank_0('building CogView2 model ...') - # print(args.vocab_size) - # ml = max(args.max_position_embeddings, args.max_position_embeddings_finetune) ml = args.max_position_embeddings model = GPT2Model(num_layers=args.num_layers, vocab_size=args.vocab_size, @@ -74,7 +72,7 @@ def get_model(args): checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, parallel_output=True, - sparse_config=args.sparse_config, + sparse_config=sparse_config if sparse_config is not None else args.sparse_config, sandwich_ln=args.sandwich_ln ) @@ -218,8 +216,9 @@ def get_masks_and_position_ids(data, # single direction, [PAD]s are at the end of the seq, so doesn't matter. if args.sparse_config.sparse_type == 'cuda_2d': layout = args.sparse_config.layout - unpad_indices = (data[:, :layout[1]] >= 0) * 10000. - starts = (torch.arange(layout[1], device=data.device).expand_as(unpad_indices) + unpad_indices).min(dim=-1)[1] + unpad_indices = (data[:, :layout[1]+1] >= 0) * 10000. + unpad_indices[:, -1] = 9000. + starts = (torch.arange(layout[1]+1, device=data.device).expand_as(unpad_indices) + unpad_indices).min(dim=-1)[1] layout[0] = starts.max().item() attention_mask = torch.ones((batch_size, seq_length, layout[0]), device=data.device) for i in range(batch_size): @@ -229,7 +228,22 @@ def get_masks_and_position_ids(data, elif args.sparse_config.sparse_type == 'standard': attention_mask = torch.ones((batch_size, seq_length, seq_length), device=data.device) attention_mask.tril_() - attention_mask = attention_mask.unsqueeze(1) + # attention_mask = torch.zeros((seq_length, seq_length), device=data.device) + # h = w = 32 + # k1=9 + # layout = [10, 64, 64+h*w, 64+h*w*5] + # for i in range(layout[1]): + # attention_mask[i, :i+1] = 1 + # for i in range(layout[1], layout[2]): + # x = (i - layout[1]) // w + # y = (i - layout[1]) % w + # lx = max(0, x - k1 // 2) + # ly = max(0, y - k1 // 2) + # rx = min(h-1, x + k1 // 2) + # ry = min(w-1, y + k1 // 2) + # attention_mask[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1 + # attention_mask[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1 + # attention_mask = attention_mask.unsqueeze(0).unsqueeze(0) elif args.sparse_config.sparse_type == 'torch_1d': raise NotImplementedError @@ -283,6 +297,18 @@ def get_batch(data_iterator, args, timers): # Unpack. tokens_ = data_b['text'].long() loss_mask = data_b['loss_mask'].float() + + #FIXME change order + if args.sparse_config.sparse_type == 'cuda_2d': + assert args.sparse_config.layout[-1] == 64+32**2+64**2 + tokens_new = tokens_.clone() + tokens_new[:, 64:64+32**2] = tokens_[:, 64+64**2:] + tokens_new[:, 64+32**2:] = tokens_[:, 64:64+64**2] + tokens_ = tokens_new + # only 32# + # tokens_ = tokens_[:, :64+32**2] + # loss_mask = loss_mask[:, :64+32**2] + if args.dataset_type == 'BinaryDataset': labels = tokens_.contiguous() loss_mask = loss_mask.contiguous() @@ -326,7 +352,7 @@ def forward_step(data_iterator, model, args, timers, mems): img_txt_sep = tokenizer.img_tokenizer.num_tokens img_indices_bool = (tokens.detach() < img_txt_sep) & (loss_mask > 0) txt_indices_bool = (~img_indices_bool) & (loss_mask > 0) - + # Forward model. logits, *mems = model(tokens, position_ids, attention_mask, *mems) losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), @@ -339,6 +365,15 @@ def forward_step(data_iterator, model, args, timers, mems): loss = torch.sum(losses) / loss_mask.sum() # ===================== Log partial losses ======================== # + if args.sparse_config.sparse_type == 'cuda_2d': + img_indices_bool2 = img_indices_bool.clone() + img_indices_bool2[:, :args.sparse_config.layout[2]] = False + img_loss2 = losses[img_indices_bool2.view(-1)].detach().sum() / max(img_indices_bool2.sum(), 1) + torch.distributed.all_reduce(img_loss2.data) + img_loss2.data = img_loss2.data / args.world_size + img_indices_bool[:, args.sparse_config.layout[2]:] = False + else: + img_loss2 = 0 img_indices_bool = img_indices_bool.view(-1) txt_indices_bool = txt_indices_bool.view(-1) img_loss = losses[img_indices_bool].detach().sum() / max(img_indices_bool.sum(), 1) @@ -351,8 +386,10 @@ def forward_step(data_iterator, model, args, timers, mems): txt_loss.data = txt_loss.data / args.world_size # ===================== END OF BLOCK ======================= # - - return loss, mems, img_loss, txt_loss + # import pdb;pdb.set_trace() + # with open('tmp_save.bin', 'wb') as fout: + # torch.save(tokens, fout) + return loss, mems, img_loss, txt_loss, img_loss2 def backward_step(optimizer, model, lm_loss, args, timers): @@ -423,12 +460,12 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, while True: # Forward model for one step. timers('forward').start() - lm_loss, mems, img_loss, txt_loss = forward_step(data_iterator, model, args, timers, mems) + lm_loss, mems, img_loss, txt_loss, img_loss2 = forward_step(data_iterator, model, args, timers, mems) timers('forward').stop() if (img_loss + txt_loss).isnan().any() or (img_loss + txt_loss).isinf().any(): print('Skipping backward and optimizer step for nan or inf in forwarding!') - return (img_loss + txt_loss), 1, mems, img_loss, txt_loss + return (img_loss + txt_loss), 1, mems, img_loss, txt_loss, img_loss2 # Calculate gradients, reduce across processes, and clip. timers('backward').start() @@ -459,15 +496,17 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, timers('optimizer').stop() if complete: break - return lm_loss_reduced, skipped_iter, mems, img_loss, txt_loss + return lm_loss_reduced, skipped_iter, mems, img_loss, txt_loss, img_loss2 -def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time, step, total_step, args, img_loss, txt_loss): +def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time, step, total_step, args, img_loss, txt_loss, img_loss2): log_string = ' iteration {:8d}/{:8d} |'.format(step, total_step) log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(elapsed_time) log_string += ' learning rate {:.3E} |'.format(lr) log_string += ' lm loss {:.6E} |'.format(loss) log_string += ' img loss {:.6E} |'.format(img_loss) + if args.sparse_config.sparse_type == 'cuda_2d': + log_string += ' img loss2 {:.6E} |'.format(img_loss2) log_string += ' unscaled txt loss {:.6E} |'.format(txt_loss) if args.fp16: log_string += ' loss scale {:.1f} |'.format( @@ -514,13 +553,13 @@ def train(model, optimizer, lr_scheduler, if args.iteration % 100 == 0: new_loaders = detect_new_datasets(args) if new_loaders is not None: - print(f'Loatding new datasets ... Now we train models on {args.train_data}.') + print(f'Loading new datasets ... Now we train models on {args.train_data}.') train_data_iterator = iter(new_loaders[0]) val_data_iterator = iter(new_loaders[1]) # TODO close the original - lm_loss, skipped_iter, mems, img_loss, txt_loss = train_step(train_data_iterator, + lm_loss, skipped_iter, mems, img_loss, txt_loss, img_loss2 = train_step(train_data_iterator, model, optimizer, lr_scheduler, @@ -544,7 +583,7 @@ def train(model, optimizer, lr_scheduler, elapsed_time = timers('interval time').elapsed() report_iteration_metrics(summary_writer, optimizer, learning_rate, avg_lm_loss, elapsed_time * 1000.0 / args.log_interval, args.iteration, args.train_iters, args, - avg_img_loss, avg_txt_loss) + avg_img_loss, avg_txt_loss, img_loss2) total_lm_loss = 0.0 total_img_loss = 0.0 total_txt_loss = 0.0 @@ -588,6 +627,7 @@ def evaluate(data_iterator, model, args, timers, verbose=False): total_lm_loss = 0 mems = [] + # with open('grad_scale_fp32.txt', 'w') as fout: with torch.no_grad(): iteration = 0 while iteration < args.eval_iters: @@ -595,8 +635,21 @@ def evaluate(data_iterator, model, args, timers, verbose=False): if verbose and iteration % args.log_interval == 0: print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters)) # Forward evaluation. - lm_loss, mems, img_loss, txt_loss = forward_step(data_iterator, model, args, timers, mems=mems) - + lm_loss, mems, img_loss, txt_loss, img_loss2 = forward_step(data_iterator, model, args, timers, mems=mems) + + # (lm_loss).backward() + # for name, param in model.named_parameters(): + # v_max = param.data.abs().max().item() + # v_mean = param.data.abs().mean().item() + # fout.write(f'name: {name}, v_max: {v_max}, v_mean: {v_mean}\n') + # if param.grad is not None: + # g_max = param.grad.max().item() + # g_zero_bool = param.grad.abs() == 0 + # g_zero = (g_zero_bool.sum() / param.grad.numel()).item() + # g_nz_mean = param.grad[~g_zero_bool].abs().mean().item() + # fout.write(f'g_max: {g_max}, g_zero: {g_zero}, g_nz_mean: {g_nz_mean}\n') + # fout.flush() + # import pdb;pdb.set_trace() '''when contiguous memory optimizations are enabled, the buffers allocated by the optimizations are deallocated during backward pass in the absence of backward pass the buffers should be reset after each @@ -825,6 +878,16 @@ def main(): evaluate_and_print_results(prefix, test_data_iterator, model, args, timers, True) +def seed_torch(seed=1029): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.enabled = False if __name__ == "__main__": + torch.backends.cuda.matmul.allow_tf32 = False main() diff --git a/random_display.py b/random_display.py new file mode 100644 index 0000000000000000000000000000000000000000..c59c51adaf56937df7c1bcef072759890006ca21 --- /dev/null +++ b/random_display.py @@ -0,0 +1,25 @@ +from data_utils.datasets import BinaryDataset +from data_utils import get_tokenizer +import argparse +import os +import torch +import random +test_dir = 'tmp' +# bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin' +bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/quanjing005/quanjing005.bin.part_0.cogdata' +bin_ds = BinaryDataset(os.path.join(bin_dir), process_fn=lambda x:x, length_per_sample=64*64+32*32+64, dtype='int32', preload=False) +args = argparse.Namespace(img_tokenizer_path='pretrained/vqvae/vqvae_hard_biggerset_011.pt', img_tokenizer_num_tokens=None) +tokenizer = get_tokenizer(args) + +bin_ds = [bin_ds[random.randint(0, len(bin_ds)-1)] for i in range(16)] +for x in bin_ds: + end = x.tolist().index(-1) + print(tokenizer.DecodeIds(x[:end])[0]) + +from torchvision.utils import save_image +imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64:64+64**2], dtype=torch.long, device='cuda')) for x in bin_ds], dim=0) +save_image(imgs, os.path.join(test_dir, 'testcase512.jpg'), normalize=True) +imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64+64**2:64+64**2+32**2], dtype=torch.long,device='cuda')) for x in bin_ds], dim=0) +save_image(imgs, os.path.join(test_dir, 'testcase256.jpg'), normalize=True) +# imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64+64**2+32**2:], dtype=torch.long,device='cuda')) for x in bin_ds], dim=0) +# save_image(imgs, os.path.join(test_dir, 'testcase128.jpg'), normalize=True) \ No newline at end of file diff --git a/scripts/cuda_2d_text2image.sh b/scripts/cuda_2d_text2image.sh new file mode 100755 index 0000000000000000000000000000000000000000..46783ccda9f37af232fd7f46de97c96463d5fccf --- /dev/null +++ b/scripts/cuda_2d_text2image.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +CHECKPOINT_PATH=data/checkpoints/cogview-fixgrad-small08-25-09-38 +# CHECKPOINT_PATH=data/checkpoints/cogview-compare +NLAYERS=16 +NHIDDEN=1024 +NATT=16 +MAXSEQLEN=5184 +MASTER_PORT=$(shuf -n 1 -i 10000-65535) +MPSIZE=1 + +#SAMPLING ARGS +TEMP=1.05 +#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p +TOPK=100 +TOPP=0 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +MASTER_PORT=${MASTER_PORT} python generate_samples.py \ + --deepspeed \ + --model-parallel-size $MPSIZE \ + --num-layers $NLAYERS \ + --hidden-size $NHIDDEN \ + --load $CHECKPOINT_PATH \ + --num-attention-heads $NATT \ + --max-position-embeddings 5184 \ + --fp16 \ + --temperature $TEMP \ + --top_k $TOPK \ + --top_p $TOPP \ + --sandwich-ln \ + --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ + --sparse-type standard \ + --max-position-embeddings-finetune $MAXSEQLEN \ + --generation-task "cuda-2d generation" \ + --input-source ./input.txt \ + --output-path samples_text2image \ + --batch-size 2 \ + --max-inference-batch-size 4 \ + --device 0 \ + --sparse-type standard \ + $@ + + diff --git a/scripts/ds_config_zero.json b/scripts/ds_config_zero.json index 5a1ea682f08d854ba98468980ad27e46c7879aee..1f0f35b84c9f2720dc39d1240cafd63a9e26b2f7 100755 --- a/scripts/ds_config_zero.json +++ b/scripts/ds_config_zero.json @@ -1,6 +1,6 @@ { - "train_micro_batch_size_per_gpu": 4, - "gradient_accumulation_steps": 1, + "train_micro_batch_size_per_gpu": 6, + "gradient_accumulation_steps": 5, "steps_per_print": 1, "gradient_clipping": 0.1, "zero_optimization": { @@ -23,7 +23,7 @@ "optimizer": { "type": "Adam", "params": { - "lr": 0.00005, + "lr": 0.0005, "betas": [ 0.9, 0.95 diff --git a/scripts/pretrain_multiple_nodes.sh b/scripts/pretrain_multiple_nodes.sh index 4af4c0560f9f4a899b8444e55f1447462100d446..a03cf273fcb587611f27c318d469a96d62af2366 100755 --- a/scripts/pretrain_multiple_nodes.sh +++ b/scripts/pretrain_multiple_nodes.sh @@ -2,7 +2,7 @@ # Change for multinode config -NUM_WORKERS=2 +NUM_WORKERS=10 NUM_GPUS_PER_WORKER=8 MP_SIZE=1 @@ -11,35 +11,44 @@ script_dir=$(dirname $script_path) main_dir=$(dirname $script_dir) # OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=bond0 NCCL_IB_GID_INDEX=3 NCCL_NET_GDR_LEVEL=0" -OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=ib0 NCCL_NET_GDR_LEVEL=2" -HOST_FILE_PATH="hostfile" +OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" +HOST_FILE_PATH="hostfile2" # OPTIONS_NCCL="" # HOST_FILE_PATH="hostfile_single" +small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/zijian/zijian.bin.part_0.cogdata" +full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin" -config_json="$script_dir/ds_config_zero.json" +config_json="$script_dir/ds_config.json" gpt_options=" \ - --experiment-name cogview-ali_fashion_tutorial-12-1024-16 \ + --experiment-name cogview-fixgrad-small-test \ --img-tokenizer-num-tokens 8192 \ - --dataset-type TokenizedDataset \ + --dataset-type BinaryDataset \ --model-parallel-size ${MP_SIZE} \ - --num-layers 12 \ + --num-layers 16 \ --hidden-size 1024 \ --num-attention-heads 16 \ --save $main_dir/data/checkpoints \ - --train-iters 40000 \ + --train-iters 300000 \ --resume-dataloader \ - --train-data ./data/ali_vqvae_hard_biggerset_011.lmdb \ + --train-data ${full_data} \ --split 949,50,1 \ --distributed-backend nccl \ --lr-decay-style cosine \ --warmup .1 \ --checkpoint-activations \ --deepspeed-activation-checkpointing \ - --max-position-embeddings 1089 \ + --max-position-embeddings 5184 \ --max-memory-length 0 \ + --sandwich-ln \ + --txt-loss-scale 10 \ + --sparse-type cuda_2d \ --fp16 \ + --save-interval 2000 \ + --load data/checkpoints/cogview-compare " + # + gpt_options="${gpt_options} diff --git a/scripts/pretrain_single_node.sh b/scripts/pretrain_single_node.sh index a55298542516da980a31f96ee4dcb8f0ecca3a1a..84436360d7ec7fbdaa64fe7f7b69d086a436cbe4 100755 --- a/scripts/pretrain_single_node.sh +++ b/scripts/pretrain_single_node.sh @@ -15,13 +15,13 @@ OPTIONS_NCCL="NCCL_DEBUG=info" HOST_FILE_PATH="hostfile_single" -config_json="$script_dir/ds_config.json" +config_json="$script_dir/ds_config_zero.json" gpt_options=" \ --experiment-name cogview-testlocal \ --img-tokenizer-num-tokens 8192 \ --dataset-type BinaryDataset \ --model-parallel-size ${MP_SIZE} \ - --num-layers 12 \ + --num-layers 48 \ --hidden-size 2560 \ --num-attention-heads 40 \ --save $main_dir/data/checkpoints \ diff --git a/scripts/testnan.sh b/scripts/testnan.sh new file mode 100755 index 0000000000000000000000000000000000000000..a263aa1747a80ac99eb8efa1f57584cff67826b8 --- /dev/null +++ b/scripts/testnan.sh @@ -0,0 +1,59 @@ +#! /bin/bash + +# Change for multinode config + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=1 +MP_SIZE=1 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +main_dir=$(dirname $script_dir) + +# OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=bond0 NCCL_IB_GID_INDEX=3 NCCL_NET_GDR_LEVEL=0" +OPTIONS_NCCL="NCCL_DEBUG=info" +HOST_FILE_PATH="hostfile_single" + +small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/zijian/zijian.bin.part_0.cogdata" +full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin" + +config_json="$script_dir/ds_config.json" +gpt_options=" \ + --experiment-name cogview-testlocal \ + --img-tokenizer-num-tokens 8192 \ + --dataset-type BinaryDataset \ + --model-parallel-size ${MP_SIZE} \ + --num-layers 16 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --save $main_dir/data/checkpoints \ + --train-iters 100000 \ + --resume-dataloader \ + --test-data ${full_data} \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr-decay-style cosine \ + --warmup .1 \ + --checkpoint-activations \ + --deepspeed-activation-checkpointing \ + --max-position-embeddings 5184 \ + --max-memory-length 0 \ + --txt-loss-scale 2 \ + --sandwich-ln \ + --sparse-type cuda_2d \ + --save-interval 2500 \ + --load data/checkpoints/cogview-fixgrad-small08-25-09-38 +" + # --fp16 \ + +gpt_options="${gpt_options} + +" + # --deepspeed \ + # --deepspeed_config ${config_json} \ + +run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_gpt2.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/scripts/text2image.sh b/scripts/text2image.sh index d3a6ae7ef5a7967201882e8546db7a746e2f0d9d..38509fbd8c2fd8ad07e6d59249a827bc9a0b4e8e 100755 --- a/scripts/text2image.sh +++ b/scripts/text2image.sh @@ -42,8 +42,8 @@ MASTER_PORT=${MASTER_PORT} python generate_samples.py \ --generation-task text2image \ --input-source ./input.txt \ --output-path samples_text2image \ - --batch-size 8 \ - --max-inference-batch-size 8 \ + --batch-size 4 \ + --max-inference-batch-size 4 \ --device 0 \ $@ diff --git a/test_sparse_attention.py b/test_sparse_attention.py index a8c22f8c9778360632ae00fd2531655ab8e2b5a0..abac1ed99ef230c2054eb997fe815c5cd119c9e0 100644 --- a/test_sparse_attention.py +++ b/test_sparse_attention.py @@ -4,7 +4,7 @@ from tqdm import tqdm import torch import numpy as np -from mpu.sparse_transformer import standard_attention, sparse_attention_1d, sparse_attention_2d +from mpu.sparse_transformer import standard_attention, sparse_attention_1d, sparse_attention_2d, sparse_attention_2dfull def test_sparse_attention_1d(): s, w, times = 4096 + 128, 128, 2 @@ -78,11 +78,12 @@ def test_sparse_attention_1d(): def test_sparse_attention_2d(): dtype = torch.float device = 'cuda' - b, n_head, hn = 1, 40, 2560 + b, n_head, hn = 2, 16, 1024 h = w = 32 - layout = [10, 10, 10+h*w, 10+h*w*5] + layout = [10, 64, 64+h*w, 64+h*w*5] k1 = 9 k2 = 7 + k1h = k1*2-1 qkv = torch.rand(3, b, layout[-1], hn, dtype=dtype, device=device) qkv2 = qkv.clone() @@ -95,20 +96,22 @@ def test_sparse_attention_2d(): m[i, :i+1] = 1 m[layout[1]:, :layout[0]] = 1 for i in tqdm(range(layout[1], layout[2])): - x = (i - layout[1]) // w - y = (i - layout[1]) % w - lx = max(0, x - k1 // 2) - ly = max(0, y - k1 // 2) - rx = min(h-1, x + k1 // 2) - ry = min(w-1, y + k1 // 2) - m[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1 - m[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1 + m[i, layout[1]:i+1] = 1 + # for i in tqdm(range(layout[1], layout[2])): + # x = (i - layout[1]) // w + # y = (i - layout[1]) % w + # lx = max(0, x - k1 // 2) + # ly = max(0, y - k1 // 2) + # rx = min(h-1, x + k1 // 2) + # ry = min(w-1, y + k1 // 2) + # m[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1 + # m[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1 for i in tqdm(range(layout[2], layout[3])): x = (i - layout[2]) // (2*w) y = (i - layout[2]) % (2*w) - lx = max(0, x - k1 // 2) + lx = max(0, x - k1h // 2) ly = max(0, y - k1 // 2) - rx = min(2*h-1, x + k1 // 2) + rx = min(2*h-1, x + k1h // 2) ry = min(2*w-1, y + k1 // 2) m[i, layout[2]:layout[3]].view(h*2, w*2)[lx:x, ly:ry+1] = 1 m[i, layout[2]:layout[3]].view(h*2, w*2)[x, ly:y+1] = 1 @@ -121,7 +124,7 @@ def test_sparse_attention_2d(): ry = min(w-1, y + k2 // 2) m[i, layout[1]:layout[2]].view(h, w)[lx:rx+1, ly:ry+1] = 1 - # mask[1:] = mask[0] + mask[1:] = mask[0] # mask[1][layout[1]:, layout[0]-1] = 0 print('finish making mask...') @@ -134,22 +137,25 @@ def test_sparse_attention_2d(): torch.cuda.synchronize() t1 = time.time() - r2 = sparse_attention_2d(*qkv2, n_head, layout, mask[...,:layout[0]].unsqueeze(1), kernel_size=k1, kernel_size2=k2) + r2 = sparse_attention_2dfull(*qkv2, n_head, layout, mask[...,:layout[0]].unsqueeze(1), kernel_size=k1, kernel_size2=k2) torch.cuda.synchronize() t2 = time.time() print('times: standard ', t1-t0, ' sparse ', t2-t1) print(( (r1[:,:layout[0]]-r2[:,:layout[0]]).abs() / (r1[:,:layout[0]].abs()+r2[:,:layout[0]].abs())).max()) print(( (r1[:,layout[1]:]-r2[:,layout[1]:]).abs() / (r1[:,layout[1]:].abs()+r2[:,layout[1]:].abs())).max()) qkv.retain_grad() - l2 = r2[:,layout[1]:].mean() - l1 = r1[:,layout[1]:].mean() + l2 = r2[:,layout[1]:].sum() + l1 = r1[:,layout[1]:].sum() + l2.backward() l1.backward() g1 = qkv.grad g2 = qkv2.grad print( (g1-g2).abs().max()) - # import pdb;pdb.set_trace() + print( ((g1-g2).abs() / (g1.abs()+g2.abs()+1e-5)).max()) + + import pdb;pdb.set_trace() def seed_torch(seed=1029): diff --git a/utils.py b/utils.py index 257568e740f908e549e3b908765e582500ac16e0..c5079ce515c1559616b3e6d24dbf365791b82cb9 100755 --- a/utils.py +++ b/utils.py @@ -373,7 +373,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states= if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) - + del sd return iteration