Skip to content
Snippets Groups Projects
Commit 3cd2ae83 authored by Ming Ding's avatar Ming Ding
Browse files

cuda_2d_sr

parent 6db42241
No related branches found
No related tags found
No related merge requests found
Showing
with 716 additions and 451 deletions
......@@ -147,6 +147,8 @@ def add_training_args(parser):
group.add_argument('--warmup', type=float, default=0.01,
help='percentage of data to warmup on (.01 = 1% of all '
'training iters). Default 0.01')
group.add_argument('--restart-iter', type=int, default=0,
help='restart with warmup from this iteration.')
# model checkpointing
group.add_argument('--save', type=str, default=None,
help='Output directory to save checkpoints to.')
......@@ -301,19 +303,20 @@ def add_sparse_args(parser):
group.add_argument("--key-window-times", type=int, default=6)
group.add_argument("--num-pivot", type=int, default=768)
# for cuda_2d
group.add_argument("--kernel-size", type=int, default=11)
group.add_argument("--kernel-size", type=int, default=9)
group.add_argument("--kernel-size2", type=int, default=7)
group.add_argument("--layout", type=str, default='0,64,1088,5184')
group.add_argument("--layout", type=str, default='64,1088,5184')
return parser
def make_sparse_config(args):
args.layout = [int(x) for x in args.layout.split(',')]
sparse_config = argparse.Namespace(sparse_type=args.sparse_type)
if args.sparse_type == 'standard':
pass
if args.sparse_type == 'cuda_2d' or args.generation_task == 'cuda-2d generation':
sparse_config.kernel_size = args.kernel_size
sparse_config.kernel_size2 = args.kernel_size2
sparse_config.layout = [int(x) for x in args.layout.split(',')]
sparse_config.layout = args.layout
elif args.sparse_type == 'torch_1d':
raise NotImplementedError
args.sparse_config = sparse_config
......
......@@ -80,16 +80,16 @@ class BinaryDataset(Dataset):
def __getitem__(self, index):
return self.process_fn(self.bin[index])
def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset):
def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset):
kwargs_to_dataset = {}
tokenizer = get_tokenizer()
if args.finetune and args.max_position_embeddings_finetune > args.max_position_embeddings:
ml = args.max_position_embeddings_finetune
if args.layout[-1] > args.max_position_embeddings:
ml = args.layout[-1]
else:
ml = args.max_position_embeddings
def pad_to_len(ret):
if len(ret) < ml: # pad
return np.concatenate((ret,
np.array([tokenizer['[PAD]']] * (ml - len(ret)))),
......@@ -117,14 +117,27 @@ def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset):
}
elif dataset_type == 'CompactBinaryDataset':
layout = args.layout
DS_CLASS = BinaryDataset
kwargs_to_dataset['length_per_sample'] = layout[-1]
def process_fn(row):
text, code = row[:64].astype(np.int64), row[64:].astype(np.int64) # must 64 + 1024
text = text[text>-1]
ret = TextCodeTemplate(text, code)
ret, attention_mask_sep = pad_to_len(ret)
row = row.astype(np.int64)
# THIS IS Reverse order, TODO
lens = list(reversed([layout[i] - layout[i-1] for i in range(1, len(layout))]))
codes = [row[layout[0]: layout[0]+lens[0]]]
if len(lens) > 1:
codes.append(row[layout[0]+lens[0]: layout[0]+lens[0]+lens[1]])
text = row[:layout[0]]
text = text[text>0][:layout[0] - 3] # [CLS] [BASE] [ROI1]
n_pad = layout[0]-3-len(text)
parts = [
np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64),
TextCodeTemplate(text, codes[-1]),
*reversed(codes[:-1])
]
ret = np.concatenate(parts, axis=0)
return {'text': ret,
'loss_mask': np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep))
'loss_mask': np.array([0] * (n_pad+1) + [1] * (len(ret) - n_pad - 1)) # don't predict [CLS]
}
elif dataset_type == 'BinaryDataset':
DS_CLASS = BinaryDataset
......@@ -134,5 +147,5 @@ def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset):
'loss_mask': loss_mask
}
return DS_CLASS(path, process_fn)
return DS_CLASS(path, process_fn, **kwargs_to_dataset)
......@@ -50,12 +50,15 @@ class VQVAETokenizer(object):
self.device = device
self.image_tokens = model.quantize_t.n_embed
self.num_tokens = model.quantize_t.n_embed
self.tr_normalize = transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800])
def __len__(self):
return self.num_tokens
def EncodeAsIds(self, img):
def EncodeAsIds(self, img, add_normalization=False):
assert len(img.shape) == 4 # [b, c, h, w]
if add_normalization:
img = self.tr_normalize(img)
return img2code(self.model, img)
def DecodeIds(self, code, shape=None):
......@@ -78,7 +81,6 @@ class VQVAETokenizer(object):
img = tr(Image.open(path))
if img.shape[0] == 4:
img = img[:-1]
tr_normalize = transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800])
img = tr_normalize(img)
img = self.tr_normalize(img)
img = img.unsqueeze(0).float().to(self.device) # size [1, 3, h, w]
return img
\ No newline at end of file
......@@ -23,9 +23,10 @@ transform = transforms.Compose([
])
img = torchvision.io.read_image('bao.jpeg')
img = transform(img) / 255.
a = np.array(loadbao('bao.txt'))
b = np.array(loadbao('bao2.txt'))
for t in (a-b>2).nonzero()[0]:
a = np.array(loadbao('bao2.txt'))
b = np.array(loadbao('bao3.txt'))
for t in (b-a>1).nonzero()[0]:
x,y = t // 32, t % 32
sq(img, x*16, y*16, 15, 15)
print(a.mean(), b.mean())
torchvision.utils.save_image(img, 'example_bao.jpg')
......@@ -41,7 +41,7 @@ from pretrain_gpt2 import get_model
import math
from copy import deepcopy
from tqdm import tqdm
from generation import get_batch, filling_sequence, add_interlacing_beam_marks, magnify, inverse_prompt_score, filling_sequence_local
from generation import get_batch, filling_sequence, add_interlacing_beam_marks, magnify, inverse_prompt_score, filling_sequence_local, filling_sequence_cuda_2d
from torchvision.utils import save_image
import torch.distributed as dist
......@@ -167,7 +167,7 @@ def generate_images_once(model, args, raw_text, seq=None, num=8, query_template=
# profile = line_profiler.LineProfiler(model.module.forward)
# profile = line_profiler.LineProfiler(standard_attention)
# profile.enable()
fill_fn = filling_sequence_local if args.generation_task == 'cuda-2d generation' else filling_sequence
fill_fn = filling_sequence_cuda_2d if args.generation_task == 'cuda-2d generation' else filling_sequence
output_tokens_list.append(fill_fn(model, seq.clone(), args))
# torch.cuda.empty_cache()
# profile.disable() # 停止分析
......@@ -212,7 +212,7 @@ def generate_images_once(model, args, raw_text, seq=None, num=8, query_template=
def generate_images_continually(model, args):
if args.generation_task == 'text2image':
query_template = '[ROI1] {} [BASE] [BOI1] [Image200]bao.jpeg'
query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024'
elif args.generation_task == 'image2text':
query_template = '[BASE] [BOI1] [Image]{} [EOI1] [ROI1] [MASK]*20'
elif args.generation_task == 'low-level super-resolution':
......@@ -222,7 +222,7 @@ def generate_images_continually(model, args):
elif args.generation_task == 'post-selection':
query_template = '[BASE] [BOI1] [Image]{} [EOI1] [ROI1] {}'
elif args.generation_task == 'cuda-2d generation':
query_template = '[CLS] {} [BASE] [Image200]bao.jpeg [MASK]*4096'
query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024 [EOI1] [MASK]*4096'
else:
raise NotImplementedError
for raw_text, seq, output_path in get_context(args, query_template):
......
from .sampling import get_batch, filling_sequence, add_interlacing_beam_marks, inverse_prompt_score
from .magnify import magnify
from .local_sampling import filling_sequence_local
\ No newline at end of file
from .local_sampling import filling_sequence_local
from .cuda_2d_sampling import filling_sequence_cuda_2d
\ No newline at end of file
from .sampling import *
import math
import sys
from copy import deepcopy
from torchvision.utils import save_image
def filling_sequence_cuda_2d(
model,
seq,
args,
mems=None,
invalid_slices=[],
iterative_step=20,
**kwargs):
'''
seq: [id[ROI1], 10000, 20000, id[BASE], id[BOI1], 1024 * -1/known tokens, id[EOI1], 4096 * -1..., ]
'''
tokenizer = get_tokenizer()
invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
device = seq.device
assert args.sparse_config.sparse_type == 'cuda_2d'
std_config = deepcopy(args.sparse_config)
std_config.sparse_type = 'standard'
sparse_config = args.sparse_config
# split two parts
seq0, seq1 = seq[:-4097], seq[-4097:] # +1 for EOI1
# generate a batch of seq0
model.module.transformer.reset_sparse_config(std_config)
args.sparse_config = std_config
output0 = filling_sequence(model, seq0, args)
model.module.transformer.reset_sparse_config(sparse_config)
args.sparse_config = sparse_config
model.module.transformer.max_memory_length = 0
# filter bad generation & select top N=2, TODO
output0 = output0
from torchvision import transforms
tr = transforms.Compose([
transforms.Resize(512),
])
imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1024:].tolist())) for x in output0] # ground truth
blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(device), add_normalization=True) # blured image as init value
# pad seq to desired shape
n_pad = args.layout[1] - len(seq0)
batch_size = output0.shape[0]
assert n_pad > 0, "You should truncate long input before filling."
seq = torch.cat((
torch.tensor([tokenizer['[PAD]']]* n_pad, device=seq.device, dtype=seq.dtype)
.unsqueeze(0).expand(batch_size, n_pad),
output0,
seq1.unsqueeze(0).expand(batch_size, len(seq1))
), dim=1
) # [b, layout[-1]]
# init
step_cnt = 0
tokens = seq[:, :-1].clone()
unfixed = (seq < 0)
# tokens[unfixed[:, :-1]] = tokens[unfixed[:, :-1]].random_(0, tokenizer.img_tokenizer.num_tokens)
tokens[:, -4095:] = blur64[:, :-1]
attention_mask = torch.ones(args.layout[1], args.layout[1]).tril().to(device)
attention_mask[n_pad:, :n_pad] = 0
position_ids = torch.cat((
torch.zeros(n_pad, dtype=torch.long),
torch.arange(0, args.layout[1] - n_pad),
torch.arange(0, args.layout[2]-args.layout[1]))).to(device)
# iterate
imgs = []
# import pdb;pdb.set_trace()
while unfixed.sum() > 0:
print(unfixed.sum())
logits, *_dump = model(tokens, position_ids, attention_mask)
step_cnt += 1
last_logits = logits
# warmup
real_topk = 5
real_temp = 2 - min(1,((step_cnt) / iterative_step)) * 1.9
# sampling
for invalid_slice in invalid_slices: # forbide to generate other tokens
logits[..., invalid_slice] = -float('Inf')
assert args.top_k > 0
tk_value, tk_idx = torch.topk(logits, real_topk, dim=-1)
tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1])
prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1)
prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1)
# update unfixed
choice = 1
if choice == 0 and step_cnt > 5:
mprob = tk_probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
dprob = (mprob[:, 1:] < 0.5) & ((mprob[:, :-1] > 0.8)| (unfixed[:, 1:-1].logical_not()))
new_fixed = unfixed.clone()
new_fixed[:, 2:] &= dprob
else:
new_fixed = unfixed & False # TODO
new_fixed[:, -1] = True
unfixed &= new_fixed.logical_not()
# update seq and tokens
seq[new_fixed] = prev[new_fixed[:, 1:]]
tokens = seq[:, :-1].clone()
tokens[:,1:][unfixed[:, 1:-1]] = prev[:, :-1][unfixed[:, 1:-1]]
if step_cnt == iterative_step:
seq[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
n_unfixed = unfixed.sum(dim=-1).tolist()
print(f'Exit with {n_unfixed} unfixed tokens.')
break
if args.debug:
from torchvision.utils import save_image
seqt = seq.clone()
seqt[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
imgs.extend([tokenizer.img_tokenizer.DecodeIds(s[-4096:]) for s in seqt])
if args.debug:
imgs = torch.cat(imgs, dim=0)
save_image(imgs, f'steps{device}.jpg', normalize=True)
model.module.transformer.max_memory_length = args.max_memory_length
return seq
\ No newline at end of file
......@@ -132,8 +132,8 @@ def filling_sequence(
tmp = -F.log_softmax(logits, dim=-1)
tmp = tmp[0,:-1].gather(dim=-1,index=tokens[0,1:].unsqueeze(-1))[4:,0]
for i in range(1,len(tmp)):
print(i, tmp[i].item())
# for i in range(1,len(tmp)):
# print(i, tmp[i].item())
index = counter
print(tmp[1:].mean(), file=sys.stderr)
elif seq[counter + 1] >= 0: # provided
......
......@@ -17,13 +17,16 @@
import torch
from torch.optim.lr_scheduler import _LRScheduler
import math
from utils import print_rank_0
class AnnealingLR(_LRScheduler):
"""Anneals the learning rate from start to zero along a cosine curve."""
DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, decay_ratio=0.5):
def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, decay_ratio=0.5, restart_iter=0):
self.restart_iter = restart_iter
assert warmup_iter <= num_iters
self.optimizer = optimizer
self.start_lr = start_lr
......@@ -38,13 +41,16 @@ class AnnealingLR(_LRScheduler):
def get_lr(self):
# https://openreview.net/pdf?id=BJYwwY9ll pg. 4
if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
return float(self.start_lr) * self.num_iters / self.warmup_iter
real_num_iters = self.num_iters - self.restart_iter
real_end_iter = self.end_iter - self.restart_iter
# print_rank_0(f'real_num_iters: {real_num_iters}')
if self.warmup_iter > 0 and real_num_iters <= self.warmup_iter:
return float(self.start_lr) * real_num_iters / self.warmup_iter
else:
if self.decay_style == self.DECAY_STYLES[0]:
return self.start_lr*((self.end_iter-(self.num_iters-self.warmup_iter))/self.end_iter)
return self.start_lr*((real_end_iter-(real_num_iters-self.warmup_iter))/real_end_iter)
elif self.decay_style == self.DECAY_STYLES[1]:
decay_step_ratio = min(1.0, (self.num_iters - self.warmup_iter) / self.end_iter)
decay_step_ratio = min(1.0, (real_num_iters - self.warmup_iter) / real_end_iter)
return self.start_lr / self.decay_ratio * (
(math.cos(math.pi * decay_step_ratio) + 1) * (self.decay_ratio - 1) / 2 + 1)
elif self.decay_style == self.DECAY_STYLES[2]:
......@@ -73,8 +79,9 @@ class AnnealingLR(_LRScheduler):
return sd
def load_state_dict(self, sd):
import pdb;pdb.set_trace()
# self.start_lr = sd['start_lr']
self.warmup_iter = sd['warmup_iter']
# self.warmup_iter = sd['warmup_iter']
self.num_iters = sd['num_iters']
# self.end_iter = sd['end_iter']
self.decay_style = sd['decay_style']
......
......@@ -41,15 +41,14 @@ def gpt2_get_params_for_weight_decay_optimization(module):
if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
if p is not None and p.requires_grad])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
if p is not None and n != 'bias' and p.requires_grad])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
if p is not None and n == 'bias' and p.requires_grad])
return weight_decay_params, no_weight_decay_params
......@@ -74,7 +73,8 @@ class GPT2Model(torch.nn.Module):
sandwich_ln,
checkpoint_num_layers=1,
parallel_output=True,
sparse_config=argparse.Namespace(sparse_type='standard')
sparse_config=argparse.Namespace(sparse_type='standard'),
finetune=False
):
super(GPT2Model, self).__init__()
......@@ -99,7 +99,8 @@ class GPT2Model(torch.nn.Module):
checkpoint_activations,
checkpoint_num_layers,
sandwich_ln=sandwich_ln,
sparse_config=sparse_config
sparse_config=sparse_config,
finetune=finetune
)
def forward(self, input_ids, position_ids, attention_mask, *mems):
......@@ -120,3 +121,6 @@ class GPT2Model(torch.nn.Module):
return (logits_parallel, *hidden_layers)
return (mpu.gather_from_model_parallel_region(logits_parallel), *hidden_layers)
def init_plus_from_old(self):
self.transformer.init_plus_from_old()
This diff is collapsed.
def sparse_attention_1d(q, k, v, pivot_idx, pivot_attention_mask, query_window=128, key_window_times=6, attention_dropout=None):
''' Sparse Attention
Args:
q, k, v: inputs, [b, num_heads, s, hn], k is padded to n * query_window
pivot_idx: [b, num_pivots]
pivot_attention_mask: [b, s, num_pivots]
query_window: .
key_window_times: key_window = query_window * key_window_times
'''
b, n_head, s, hn = q.shape
b, n_piv = pivot_idx.shape
w = query_window
pivot_idx_dummy = pivot_idx.view(b, 1, n_piv, 1).expand(b, n_head, n_piv, hn)
# ===================== Pivot Attention ======================== #
pivot_k, pivot_v = torch.gather(k, 2, pivot_idx_dummy), torch.gather(v, 2, pivot_idx_dummy)
attention_scores = torch.matmul(q, pivot_k.transpose(-1, -2))
pivot_attention_mask = pivot_attention_mask.unsqueeze(1)
attention_scores_pivot = torch.mul(attention_scores, pivot_attention_mask / math.sqrt(hn)) - 10000.0 * (1.0 - pivot_attention_mask)
attention_scores_pivot = attention_scores_pivot + math.log(s // n_piv)
# ===================== Window Attention ======================= #
window_k = _chunk(k, query_window, key_window_times)
window_v = _chunk(v, query_window, key_window_times)
# window_k [b, n_head, s // w up int, w*times, hn]
if s % w == 0: # training # TODO args check
assert k.shape[2] == s
assert window_k.shape[2] == s // w
window_q = q.view(b, n_head, s // w, w, hn)
attention_scores = torch.matmul(window_q, window_k.transpose(-1, -2))
window_attention_mask = torch.ones((w, w * key_window_times), dtype=attention_scores.dtype, device=q.device).tril_(diagonal=w * (key_window_times - 1))
attention_scores_window = torch.mul(attention_scores, window_attention_mask / math.sqrt(hn)) - 10000.0 * (1.0 - window_attention_mask)
for t in range(1, key_window_times):
attention_scores_window[:, :, t - 1, :, :w * key_window_times - w * t] -= 10000.0
else:
raise ValueError('The seq_len must be exactly divided by window_size.')
# ===================== Joint Softmax ======================= #
attention_scores_window = attention_scores_window.view(b, n_head, s, w * key_window_times)
attention_scores = torch.cat((attention_scores_pivot, attention_scores_window), dim=-1)
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
if attention_dropout is not None:
with get_cuda_rng_tracker().fork():
attention_probs = attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs[..., :-w * key_window_times], pivot_v) + torch.einsum('bcgwk,bcgkh->bcgwh', attention_probs[..., -w * key_window_times:].view(b, n_head, s // w, w, w * key_window_times), window_v).view(b, n_head, s, hn)
return context_layer
def transpose_and_split(x, layout, n_head):
x = x.transpose(1, 2)
x = x.reshape(x.shape[0]*n_head, x.shape[1] // n_head, x.shape[2])
x_text = x[..., :layout[0]]
x0 = x[...,layout[1]:layout[2]].view(x.shape[0], x.shape[1], sqrt(layout[2] - layout[1]), sqrt(layout[2] - layout[1])).contiguous()
x1 = x[...,layout[2]:layout[3]].view(x.shape[0], x.shape[1], sqrt(layout[3] - layout[2]), sqrt(layout[3] - layout[2])).contiguous()
return x, x_text, x0, x1
def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs):
'''
q, k, v: [batch_size, 64+1024+4096, hidden_size]
n_head: int
layout: [endoftext/startofpad, startof0, startof1, endofall]
attention_mask_text2d: [batch_size, sq_len, endoftext]
'''
from .local_attention_function import f_similar, f_weighting
b, sq_len, hn = q.shape
alpha = sqrt((layout[3] - layout[2]) // (layout[2] - layout[1]))
q = q / math.sqrt(hn // n_head) # normalization
q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext]
k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head)
v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head)
# import pdb; pdb.set_trace()
# all to text
scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d)
scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0])
# 0 to 0
scores_0_to_0 = f_similar(q0, k0, kernel_size*2-1, kernel_size, True)
# 1 to 1
scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)
# 1 to 0
scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2]
# softmax
# if 'offset_bias' in kwargs:
# p1, p2 = kernel_size**2//2 + 1, kernel_size2**2
# offset_bias = kwargs['offset_bias'].expand(b, n_head, p1+p2).view(b*n_head, 1, p1+p2)
# scores_0_to_0 = scores_0_to_0 * offset_bias[...,:p1]
# scores_1_to_1 = scores_1_to_1 * offset_bias[...,:p1]
# scores_1_to_0 = scores_1_to_0 * offset_bias[...,-p2:]
scores_0 = torch.cat(
(scores_all_to_text[:, layout[1]:layout[2]],
scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])),
dim=-1)
scores_1 = torch.cat(
(scores_all_to_text[:, layout[2]:layout[3]],
scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]),
scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])),
dim=-1)
probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text]
probs_0 = F.softmax(scores_0, dim=-1) #
probs_1 = F.softmax(scores_1, dim=-1)
if attention_dropout is not None:
with get_cuda_rng_tracker().fork():
probs_0 = attention_dropout(probs_0)
probs_1 = attention_dropout(probs_1)
# weighting
pad = torch.zeros(layout[1], device=q.device, dtype=q.dtype)
probs_all_to_text = torch.cat((
probs_text,
pad[-layout[0]:].expand(b*n_head, layout[1]-layout[0], layout[0]),
probs_0[:, :, :layout[0]],
probs_1[:, :, :layout[0]]
), dim=1)
context_all_to_text = torch.einsum('bhij,bhcj->bihc',
probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]),
v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn)
context_0_to_0 = f_weighting(v0, probs_0[..., layout[0]:].view_as(scores_0_to_0).contiguous(), kernel_size*2-1, kernel_size, True)
context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False)
context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size*2-1, kernel_size, True)
context_all_to_01 =torch.cat(
(
pad.expand(b*n_head, hn//n_head, layout[1]),
context_0_to_0.view(b*n_head, hn//n_head, layout[2]-layout[1]),
(context_1_to_0 + context_1_to_1).view(b*n_head, hn//n_head, layout[3]-layout[2])
), dim=-1).view(b, hn, -1).transpose(1, 2)
return context_all_to_text + context_all_to_01
def sparse_attention_2dfull(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs):
'''
q, k, v: [batch_size, 64+1024+4096, hidden_size]
n_head: int
layout: [endoftext/startofpad, startof0, startof1, endofall]
attention_mask_text2d: [batch_size, sq_len, endoftext]
'''
from .local_attention_function import f_similar, f_weighting
b, sq_len, hn = q.shape
alpha = sqrt((layout[3] - layout[2]) // (layout[2] - layout[1]))
q = q / math.sqrt(hn // n_head) # normalization
q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext]
k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head)
v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head)
# import pdb; pdb.set_trace()
# all to text
scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d)
scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0])
# 0 to 0
if not hasattr(sparse_attention_2dfull, 'attention_mask0'):
sparse_attention_2dfull.attention_mask0 = torch.ones((layout[2] - layout[1], layout[2] - layout[1]), device=q.device, dtype=q.dtype).tril_()
attention_mask0 = sparse_attention_2dfull.attention_mask0
scores_0_to_0 = torch.einsum('bhi,bhj->bij', q0.view(*q0.shape[:2], -1), k0.view(*k0.shape[:2], -1)) * attention_mask0 - 10000.0 * (1.0 - attention_mask0)
# 1 to 1
scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)
# 1 to 0
scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2]
# softmax
scores_0 = torch.cat(
(scores_all_to_text[:, layout[1]:layout[2]],
scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])),
dim=-1)
scores_1 = torch.cat(
(scores_all_to_text[:, layout[2]:layout[3]],
scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]),
scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])),
dim=-1)
probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text]
probs_0 = F.softmax(scores_0, dim=-1) #
probs_1 = F.softmax(scores_1, dim=-1)
if attention_dropout is not None:
with get_cuda_rng_tracker().fork():
probs_0 = attention_dropout(probs_0)
probs_1 = attention_dropout(probs_1)
# weighting
pad = torch.zeros(layout[1], device=q.device, dtype=q.dtype)
probs_all_to_text = torch.cat((
probs_text,
pad[-layout[0]:].expand(b*n_head, layout[1]-layout[0], layout[0]),
probs_0[:, :, :layout[0]],
probs_1[:, :, :layout[0]]
), dim=1)
context_all_to_text = torch.einsum('bhij,bhcj->bihc',
probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]),
v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn)
context_0_to_0 = torch.einsum('bcj,bij->bci', v0.view(*v0.shape[:2], -1), probs_0[..., layout[0]:].view_as(scores_0_to_0))
context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False)
context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size*2-1, kernel_size, True)
context_all_to_01 =torch.cat(
(
pad.expand(b*n_head, hn//n_head, layout[1]),
context_0_to_0.view(b*n_head, hn//n_head, layout[2]-layout[1]),
(context_1_to_0 + context_1_to_1).view(b*n_head, hn//n_head, layout[3]-layout[2])
), dim=-1).view(b, hn, -1).transpose(1, 2)
return context_all_to_text + context_all_to_01
if args.sparse_config.sparse_type == 'cuda_2d':
layout = args.sparse_config.layout
unpad_indices = (data[:, :layout[1]+1] >= 0) * 10000.
unpad_indices[:, -1] = 9000.
starts = (torch.arange(layout[1]+1, device=data.device).expand_as(unpad_indices) + unpad_indices).min(dim=-1)[1]
layout[0] = starts.max().item()
attention_mask = torch.ones((batch_size, seq_length, layout[0]), device=data.device)
for i in range(batch_size):
attention_mask[i, :, starts[i]:layout[1]] = 0
attention_mask[:, :layout[0]].tril_()
attention_mask = attention_mask.unsqueeze(1)
elif args.sparse_config.sparse_type == 'standard':
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=data.device)
attention_mask.tril_()
# attention_mask = torch.zeros((seq_length, seq_length), device=data.device)
# h = w = 32
# k1=9
# layout = [10, 64, 64+h*w, 64+h*w*5]
# for i in range(layout[1]):
# attention_mask[i, :i+1] = 1
# for i in range(layout[1], layout[2]):
# x = (i - layout[1]) // w
# y = (i - layout[1]) % w
# lx = max(0, x - k1 // 2)
# ly = max(0, y - k1 // 2)
# rx = min(h-1, x + k1 // 2)
# ry = min(w-1, y + k1 // 2)
# attention_mask[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1
# attention_mask[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1
# attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
\ No newline at end of file
......@@ -73,7 +73,8 @@ def get_model(args, sparse_config=None):
checkpoint_num_layers=args.checkpoint_num_layers,
parallel_output=True,
sparse_config=sparse_config if sparse_config is not None else args.sparse_config,
sandwich_ln=args.sandwich_ln
sandwich_ln=args.sandwich_ln,
finetune=args.finetune
)
if mpu.get_data_parallel_rank() == 0:
......@@ -163,7 +164,7 @@ def get_learning_rate_scheduler(optimizer, args):
num_iters = args.lr_decay_iters
else:
num_iters = args.train_iters
num_iters = max(1, num_iters)
num_iters = max(1, num_iters - args.restart_iter)
init_step = -1
warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR(optimizer,
......@@ -172,7 +173,9 @@ def get_learning_rate_scheduler(optimizer, args):
num_iters=num_iters,
decay_style=args.lr_decay_style,
last_iter=init_step,
decay_ratio=args.lr_decay_ratio)
decay_ratio=args.lr_decay_ratio,
restart_iter=args.restart_iter
)
return lr_scheduler
......@@ -182,6 +185,12 @@ def setup_model_and_optimizer(args):
model = get_model(args)
if args.finetune:
model.requires_grad_(False)
for name, param in model.named_parameters():
if name.find('_plus') > 0:
param.requires_grad_(True)
param_groups = get_optimizer_param_groups(model)
if args.train_data is not None:
......@@ -213,38 +222,18 @@ def get_masks_and_position_ids(data,
# Attention mask (lower triangular).
if attention_mask is None:
# single direction, [PAD]s are at the end of the seq, so doesn't matter.
if args.sparse_config.sparse_type == 'cuda_2d':
layout = args.sparse_config.layout
unpad_indices = (data[:, :layout[1]+1] >= 0) * 10000.
unpad_indices[:, -1] = 9000.
starts = (torch.arange(layout[1]+1, device=data.device).expand_as(unpad_indices) + unpad_indices).min(dim=-1)[1]
layout[0] = starts.max().item()
attention_mask = torch.ones((batch_size, seq_length, layout[0]), device=data.device)
# single direction, [PAD]s are at the start of the seq.
assert loss_mask is not None
# loss_mask has n_pad(+1 CLS and [1:] then) zeros, so it is the same as attention_mask, reuse.
attention_mask = loss_mask[:, :args.layout[1]].unsqueeze(-2).expand(batch_size, args.layout[1], args.layout[1]).tril()
for i in range(batch_size):
attention_mask[i, :, starts[i]:layout[1]] = 0
attention_mask[:, :layout[0]].tril_()
attention_mask[i].fill_diagonal_(1)
attention_mask = attention_mask.unsqueeze(1)
elif args.sparse_config.sparse_type == 'standard':
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=data.device)
attention_mask.tril_()
# attention_mask = torch.zeros((seq_length, seq_length), device=data.device)
# h = w = 32
# k1=9
# layout = [10, 64, 64+h*w, 64+h*w*5]
# for i in range(layout[1]):
# attention_mask[i, :i+1] = 1
# for i in range(layout[1], layout[2]):
# x = (i - layout[1]) // w
# y = (i - layout[1]) % w
# lx = max(0, x - k1 // 2)
# ly = max(0, y - k1 // 2)
# rx = min(h-1, x + k1 // 2)
# ry = min(w-1, y + k1 // 2)
# attention_mask[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1
# attention_mask[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1
# attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
elif args.sparse_config.sparse_type == 'torch_1d':
else:
raise NotImplementedError
# Loss mask.
......@@ -252,26 +241,19 @@ def get_masks_and_position_ids(data,
loss_mask = torch.ones(data.size(), dtype=data.dtype, device=data.device)
# Position ids.
if args is not None and args.finetune and args.max_position_embeddings < args.max_position_embeddings_finetune:
# for each sample, find [ROI2] and split
# ([ROI1] text... [BOI1] img... [EOI1] [ROI2]<pos_id==1089> ...)
start_token = get_tokenizer()['[ROI2]']
tmp = torch.nonzero(data == start_token, as_tuple=False)
start_token_poses = [100000] * batch_size
for x, y in tmp:
start_token_poses[x] = min(start_token_poses[x], y)
assert 100000 not in start_token_poses, 'Some samples do not have [ROI2]!'
if args.sparse_config.sparse_type == 'cuda_2d':
assert loss_mask is not None
layout = args.layout
assert seq_length == layout[-1]
n_pads = seq_length - loss_mask.sum(dim=-1).long()
position_ids = torch.zeros(batch_size, seq_length, dtype=torch.long,
device=data.device)
for i in range(batch_size):
sep = start_token_poses[i]
torch.arange(start=0, end=sep, out=position_ids[i, :sep],
torch.arange(layout[1] - n_pads[i], out=position_ids[i, n_pads[i]:layout[1]],
dtype=torch.long, device=data.device)
second_pos = 0 # reuse
torch.arange(start=second_pos, end=second_pos + seq_length - sep,
out=position_ids[i, sep:],
torch.arange(layout[2] - layout[1],
out=position_ids[i, layout[1]:],
dtype=torch.long, device=data.device)
position_ids[position_ids >= args.max_position_embeddings] = args.max_position_embeddings - 1
else:
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
......@@ -298,28 +280,9 @@ def get_batch(data_iterator, args, timers):
tokens_ = data_b['text'].long()
loss_mask = data_b['loss_mask'].float()
#FIXME change order
if args.sparse_config.sparse_type == 'cuda_2d':
assert args.sparse_config.layout[-1] == 64+32**2+64**2
tokens_new = tokens_.clone()
tokens_new[:, 64:64+32**2] = tokens_[:, 64+64**2:]
tokens_new[:, 64+32**2:] = tokens_[:, 64:64+64**2]
tokens_ = tokens_new
# only 32#
# tokens_ = tokens_[:, :64+32**2]
# loss_mask = loss_mask[:, :64+32**2]
if args.dataset_type == 'BinaryDataset':
labels = tokens_.contiguous()
loss_mask = loss_mask.contiguous()
tokenizer = get_tokenizer()
cls_token = torch.zeros(tokens_.shape[0], 1, dtype=tokens_.dtype, device=tokens_.device) + tokenizer['[CLS]']
tokens = torch.cat((cls_token, tokens_[:, :-1]), dim=1)
tokens[:, 64] = tokenizer['[BASE]']
else:
labels = tokens_[:, 1:].contiguous()
loss_mask = loss_mask[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
labels = tokens_[:, 1:].contiguous()
loss_mask = loss_mask[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
attention_mask = None
......@@ -344,7 +307,6 @@ def forward_step(data_iterator, model, args, timers, mems):
timers('batch generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator, args, timers)
timers('batch generator').stop()
# split img & txt positions, [PAD] not included # TODO check enough
......@@ -352,7 +314,6 @@ def forward_step(data_iterator, model, args, timers, mems):
img_txt_sep = tokenizer.img_tokenizer.num_tokens
img_indices_bool = (tokens.detach() < img_txt_sep) & (loss_mask > 0)
txt_indices_bool = (~img_indices_bool) & (loss_mask > 0)
# Forward model.
logits, *mems = model(tokens, position_ids, attention_mask, *mems)
losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(),
......@@ -363,15 +324,14 @@ def forward_step(data_iterator, model, args, timers, mems):
losses = losses.view(-1) * loss_mask
loss = torch.sum(losses) / loss_mask.sum()
# ===================== Log partial losses ======================== #
if args.sparse_config.sparse_type == 'cuda_2d':
img_indices_bool2 = img_indices_bool.clone()
img_indices_bool2[:, :args.sparse_config.layout[2]] = False
img_indices_bool2[:, :args.sparse_config.layout[1]] = False
img_loss2 = losses[img_indices_bool2.view(-1)].detach().sum() / max(img_indices_bool2.sum(), 1)
torch.distributed.all_reduce(img_loss2.data)
img_loss2.data = img_loss2.data / args.world_size
img_indices_bool[:, args.sparse_config.layout[2]:] = False
img_indices_bool[:, args.sparse_config.layout[1]:] = False
else:
img_loss2 = 0
img_indices_bool = img_indices_bool.view(-1)
......@@ -386,9 +346,6 @@ def forward_step(data_iterator, model, args, timers, mems):
txt_loss.data = txt_loss.data / args.world_size
# ===================== END OF BLOCK ======================= #
# import pdb;pdb.set_trace()
# with open('tmp_save.bin', 'wb') as fout:
# torch.save(tokens, fout)
return loss, mems, img_loss, txt_loss, img_loss2
......@@ -471,7 +428,6 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
timers('backward').start()
lm_loss_reduced = backward_step(optimizer, model, lm_loss, args, timers)
timers('backward').stop()
# Update parameters.
skipped_iter, complete = 0, False
timers('optimizer').start()
......@@ -674,7 +630,14 @@ def evaluate(data_iterator, model, args, timers, verbose=False):
def evaluate_and_print_results(prefix, data_iterator, model,
args, timers, verbose=False, step=None, summary_writer=None):
"""Helper function to evaluate and dump results on screen."""
# import line_profiler
# profile = line_profiler.LineProfiler(model.module.module.transformer.layers[0].forward)
# profile.enable()
# torch.cuda.empty_cache()
lm_loss = evaluate(data_iterator, model, args, timers, verbose)
# profile.disable() # 停止分析
# import sys
# profile.print_stats(sys.stdout)
lm_ppl = math.exp(min(20, lm_loss))
report_evaluate_metrics(summary_writer, prefix, lm_loss, lm_ppl, step)
......
......@@ -6,12 +6,12 @@ import torch
import random
test_dir = 'tmp'
# bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin'
bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/quanjing005/quanjing005.bin.part_0.cogdata'
bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/quanjing003/quanjing003.bin.part_0.cogdata'
bin_ds = BinaryDataset(os.path.join(bin_dir), process_fn=lambda x:x, length_per_sample=64*64+32*32+64, dtype='int32', preload=False)
args = argparse.Namespace(img_tokenizer_path='pretrained/vqvae/vqvae_hard_biggerset_011.pt', img_tokenizer_num_tokens=None)
tokenizer = get_tokenizer(args)
bin_ds = [bin_ds[random.randint(0, len(bin_ds)-1)] for i in range(16)]
bin_ds = [bin_ds[random.randint(0, len(bin_ds)-1)] for i in range(32)]
for x in bin_ds:
end = x.tolist().index(-1)
print(tokenizer.DecodeIds(x[:end])[0])
......
#!/bin/bash
CHECKPOINT_PATH=data/checkpoints/cogview-fixgrad-small08-25-09-38
CHECKPOINT_PATH=data/checkpoints/cogview-long
# CHECKPOINT_PATH=data/checkpoints/cogview-compare
NLAYERS=16
NHIDDEN=1024
NATT=16
NLAYERS=48
NHIDDEN=2560
NATT=40
MAXSEQLEN=5184
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
MPSIZE=1
#SAMPLING ARGS
TEMP=1.05
TEMP=1.03
#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p
TOPK=100
TOPP=0
......@@ -25,22 +25,24 @@ MASTER_PORT=${MASTER_PORT} python generate_samples.py \
--hidden-size $NHIDDEN \
--load $CHECKPOINT_PATH \
--num-attention-heads $NATT \
--max-position-embeddings 5184 \
--max-position-embeddings 1089 \
--fp16 \
--temperature $TEMP \
--top_k $TOPK \
--top_p $TOPP \
--sandwich-ln \
--img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \
--sparse-type standard \
--sparse-type cuda_2d \
--max-position-embeddings-finetune $MAXSEQLEN \
--generation-task "cuda-2d generation" \
--input-source ./input.txt \
--output-path samples_text2image \
--batch-size 2 \
--output-path samples_cuda_2d2 \
--batch-size 3 \
--max-inference-batch-size 4 \
--device 0 \
--sparse-type standard \
--finetune \
--no-load-optim \
--sparse-type cuda_2d \
$@
......@@ -2,7 +2,7 @@
# Change for multinode config
NUM_WORKERS=10
NUM_WORKERS=19
NUM_GPUS_PER_WORKER=8
MP_SIZE=1
......@@ -12,23 +12,22 @@ main_dir=$(dirname $script_dir)
# OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=bond0 NCCL_IB_GID_INDEX=3 NCCL_NET_GDR_LEVEL=0"
OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
HOST_FILE_PATH="hostfile2"
HOST_FILE_PATH="hostfile"
# OPTIONS_NCCL=""
# HOST_FILE_PATH="hostfile_single"
small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/zijian/zijian.bin.part_0.cogdata"
full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin"
config_json="$script_dir/ds_config.json"
config_json="$script_dir/ds_config_zero.json"
gpt_options=" \
--experiment-name cogview-fixgrad-small-test \
--experiment-name cogview-base-continue-long \
--img-tokenizer-num-tokens 8192 \
--dataset-type BinaryDataset \
--dataset-type CompactBinaryDataset \
--model-parallel-size ${MP_SIZE} \
--num-layers 16 \
--hidden-size 1024 \
--num-attention-heads 16 \
--save $main_dir/data/checkpoints \
--num-layers 48 \
--hidden-size 2560 \
--num-attention-heads 40 \
--train-iters 300000 \
--resume-dataloader \
--train-data ${full_data} \
......@@ -38,24 +37,35 @@ gpt_options=" \
--warmup .1 \
--checkpoint-activations \
--deepspeed-activation-checkpointing \
--max-position-embeddings 5184 \
--max-position-embeddings 1089 \
--max-memory-length 0 \
--sandwich-ln \
--txt-loss-scale 10 \
--txt-loss-scale 0.1 \
--sparse-type cuda_2d \
--fp16 \
--save-interval 2000 \
--load data/checkpoints/cogview-compare
--no-load-optim \
--no-save-optim \
--eval-interval 1000 \
--save /root/checkpoints \
--fast-load \
--load data/checkpoints/cogview-continue \
--finetune
"
#
# --finetune
# --save $main_dir/data/checkpoints \
# --restart-iter 199000
gpt_options="${gpt_options}
--deepspeed \
--deepspeed_config ${config_json} \
--deepspeed \
--deepspeed_config ${config_json} \
"
run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_gpt2.py $@ ${gpt_options}"
echo ${run_cmd}
......
......@@ -21,11 +21,11 @@ config_json="$script_dir/ds_config.json"
gpt_options=" \
--experiment-name cogview-testlocal \
--img-tokenizer-num-tokens 8192 \
--dataset-type BinaryDataset \
--dataset-type CompactBinaryDataset \
--model-parallel-size ${MP_SIZE} \
--num-layers 16 \
--hidden-size 1024 \
--num-attention-heads 16 \
--num-layers 48 \
--hidden-size 2560 \
--num-attention-heads 40 \
--save $main_dir/data/checkpoints \
--train-iters 100000 \
--resume-dataloader \
......@@ -36,15 +36,19 @@ gpt_options=" \
--warmup .1 \
--checkpoint-activations \
--deepspeed-activation-checkpointing \
--max-position-embeddings 5184 \
--max-position-embeddings 1089 \
--max-memory-length 0 \
--txt-loss-scale 2 \
--txt-loss-scale 1 \
--sandwich-ln \
--sparse-type cuda_2d \
--sparse-type standard \
--save-interval 2500 \
--load data/checkpoints/cogview-fixgrad-small08-25-09-38
--fp16 \
--eval-iters 1000 \
--load pretrained/cogview/cogview-base
"
# --fp16 \
#
# --load data/checkpoints/cogview-fixgrad-small08-25-09-38
gpt_options="${gpt_options}
......
......@@ -6,7 +6,8 @@
# NHIDDEN=1024
# NATT=16
CHECKPOINT_PATH=pretrained/cogview/cogview-base
CHECKPOINT_PATH=data/checkpoints/cogview-continue
# CHECKPOINT_PATH=pretrained/cogview/cogview-base
NLAYERS=48
NHIDDEN=2560
NATT=40
......@@ -42,8 +43,8 @@ MASTER_PORT=${MASTER_PORT} python generate_samples.py \
--generation-task text2image \
--input-source ./input.txt \
--output-path samples_text2image \
--batch-size 4 \
--max-inference-batch-size 4 \
--batch-size 8 \
--max-inference-batch-size 8 \
--device 0 \
$@
......
......@@ -4,8 +4,7 @@ from tqdm import tqdm
import torch
import numpy as np
from mpu.sparse_transformer import standard_attention, sparse_attention_1d, sparse_attention_2d, sparse_attention_2dfull
from mpu.sparse_transformer import standard_attention, sparse_attention_2d_light
def test_sparse_attention_1d():
s, w, times = 4096 + 128, 128, 2
num_pivot = 768
......@@ -80,7 +79,7 @@ def test_sparse_attention_2d():
device = 'cuda'
b, n_head, hn = 2, 16, 1024
h = w = 32
layout = [10, 64, 64+h*w, 64+h*w*5]
layout = [64, 64+h*w, 64+h*w*5]
k1 = 9
k2 = 7
k1h = k1*2-1
......@@ -94,9 +93,6 @@ def test_sparse_attention_2d():
m = mask[0]
for i in range(layout[1]):
m[i, :i+1] = 1
m[layout[1]:, :layout[0]] = 1
for i in tqdm(range(layout[1], layout[2])):
m[i, layout[1]:i+1] = 1
# for i in tqdm(range(layout[1], layout[2])):
# x = (i - layout[1]) // w
# y = (i - layout[1]) % w
......@@ -106,15 +102,15 @@ def test_sparse_attention_2d():
# ry = min(w-1, y + k1 // 2)
# m[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1
# m[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1
for i in tqdm(range(layout[2], layout[3])):
x = (i - layout[2]) // (2*w)
y = (i - layout[2]) % (2*w)
for i in tqdm(range(layout[1], layout[2])):
x = (i - layout[1]) // (2*w)
y = (i - layout[1]) % (2*w)
lx = max(0, x - k1h // 2)
ly = max(0, y - k1 // 2)
rx = min(2*h-1, x + k1h // 2)
ry = min(2*w-1, y + k1 // 2)
m[i, layout[2]:layout[3]].view(h*2, w*2)[lx:x, ly:ry+1] = 1
m[i, layout[2]:layout[3]].view(h*2, w*2)[x, ly:y+1] = 1
m[i, layout[1]:layout[2]].view(h*2, w*2)[lx:x, ly:ry+1] = 1
m[i, layout[1]:layout[2]].view(h*2, w*2)[x, ly:y+1] = 1
x = x // 2
y = y // 2
......@@ -122,7 +118,7 @@ def test_sparse_attention_2d():
ly = max(0, y - k2 // 2)
rx = min(h-1, x + k2 // 2)
ry = min(w-1, y + k2 // 2)
m[i, layout[1]:layout[2]].view(h, w)[lx:rx+1, ly:ry+1] = 1
m[i, layout[0]:layout[1]].view(h, w)[lx:rx+1, ly:ry+1] = 1
mask[1:] = mask[0]
# mask[1][layout[1]:, layout[0]-1] = 0
......@@ -133,15 +129,18 @@ def test_sparse_attention_2d():
torch.cuda.synchronize()
t0 = time.time()
qkv_tmp = qkv.view(3, b, layout[-1], n_head, hn//n_head).permute(0, 1, 3, 2, 4).contiguous()
r1 = standard_attention(*qkv_tmp, mask.unsqueeze(1)).transpose(1, 2).reshape(b, layout[3], hn)
r1 = standard_attention(*qkv_tmp, mask.unsqueeze(1)).transpose(1, 2).reshape(b, layout[2], hn)
torch.cuda.synchronize()
t1 = time.time()
r2 = sparse_attention_2dfull(*qkv2, n_head, layout, mask[...,:layout[0]].unsqueeze(1), kernel_size=k1, kernel_size2=k2)
# r2 = sparse_attention_2dfull(*qkv2, n_head, layout, mask[...,:layout[0]].unsqueeze(1), kernel_size=k1, kernel_size2=k2)
qkv20, qkv21 = qkv2[:, :, :layout[1]], qkv2[:, :, layout[1]:]
r20, r21 = sparse_attention_2d_light(*qkv20, *qkv21, mask[...,:layout[1],:layout[1]].unsqueeze(1), n_head, layout[0],kernel_size=k1, kernel_size2=k2)
r2 = torch.cat((r20, r21), dim=1)
torch.cuda.synchronize()
t2 = time.time()
print('times: standard ', t1-t0, ' sparse ', t2-t1)
print(( (r1[:,:layout[0]]-r2[:,:layout[0]]).abs() / (r1[:,:layout[0]].abs()+r2[:,:layout[0]].abs())).max())
print(( (r1[:,:layout[1]]-r2[:,:layout[1]]).abs() / (r1[:,:layout[1]].abs()+r2[:,:layout[1]].abs())).max())
print(( (r1[:,layout[1]:]-r2[:,layout[1]:]).abs() / (r1[:,layout[1]:].abs()+r2[:,layout[1]:].abs())).max())
qkv.retain_grad()
l2 = r2[:,layout[1]:].sum()
......@@ -153,7 +152,7 @@ def test_sparse_attention_2d():
g1 = qkv.grad
g2 = qkv2.grad
print( (g1-g2).abs().max())
print( ((g1-g2).abs() / (g1.abs()+g2.abs()+1e-5)).max())
print( ((g1-g2).abs() / (g1.abs()+g2.abs()+1e-3)).max())
import pdb;pdb.set_trace()
......
......@@ -248,8 +248,34 @@ def save_ds_checkpoint(iteration, model, lr_scheduler, args):
sd['torch_rng_state'] = torch.get_rng_state()
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
if args.no_save_optim:
save_ds_checkpoint_no_optim(model, args.save, str(iteration), client_state=sd)
else:
model.save_checkpoint(args.save, str(iteration), client_state=sd)
def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save_latest=True):
os.makedirs(save_dir, exist_ok=True)
if tag is None:
tag = f"global_step{model.global_steps}"
# Ensure tag is a string
tag = str(tag)
# Ensure checkpoint tag is consistent across ranks
model._checkpoint_tag_validation(tag)
if model.save_non_zero_checkpoint:
model._create_checkpoint_file(save_dir, tag, False)
model._save_checkpoint(save_dir, tag, client_state=client_state)
# Save latest checkpoint tag
if save_latest:
with open(os.path.join(save_dir, 'latest'), 'w') as fd:
fd.write(tag)
model.save_checkpoint(args.save, str(iteration), client_state=sd)
return True
def get_checkpoint_iteration(args):
......@@ -296,8 +322,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states=
if args.deepspeed:
checkpoint_name, sd = model.load_checkpoint(args.load, iteration, load_optimizer_states=not args.no_load_optim)
if "client_lr_scheduler" in sd:
checkpoint_name, sd = model.load_checkpoint(args.load, iteration, load_optimizer_states=not args.no_load_optim, load_module_strict=not args.finetune)
if args.finetune:
model.module.module.init_plus_from_old()
if (args.finetune or args.no_load_optim) and model.zero_optimization():
model.optimizer.refresh_fp32_params()
if "client_lr_scheduler" in sd and not args.finetune:
lr_scheduler.load_state_dict(sd["client_lr_scheduler"])
print_rank_0("Load lr scheduler state")
if checkpoint_name is None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment