from vqvae.vqvae_zc import Encoder from .sampling import * import math import sys from copy import deepcopy from torchvision.utils import save_image def filling_sequence_cuda_2d( model, seq, args, mems=None, invalid_slices=[], **kwargs): ''' seq: [id[ROI1], 10000, 20000, id[BASE], id[BOI1], 1024 * -1/known tokens, id[EOI1], 4096 * -1..., ] ''' tokenizer = get_tokenizer() invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] device = seq.device assert args.sparse_config.sparse_type == 'cuda_2d' std_config = deepcopy(args.sparse_config) std_config.sparse_type = 'standard' sparse_config = args.sparse_config # split two parts seq0, seq1 = seq[:-4097], seq[-4097:] # +1 for EOI1 # generate a batch of seq0 model.module.transformer.reset_sparse_config(std_config) args.sparse_config = std_config output0 = filling_sequence(model, seq0, args) model.module.transformer.reset_sparse_config(sparse_config) args.sparse_config = sparse_config model.module.transformer.max_memory_length = 0 # filter bad generation & select top N=2, TODO output0 = output0 from torchvision import transforms tr = transforms.Compose([ transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), ]) imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1024:].tolist())) for x in output0] # ground truth blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(device), add_normalization=True) # blured image as init value # pad seq to desired shape n_pad = args.layout[1] - len(seq0) batch_size = output0.shape[0] assert n_pad > 0, "You should truncate long input before filling." seq = torch.cat(( torch.tensor([tokenizer['[PAD]']]* n_pad, device=seq.device, dtype=seq.dtype) .unsqueeze(0).expand(batch_size, n_pad), output0, seq1.unsqueeze(0).expand(batch_size, len(seq1)) ), dim=1 ) # [b, layout[-1]] # init step_cnt = 0 tokens = seq[:, :-1].clone() unfixed = (seq < 0) # tokens[unfixed[:, :-1]] = tokens[unfixed[:, :-1]].random_(0, tokenizer.img_tokenizer.num_tokens) tokens[:, -4095:] = blur64[:, :-1] attention_mask = torch.ones(args.layout[1], args.layout[1]).tril().to(device) attention_mask[n_pad:, :n_pad] = 0 position_ids = torch.cat(( torch.zeros(n_pad, dtype=torch.long), torch.arange(0, args.layout[1] - n_pad), torch.arange(0, args.layout[2]-args.layout[1]))).to(device) # iterate imgs = [] # import pdb;pdb.set_trace() while unfixed.sum() > 0: print(unfixed.sum()) logits, *_dump = model(tokens, position_ids, attention_mask) step_cnt += 1 # warmup real_topk = 10 warmup_steps = 3 iterative_step= warmup_steps + 6 if step_cnt <= warmup_steps: real_temp = 0.1 elif step_cnt == warmup_steps + 1: real_temp = 0.55 elif step_cnt > warmup_steps + 1: real_temp = 0.45 # if 5 < step_cnt: # real_topk = 200 # sampling for invalid_slice in invalid_slices: # forbide to generate other tokens logits[..., invalid_slice] = -float('Inf') assert args.top_k > 0 # probs0 = F.softmax(logits/real_temp, dim=-1) topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1) ent = -(topraw * topraw.log()).sum(dim=-1) # topsum = topraw.sum(dim=-1) if step_cnt > warmup_steps: # import pdb;pdb.set_trace() real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.3).unsqueeze(-1) + 0.6 # import pdb;pdb.set_trace() else: real_temp2 = real_temp # import pdb;pdb.set_trace() probs = F.softmax(logits/real_temp2, dim=-1) tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1) prev = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1) edge_idx = tk_idx[:, :, -1:] edge_value = tk_value[:, :, -1:] edge_mask = probs.gather(dim=-1, index=prev) < edge_value prev[edge_mask] = edge_idx[edge_mask] prev.squeeze_(-1) # tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1]) # prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1) # prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1) # update unfixed choice = 1 if choice == 0 and warmup_steps < step_cnt: # mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2])) # # import pdb;pdb.set_trace() # dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:]) # new_fixed = unfixed.clone() # moved_new_fixed = new_fixed[:, 2:] # moved_new_fixed &= dprob # moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not() # moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not() # # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not() # moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not() # moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not() # # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not() pass elif choice == 1 and warmup_steps < step_cnt: new_fixed = unfixed & False ll, rr = 4, 4 for x in range(min(ll, step_cnt - warmup_steps)): y = step_cnt - warmup_steps - x - 1 if y < rr: print(x,y) new_fixed[..., -4096:].view(batch_size, 64//ll, ll, 64//rr, rr)[:, :, x, :, y] = True new_fixed &= unfixed else: new_fixed = unfixed & False # TODO new_fixed[:, -1] = True # with open(f'bed{step_cnt}.txt', 'w') as fout: # for i, prob in enumerate(topraw[0, -4096:]): # s = ' '.join([str(x) for x in prob.tolist()]) # fout.write(f'{i} {s}\n') unfixed &= new_fixed.logical_not() # update seq and tokens seq[new_fixed] = prev[new_fixed[:, 1:]] tokens = seq[:, :-1].clone() tokens[:,1:][unfixed[:, 1:-1]] = prev[:, :-1][unfixed[:, 1:-1]] if step_cnt == iterative_step: seq[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step n_unfixed = unfixed.sum(dim=-1).tolist() print(f'Exit with {n_unfixed} unfixed tokens.') break if args.debug: from torchvision.utils import save_image seqt = seq.clone() seqt[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step imgs.extend([tokenizer.img_tokenizer.DecodeIds(s[-4096:]) for s in seqt]) if args.debug: imgs = torch.cat(imgs, dim=0) save_image(imgs, f'steps{device}.jpg', normalize=True) model.module.transformer.max_memory_length = args.max_memory_length return seq