diff --git a/arguments.py b/arguments.py index a9e722aa247c317239666c5f82906be55fbfb433..7ea04322c89471000927a29dc854f46b19ead1fd 100755 --- a/arguments.py +++ b/arguments.py @@ -311,12 +311,12 @@ def add_sparse_args(parser): def make_sparse_config(args): args.layout = [int(x) for x in args.layout.split(',')] sparse_config = argparse.Namespace(sparse_type=args.sparse_type) + sparse_config.layout = args.layout if args.sparse_type == 'standard': pass if args.sparse_type == 'cuda_2d' or args.generation_task == 'cuda-2d generation': sparse_config.kernel_size = args.kernel_size sparse_config.kernel_size2 = args.kernel_size2 - sparse_config.layout = args.layout elif args.sparse_type == 'torch_1d': raise NotImplementedError args.sparse_config = sparse_config diff --git a/create_gt.py b/create_gt.py new file mode 100644 index 0000000000000000000000000000000000000000..45cd3efdb6677a4ce919100d253e73edbac9118a --- /dev/null +++ b/create_gt.py @@ -0,0 +1,16 @@ +# %% +p = 'people.jpeg' +from data_utils.vqvae_tokenizer import VQVAETokenizer +model = VQVAETokenizer( + 'pretrained/vqvae/vqvae_hard_biggerset_011.pt' +) +img = model.read_img(p, img_size=512) +# %% +test_dir = 'tmp' +import os +import torch +from torchvision.utils import save_image +img = model.EncodeAsIds(img) +imgs = model.DecodeIds(torch.tensor(img)) +save_image(imgs, os.path.join(test_dir, 'show512_people.jpg'), normalize=True) +# %% diff --git a/data_utils/datasets.py b/data_utils/datasets.py index 7b1203839f85f14874d96647c4f7c0fd7be7d042..d12d8f7ab827230b66e2cf4be06fe46911982050 100755 --- a/data_utils/datasets.py +++ b/data_utils/datasets.py @@ -117,23 +117,21 @@ def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset): } elif dataset_type == 'CompactBinaryDataset': - layout = args.layout + layout = [64, 64+16**2, 64+16**2+32**2, 64+64**2+16**2+32**2] # FIXME DS_CLASS = BinaryDataset kwargs_to_dataset['length_per_sample'] = layout[-1] def process_fn(row): row = row.astype(np.int64) - # THIS IS Reverse order, TODO - lens = list(reversed([layout[i] - layout[i-1] for i in range(1, len(layout))])) - codes = [row[layout[0]: layout[0]+lens[0]]] - if len(lens) > 1: - codes.append(row[layout[0]+lens[0]: layout[0]+lens[0]+lens[1]]) + + codes = [row[layout[i-1]:layout[i]] for i in range(1, len(layout))] + text = row[:layout[0]] text = text[text>0][:layout[0] - 3] # [CLS] [BASE] [ROI1] n_pad = layout[0]-3-len(text) parts = [ np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64), - TextCodeTemplate(text, codes[-1]), - *reversed(codes[:-1]) + TextCodeTemplate(text, codes[1]), # FIXME + *codes[2:] # FIXME ] ret = np.concatenate(parts, axis=0) return {'text': ret, diff --git a/draw_diff.py b/draw_diff.py index b4baab18c797077231138786e377f66adee62f62..6c2d74b5cf57c469752a5fb3c5170574115e6c94 100644 --- a/draw_diff.py +++ b/draw_diff.py @@ -7,6 +7,15 @@ def loadbao(name): a, b = line.split() ret.append(abs(float(b))) return ret + +def loadlion(name): + ret1, ret2 = [], [] + with open(name, 'r') as fin: + for line in fin: + a, b = line.split() + ret1.append(abs(float(a))) + ret2.append(abs(float(b))) + return ret1, ret2 import torchvision import torchvision.transforms as transforms @@ -20,13 +29,13 @@ def sq(img, x, y, lx, ly): transform = transforms.Compose([ transforms.Resize(512), transforms.CenterCrop(512), - ]) -img = torchvision.io.read_image('bao.jpeg') + ]) +img = torchvision.io.read_image('cat2.jpeg') img = transform(img) / 255. -a = np.array(loadbao('bao2.txt')) -b = np.array(loadbao('bao3.txt')) -for t in (b-a>1).nonzero()[0]: - x,y = t // 32, t % 32 - sq(img, x*16, y*16, 15, 15) -print(a.mean(), b.mean()) -torchvision.utils.save_image(img, 'example_bao.jpg') +# a,b = np.array(loadlion('bed6.txt')) +b = np.array(loadbao('bed6.txt')) +for t in (b<0.999).nonzero()[0]: + x,y = t // 64, t % 64 + sq(img, x*8, y*8, 7, 7) +print(b.sum()) +torchvision.utils.save_image(img, 'example_cat6.jpg') diff --git a/generate_samples.py b/generate_samples.py index 4ae8fa02da9c751920e649b057705aab2a5587c4..f0c5ba54e71a87d8443e08c342714345f1bb000d 100755 --- a/generate_samples.py +++ b/generate_samples.py @@ -325,13 +325,13 @@ def main(): torch.cuda.set_device(device) # Random seeds for reproducability. - set_random_seed(args.seed) # get the tokenizer tokenizer = prepare_tokenizer(args) # Model, optimizer, and learning rate. model = setup_model(args) + set_random_seed(args.seed) generate_images_continually(model, args) diff --git a/generation/cuda_2d_sampling.py b/generation/cuda_2d_sampling.py index d5fe7f83935a99dcdbbbe5da57fe506c5ee700ab..69d3d7cb10d6a902ad2c2711d2e20abd0a61e5cc 100644 --- a/generation/cuda_2d_sampling.py +++ b/generation/cuda_2d_sampling.py @@ -1,3 +1,4 @@ +from vqvae.vqvae_zc import Encoder from .sampling import * import math import sys @@ -37,7 +38,7 @@ def filling_sequence_cuda_2d( from torchvision import transforms tr = transforms.Compose([ - transforms.Resize(512), + transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), ]) imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1024:].tolist())) for x in output0] # ground truth blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(device), add_normalization=True) # blured image as init value @@ -73,29 +74,73 @@ def filling_sequence_cuda_2d( print(unfixed.sum()) logits, *_dump = model(tokens, position_ids, attention_mask) step_cnt += 1 - last_logits = logits # warmup - real_topk = 5 - real_temp = 2 - min(1,((step_cnt) / iterative_step)) * 1.9 + real_topk = 200 + # real_temp = 0.7 #- min(1,((step_cnt) / iterative_step)) * .3 + # real_temp = args.temperature + if step_cnt <= 5: + real_temp = 0.1 + elif step_cnt == 6: + real_temp = 0.55 + elif step_cnt > 6: + real_temp = 0.45 + if 5 < step_cnt: + real_topk = 200 # sampling for invalid_slice in invalid_slices: # forbide to generate other tokens logits[..., invalid_slice] = -float('Inf') assert args.top_k > 0 - tk_value, tk_idx = torch.topk(logits, real_topk, dim=-1) - tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1]) - prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1) - prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1) + + probs0 = F.softmax(logits/real_temp, dim=-1) + topsum = torch.topk(probs0, 20, dim=-1)[0].sum(dim=-1) + if step_cnt >= 6: + real_temp2 = torch.tensor([[[real_temp]]], device=probs0.device).expand(*probs0.shape[:2], 1) * (topsum < 0.95).unsqueeze(-1) + 0.6 + # import pdb;pdb.set_trace() + else: + real_temp2 = real_temp + # import pdb;pdb.set_trace() + probs = F.softmax(logits/real_temp2, dim=-1) + tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1) + prev = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1) + edge_idx = tk_idx[:, :, -1:] + edge_value = tk_value[:, :, -1:] + edge_mask = probs.gather(dim=-1, index=prev) < edge_value + prev[edge_mask] = edge_idx[edge_mask] + prev.squeeze_(-1) + # tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1]) + # prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1) + # prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1) # update unfixed choice = 1 - if choice == 0 and step_cnt > 5: - mprob = tk_probs.max(dim=-1)[0].view(*(tk_value.shape[:2])) - dprob = (mprob[:, 1:] < 0.5) & ((mprob[:, :-1] > 0.8)| (unfixed[:, 1:-1].logical_not())) + if choice == 0 and 5 < step_cnt: + mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2])) + # import pdb;pdb.set_trace() + dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:]) + new_fixed = unfixed.clone() - new_fixed[:, 2:] &= dprob + moved_new_fixed = new_fixed[:, 2:] + moved_new_fixed &= dprob + moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not() + moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not() + # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not() + moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not() + moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not() + # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not() + elif choice == 1 and 5 < step_cnt: + new_fixed = unfixed & False + x = (step_cnt-5) // 4 + y = (step_cnt-5) % 4 + new_fixed[..., -4096:].view(batch_size, 16, 4, 16, 4)[:, :, x, :, y] = True + new_fixed &= unfixed else: new_fixed = unfixed & False # TODO new_fixed[:, -1] = True + + with open(f'bed{step_cnt}.txt', 'w') as fout: + for i, prob in enumerate(topsum[0, -4096:]): + fout.write(f'{i} {prob}\n') + unfixed &= new_fixed.logical_not() # update seq and tokens seq[new_fixed] = prev[new_fixed[:, 1:]] @@ -115,7 +160,6 @@ def filling_sequence_cuda_2d( if args.debug: imgs = torch.cat(imgs, dim=0) save_image(imgs, f'steps{device}.jpg', normalize=True) - model.module.transformer.max_memory_length = args.max_memory_length return seq \ No newline at end of file diff --git a/generation/sampling.py b/generation/sampling.py index 8d614724a10a3f1ea1733d4c3fe23a602b70ed11..b049ba8a7f3b6e5d33813a90e8db1d81e14ffd6b 100755 --- a/generation/sampling.py +++ b/generation/sampling.py @@ -19,6 +19,7 @@ import torch.nn.functional as F from pretrain_gpt2 import get_masks_and_position_ids from data_utils import get_tokenizer +from copy import deepcopy def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): # This function has been mostly taken from huggingface conversational ai code at @@ -26,8 +27,12 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k + # s1 = (logits-logits.max()).exp().sum() indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value + logits[indices_to_remove] = filter_value + # s2 = (logits-logits.max()).exp().sum() + # with open('lion.txt', 'a') as fout: + # fout.write(f'{s1} {s2}\n') if top_p > 0.0: # convert to 1D @@ -107,6 +112,12 @@ def filling_sequence( offset = context_length context_length += 1 tokens, attention_mask, position_ids = get_batch(seq[:context_length], device, args) + txt_len = seq.tolist().index(tokenizer['[BASE]']) + print('txt_len:', txt_len) + config = deepcopy(model.module.transformer.sparse_config) + ori_config = model.module.transformer.sparse_config + config.layout[0] = txt_len + model.module.transformer.reset_sparse_config(config) counter = context_length - 1 # == len(tokens) - 1 index = 0 # len(mems) @@ -130,12 +141,12 @@ def filling_sequence( logits, *qkv = model(tokens, position_ids, attention_mask, *mems) mems = update_mems(qkv, mems) - tmp = -F.log_softmax(logits, dim=-1) - tmp = tmp[0,:-1].gather(dim=-1,index=tokens[0,1:].unsqueeze(-1))[4:,0] + # tmp = -F.log_softmax(logits, dim=-1) + # tmp = tmp[0,:-1].gather(dim=-1,index=tokens[0,1:].unsqueeze(-1))[4:,0] # for i in range(1,len(tmp)): # print(i, tmp[i].item()) index = counter - print(tmp[1:].mean(), file=sys.stderr) + # print(tmp[1:].mean(), file=sys.stderr) elif seq[counter + 1] >= 0: # provided if seq[counter + 1] == tokenizer['[ROI2]']: offset = counter + 1 @@ -165,29 +176,44 @@ def filling_sequence( logits = logits[:, -1] # [batch size, vocab size] temp = args.temperature + real_topk = args.top_k + if counter <= context_length + 32: + real_topk = 80 + # else: + # real_topk = args.top_k + # if counter == context_length + 32 + 12: + # import pdb;pdb.set_trace() # TODO since the temperature is crucial, how can we find a good setting? logits /= temp - for invalid_slice in invalid_slices: # forbide to generate other tokens + for invalid_slice in invalid_slices: # to generate other tokens logits[..., invalid_slice] = -float('Inf') - logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) - log_probs = F.softmax(logits, dim=-1) + # logits = top_k_logits(logits, top_k=real_topk, top_p=args.top_p) + probs = F.softmax(logits, dim=-1) + + tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1) # expand beams if nb > 1 and tokens.shape[0] == 1: # 1->nb tokens = tokens.expand(nb, -1).contiguous() mems = [mem.expand(nb, -1, -1) for mem in mems] - prev = torch.multinomial(log_probs, num_samples=nb, replacement=True) - score = torch.log(torch.gather(log_probs, dim=1, index=prev)[0]).tolist() + prev = torch.multinomial(probs, num_samples=nb, replacement=True) + score = torch.log(torch.gather(probs, dim=1, index=prev)[0]).tolist() else: # nb -> nb assert tokens.shape[0] == nb - prev = torch.multinomial(log_probs, num_samples=1) - score_plus = torch.log(torch.gather(log_probs, dim=1, index=prev)[:, 0]) + prev = torch.multinomial(probs, num_samples=1) + for j in range(0, prev.shape[0]): + if probs[j, prev[j,-1]] < tk_value[j, -1]: + prev[j, -1] = tk_idx[j,torch.randint(tk_idx.shape[-1]-100, tk_idx.shape[-1], (1,))] + # prev[j, -1] = tk_idx[j,torch.randint(0, tk_idx.shape[-1], (1,))] + + score_plus = torch.log(torch.gather(probs, dim=1, index=prev)[:, 0]) for idx in range(nb): score[idx] += score_plus[idx] tokens = torch.cat((tokens, prev.view(tokens.shape[0], 1)), dim=1) output_tokens_list = tokens.view(tokens.shape[0], -1).contiguous() + model.module.transformer.reset_sparse_config(ori_config) return output_tokens_list def shrink_beams(tokens, mems, nb, score): diff --git a/mpu/sparse_transformer.py b/mpu/sparse_transformer.py index 53a92de78d6b0409e520dbb139774d18d37af6f7..37f3f9dbaeacff4ddcf1c500c75613bee878add6 100755 --- a/mpu/sparse_transformer.py +++ b/mpu/sparse_transformer.py @@ -75,12 +75,13 @@ class GPT2ParallelSelfAttention(torch.nn.Module): """ def __init__(self, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, - init_method, output_layer_init_method=None,sparse_config=None, + init_method, layer_id, output_layer_init_method=None,sparse_config=None, finetune=False): super(GPT2ParallelSelfAttention, self).__init__() # Set output layer initialization if not provided. if output_layer_init_method is None: output_layer_init_method = init_method + self.layer_id = layer_id # Per attention head and per partition values. world_size = get_model_parallel_world_size() self.hidden_size_per_partition = divide(hidden_size, world_size) @@ -177,7 +178,7 @@ class GPT2ParallelSelfAttention(torch.nn.Module): key_layer = self._transpose_for_scores(mixed_key_layer) value_layer = self._transpose_for_scores(mixed_value_layer) - context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn) + context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn, layer_id=self.layer_id, txt_len=layout[0] if not self.training else -1) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + \ @@ -193,7 +194,9 @@ class GPT2ParallelSelfAttention(torch.nn.Module): text_len=sparse_config.layout[0], kernel_size=sparse_config.kernel_size, kernel_size2=sparse_config.kernel_size2, - attention_dropout=dropout_fn + attention_dropout=dropout_fn, + text_start=(1-mask[...,-1,:]).sum().long().item()+1 if not self.training else -1, + layer_id=self.layer_id ) if sparse_config.sparse_type == 'cuda_2d': @@ -308,6 +311,7 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): output_dropout_prob, layernorm_epsilon, init_method, + layer_id, output_layer_init_method=None, sandwich_ln=True, sparse_config=argparse.Namespace(sparse_type='standard'), @@ -317,6 +321,7 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): # Set output layer initialization if not provided. if output_layer_init_method is None: output_layer_init_method = init_method + self.layer_id = layer_id # Layernorm on the input data. self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) @@ -328,6 +333,7 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): attention_dropout_prob, output_dropout_prob, init_method, + layer_id, output_layer_init_method=output_layer_init_method, sparse_config=sparse_config, finetune=finetune @@ -488,6 +494,7 @@ class GPT2ParallelTransformer(torch.nn.Module): output_dropout_prob, layernorm_epsilon, unscaled_init_method(init_method_std), + layer_id, output_layer_init_method=output_layer_init_method, sandwich_ln=sandwich_ln, sparse_config=sparse_config, @@ -624,7 +631,7 @@ def _chunk(x, w, times): return x.as_strided(size=chunk_size, stride=chunk_stride) -def standard_attention(query_layer, key_layer, value_layer, attention_mask, attention_dropout=None): +def standard_attention(query_layer, key_layer, value_layer, attention_mask, attention_dropout=None, layer_id = -1, txt_len=-1): # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. @@ -638,9 +645,17 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, atte attention_scores = torch.mul(attention_scores, attention_mask) - \ 10000.0 * (1.0 - attention_mask) - # Attention probabilities. [b, np, s, s] + # Attention probabilities [b, np, s, s] attention_probs = F.softmax(attention_scores, dim=-1) - + + if txt_len > 0: + t = key_layer.shape[-2] - txt_len - 1 + if t // 32 <= 32: + # line = attention_probs[..., :, 1:txt_len].max(dim=-1, keepdim=True)[0] + # tmask = attention_probs[..., :, 1:txt_len] >= line + attention_probs[..., :, 1:txt_len] *= 6 if txt_len <= 10 else 4 + attention_probs /= attention_probs.sum(dim=-1, keepdim=True)[0] + if attention_dropout is not None: with get_cuda_rng_tracker().fork(): attention_probs = attention_dropout(attention_probs) @@ -651,7 +666,7 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, atte return context_layer -def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len=64, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs): +def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len=64, kernel_size=9, kernel_size2=7, attention_dropout=None, text_start = -1, layer_id=-1, **kwargs): ''' q0, k0, v0: [batch_size, 1088, hidden_size] q1, k1, v1: [batch_size, 4096, h2] @@ -673,6 +688,9 @@ def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, te attention_scores = torch.mul(attention_scores, attention_mask) - \ 10000.0 * (1.0 - attention_mask) attention_probs0 = F.softmax(attention_scores, dim=-1) + if text_start > 0: + attention_probs0[..., :, text_start:text_len-2] *= 1 + attention_probs0 /= attention_probs0.sum(dim=-1, keepdim=True)[0] # local attention for level 1 q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1) k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) @@ -701,10 +719,11 @@ def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, te # weighting for level 1 probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1) context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True) - context1_to_1 = context1_to_1.view(b, n_head * h, l1**2) + context1 = context1_to_1.view(b, n_head * h, l1**2) # weighting for cross attention probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0) v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0) context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False) context1_to_0 = context1_to_0.view(b, n_head * h, l1**2) - return context0.transpose(1, 2).reshape(b, s0, h0), (context1_to_0 + context1_to_1).transpose(-1, -2) \ No newline at end of file + context1 = context1 + context1_to_0 + return context0.transpose(1, 2).reshape(b, s0, h0), context1.transpose(-1, -2) \ No newline at end of file diff --git a/pretrain_gpt2.py b/pretrain_gpt2.py index 87844669294b4455e8a1687d405ac83e977f9ca2..8fb6d3b355c095d1616765b6a60f6969f9620b76 100755 --- a/pretrain_gpt2.py +++ b/pretrain_gpt2.py @@ -179,7 +179,6 @@ def get_learning_rate_scheduler(optimizer, args): return lr_scheduler - def setup_model_and_optimizer(args): """Setup model and optimizer.""" diff --git a/random_display.py b/random_display.py index 32096ddbf43b589665943d94bab91b701bec4501..7c9d56f0f0a30c323c1527d8ed1591e5723459a4 100644 --- a/random_display.py +++ b/random_display.py @@ -5,21 +5,24 @@ import os import torch import random test_dir = 'tmp' -# bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin' -bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/quanjing003/quanjing003.bin.part_0.cogdata' -bin_ds = BinaryDataset(os.path.join(bin_dir), process_fn=lambda x:x, length_per_sample=64*64+32*32+64, dtype='int32', preload=False) +# bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin' +bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata' +bin_ds = BinaryDataset(os.path.join(bin_dir), process_fn=lambda x:x, length_per_sample=16**2+64*64+32*32+64, dtype='int32', preload=False) args = argparse.Namespace(img_tokenizer_path='pretrained/vqvae/vqvae_hard_biggerset_011.pt', img_tokenizer_num_tokens=None) tokenizer = get_tokenizer(args) bin_ds = [bin_ds[random.randint(0, len(bin_ds)-1)] for i in range(32)] for x in bin_ds: - end = x.tolist().index(-1) + if x[63] != -1: + end = 64 + else: + end = x.tolist().index(-1) print(tokenizer.DecodeIds(x[:end])[0]) from torchvision.utils import save_image -imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64:64+64**2], dtype=torch.long, device='cuda')) for x in bin_ds], dim=0) -save_image(imgs, os.path.join(test_dir, 'testcase512.jpg'), normalize=True) -imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64+64**2:64+64**2+32**2], dtype=torch.long,device='cuda')) for x in bin_ds], dim=0) +imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64:64+16**2], dtype=torch.long, device='cuda')) for x in bin_ds], dim=0) +save_image(imgs, os.path.join(test_dir, 'testcase128.jpg'), normalize=True) +imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64+16**2:64+16**2+32**2], dtype=torch.long,device='cuda')) for x in bin_ds], dim=0) save_image(imgs, os.path.join(test_dir, 'testcase256.jpg'), normalize=True) -# imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64+64**2+32**2:], dtype=torch.long,device='cuda')) for x in bin_ds], dim=0) -# save_image(imgs, os.path.join(test_dir, 'testcase128.jpg'), normalize=True) \ No newline at end of file +imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64+16**2+32**2:], dtype=torch.long,device='cuda')) for x in bin_ds], dim=0) +save_image(imgs, os.path.join(test_dir, 'testcase512.jpg'), normalize=True) \ No newline at end of file diff --git a/scripts/cuda_2d_text2image.sh b/scripts/cuda_2d_text2image.sh index 3f03974a9f479821f415107e4cefb84efae97ed2..6a32e31627fcdc70ab9b62699312800ea6acd0e2 100755 --- a/scripts/cuda_2d_text2image.sh +++ b/scripts/cuda_2d_text2image.sh @@ -1,6 +1,6 @@ #!/bin/bash -CHECKPOINT_PATH=data/checkpoints/cogview-long +CHECKPOINT_PATH=data/checkpoints/cogview-base # CHECKPOINT_PATH=data/checkpoints/cogview-compare NLAYERS=48 NHIDDEN=2560 @@ -10,9 +10,9 @@ MASTER_PORT=$(shuf -n 1 -i 10000-65535) MPSIZE=1 #SAMPLING ARGS -TEMP=1.03 +TEMP=1. #If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p -TOPK=100 +TOPK=200 TOPP=0 script_path=$(realpath $0) @@ -36,13 +36,14 @@ MASTER_PORT=${MASTER_PORT} python generate_samples.py \ --max-position-embeddings-finetune $MAXSEQLEN \ --generation-task "cuda-2d generation" \ --input-source ./input.txt \ - --output-path samples_cuda_2d2 \ - --batch-size 3 \ + --output-path samples_cuda_2d3 \ + --batch-size 4 \ --max-inference-batch-size 4 \ --device 0 \ --finetune \ --no-load-optim \ --sparse-type cuda_2d \ + --debug \ $@ diff --git a/scripts/ds_config_zero.json b/scripts/ds_config_zero.json index 1f0f35b84c9f2720dc39d1240cafd63a9e26b2f7..b9855836dd584b47b0e170e94e6c4d07cf9ba6f2 100755 --- a/scripts/ds_config_zero.json +++ b/scripts/ds_config_zero.json @@ -1,6 +1,6 @@ { - "train_micro_batch_size_per_gpu": 6, - "gradient_accumulation_steps": 5, + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 1, "steps_per_print": 1, "gradient_clipping": 0.1, "zero_optimization": { @@ -10,7 +10,8 @@ "overlap_comm": true, "reduce_scatter": true, "reduce_bucket_size": 100000000, - "allgather_bucket_size": 1000000000 + "allgather_bucket_size": 1000000000, + "load_from_fp32_weights": false }, "zero_allow_untested_optimizer": true, "fp16": { @@ -23,13 +24,13 @@ "optimizer": { "type": "Adam", "params": { - "lr": 0.0005, + "lr": 0.0002, "betas": [ 0.9, 0.95 ], "eps": 1e-8, - "weight_decay": 4e-2 + "weight_decay": 1e-4 } }, "activation_checkpointing": { diff --git a/scripts/pretrain_multiple_nodes.sh b/scripts/pretrain_multiple_nodes.sh index a1245d3f2bd70fac8f5f1809ac6c7cdf79f0f68d..f099f8011dd4e174d0e4cd1d1dcc68afc641b392 100755 --- a/scripts/pretrain_multiple_nodes.sh +++ b/scripts/pretrain_multiple_nodes.sh @@ -16,12 +16,12 @@ HOST_FILE_PATH="hostfile" # OPTIONS_NCCL="" # HOST_FILE_PATH="hostfile_single" -small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/zijian/zijian.bin.part_0.cogdata" -full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin" +small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata" +full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin" config_json="$script_dir/ds_config_zero.json" gpt_options=" \ - --experiment-name cogview-base-continue-long \ + --experiment-name cogview-base-long \ --img-tokenizer-num-tokens 8192 \ --dataset-type CompactBinaryDataset \ --model-parallel-size ${MP_SIZE} \ @@ -47,9 +47,9 @@ gpt_options=" \ --no-load-optim \ --no-save-optim \ --eval-interval 1000 \ - --save /root/checkpoints \ + --save $main_dir/data/checkpoints \ --fast-load \ - --load data/checkpoints/cogview-continue \ + --load data/checkpoints/cogview-base \ --finetune " diff --git a/scripts/text2image.sh b/scripts/text2image.sh index fdb3bf3a17ddac30b3954a50d9f219691ed13f12..4f2f5d55940a5f95a4b5e343c6b01da07d65139a 100755 --- a/scripts/text2image.sh +++ b/scripts/text2image.sh @@ -6,8 +6,8 @@ # NHIDDEN=1024 # NATT=16 -CHECKPOINT_PATH=data/checkpoints/cogview-continue -# CHECKPOINT_PATH=pretrained/cogview/cogview-base +# CHECKPOINT_PATH=data/checkpoints/cogview-base +CHECKPOINT_PATH=pretrained/cogview/cogview-base NLAYERS=48 NHIDDEN=2560 NATT=40 @@ -16,7 +16,7 @@ MASTER_PORT=$(shuf -n 1 -i 10000-65535) MPSIZE=1 #SAMPLING ARGS -TEMP=1. +TEMP=1 #If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p TOPK=200 TOPP=0 @@ -46,6 +46,7 @@ MASTER_PORT=${MASTER_PORT} python generate_samples.py \ --batch-size 8 \ --max-inference-batch-size 8 \ --device 0 \ + --debug \ $@ diff --git a/utils.py b/utils.py index 0d4c51d64c77299f56b9e6861b376a4a37808518..3256dce3e6fc1daf6df72aaae7f15d1ca4916de0 100755 --- a/utils.py +++ b/utils.py @@ -323,8 +323,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states= if args.deepspeed: checkpoint_name, sd = model.load_checkpoint(args.load, iteration, load_optimizer_states=not args.no_load_optim, load_module_strict=not args.finetune) - if args.finetune: - model.module.module.init_plus_from_old() + # if args.finetune: + # model.module.module.init_plus_from_old() if (args.finetune or args.no_load_optim) and model.zero_optimization(): model.optimizer.refresh_fp32_params() if "client_lr_scheduler" in sd and not args.finetune: @@ -461,4 +461,4 @@ def move_weights(our, oai, dst2src=False): load_weights(transformer_model.wpe, our.position_embeddings, dst2src) for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h): - load_transformer_layer(our_layer, oai_layer, dst2src) + load_transformer_layer(our_layer, oai_layer, dst2src) \ No newline at end of file