Skip to content
Snippets Groups Projects
cuda_2d_sampling.py 7.32 KiB
Newer Older
  • Learn to ignore specific revisions
  • Ming Ding's avatar
    Ming Ding committed
    from vqvae.vqvae_zc import Encoder
    
    Ming Ding's avatar
    Ming Ding committed
    from .sampling import *
    import math
    import sys
    from copy import deepcopy
    from torchvision.utils import save_image
    def filling_sequence_cuda_2d(
            model, 
            seq, 
            args, 
            mems=None, 
            invalid_slices=[], 
            **kwargs):
        '''
            seq: [id[ROI1], 10000, 20000, id[BASE], id[BOI1], 1024 * -1/known tokens, id[EOI1], 4096 * -1..., ]
        '''
        tokenizer = get_tokenizer()
        invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
        device = seq.device
        assert args.sparse_config.sparse_type == 'cuda_2d'
        std_config = deepcopy(args.sparse_config)
        std_config.sparse_type = 'standard'
        sparse_config = args.sparse_config
        # split two parts
        seq0, seq1 = seq[:-4097], seq[-4097:] # +1 for EOI1
        # generate a batch of seq0
        model.module.transformer.reset_sparse_config(std_config)
        args.sparse_config = std_config
        output0 = filling_sequence(model, seq0, args)
        model.module.transformer.reset_sparse_config(sparse_config)
        args.sparse_config = sparse_config
        model.module.transformer.max_memory_length = 0
    
    
        # filter bad generation & select top N=2, TODO
        output0 = output0
    
        from torchvision import transforms
        tr = transforms.Compose([
    
    Ming Ding's avatar
    Ming Ding committed
            transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), 
    
    Ming Ding's avatar
    Ming Ding committed
        ])
        imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1024:].tolist())) for x in output0] # ground truth
        blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(device), add_normalization=True) # blured image as init value
    
        # pad seq to desired shape
        n_pad = args.layout[1] - len(seq0)
        batch_size = output0.shape[0]
        assert n_pad > 0, "You should truncate long input before filling."
        seq = torch.cat((
            torch.tensor([tokenizer['[PAD]']]* n_pad, device=seq.device, dtype=seq.dtype)
                .unsqueeze(0).expand(batch_size, n_pad),
            output0,
            seq1.unsqueeze(0).expand(batch_size, len(seq1))    
            ), dim=1
        ) # [b, layout[-1]]
    
        # init 
        step_cnt = 0
        tokens = seq[:, :-1].clone()
        unfixed = (seq < 0)
        # tokens[unfixed[:, :-1]] = tokens[unfixed[:, :-1]].random_(0, tokenizer.img_tokenizer.num_tokens)
        tokens[:, -4095:] = blur64[:, :-1]
        attention_mask = torch.ones(args.layout[1], args.layout[1]).tril().to(device)
        attention_mask[n_pad:, :n_pad] = 0
        position_ids = torch.cat((
            torch.zeros(n_pad, dtype=torch.long),
            torch.arange(0, args.layout[1] - n_pad), 
            torch.arange(0, args.layout[2]-args.layout[1]))).to(device)
        # iterate
        imgs = []
        # import pdb;pdb.set_trace()
        while unfixed.sum() > 0:
            print(unfixed.sum())
            logits, *_dump = model(tokens, position_ids, attention_mask)
            step_cnt += 1
    
            # warmup 
    
    Ming Ding's avatar
    Ming Ding committed
            real_topk = 10
            warmup_steps = 3
            iterative_step= warmup_steps + 6
            if step_cnt <= warmup_steps:
                real_temp = 0.1
            elif step_cnt == warmup_steps + 1:
    
    Ming Ding's avatar
    Ming Ding committed
                real_temp = 0.55
    
    Ming Ding's avatar
    Ming Ding committed
            elif step_cnt > warmup_steps + 1:
    
    Ming Ding's avatar
    Ming Ding committed
                real_temp = 0.45
    
    Ming Ding's avatar
    Ming Ding committed
            # if  5 < step_cnt:
            #     real_topk = 200
    
    Ming Ding's avatar
    Ming Ding committed
            # sampling
            for invalid_slice in invalid_slices: # forbide to generate other tokens
                logits[..., invalid_slice] = -float('Inf')
            assert args.top_k > 0
    
    Ming Ding's avatar
    Ming Ding committed
            
    
    Ming Ding's avatar
    Ming Ding committed
            # probs0 = F.softmax(logits/real_temp, dim=-1)
            topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1)
            ent = -(topraw * topraw.log()).sum(dim=-1)
            # topsum = topraw.sum(dim=-1)
    
    Ming Ding's avatar
    Ming Ding committed
            if step_cnt > warmup_steps:
    
    Ming Ding's avatar
    Ming Ding committed
                # import pdb;pdb.set_trace()
    
    Ming Ding's avatar
    Ming Ding committed
                real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.3).unsqueeze(-1) + 0.6
    
    Ming Ding's avatar
    Ming Ding committed
                # import pdb;pdb.set_trace()
            else:
                real_temp2 = real_temp
            # import pdb;pdb.set_trace()
            probs = F.softmax(logits/real_temp2, dim=-1)
            tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1)
            prev = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
            edge_idx = tk_idx[:, :, -1:]
            edge_value = tk_value[:, :, -1:]
            edge_mask = probs.gather(dim=-1, index=prev) < edge_value
            prev[edge_mask] = edge_idx[edge_mask]
            prev.squeeze_(-1)
            # tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1])
            # prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1)
            # prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1)
    
    Ming Ding's avatar
    Ming Ding committed
            # update unfixed
            choice = 1
    
    Ming Ding's avatar
    Ming Ding committed
            if choice == 0 and warmup_steps < step_cnt:
    
                # mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
                # # import pdb;pdb.set_trace()
                # dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:])
    
    Ming Ding's avatar
    Ming Ding committed
    
    
                # new_fixed = unfixed.clone()
                # moved_new_fixed = new_fixed[:, 2:]
                # moved_new_fixed &= dprob
                # moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not()
                # moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not()
                # # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not()
                # moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not()
                # moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not()
                # # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not()
                pass
    
    Ming Ding's avatar
    Ming Ding committed
            elif choice == 1 and warmup_steps < step_cnt:
    
    Ming Ding's avatar
    Ming Ding committed
                new_fixed = unfixed & False
    
    Ming Ding's avatar
    Ming Ding committed
                ll, rr = 4, 4
                for x in range(min(ll, step_cnt - warmup_steps)):
                    y = step_cnt - warmup_steps - x - 1
                    if y < rr:
                        print(x,y)
                        new_fixed[..., -4096:].view(batch_size, 64//ll, ll, 64//rr, rr)[:, :, x, :, y] = True
    
    Ming Ding's avatar
    Ming Ding committed
                new_fixed &= unfixed
    
    Ming Ding's avatar
    Ming Ding committed
            else:
                new_fixed = unfixed & False # TODO
            new_fixed[:, -1] = True
    
    Ming Ding's avatar
    Ming Ding committed
    
    
    Ming Ding's avatar
    Ming Ding committed
            # with open(f'bed{step_cnt}.txt', 'w') as fout:
            #     for i, prob in enumerate(topraw[0, -4096:]):
            #         s = ' '.join([str(x) for x in prob.tolist()])
            #         fout.write(f'{i} {s}\n')
    
    Ming Ding's avatar
    Ming Ding committed
    
    
    Ming Ding's avatar
    Ming Ding committed
            unfixed &= new_fixed.logical_not()
            # update seq and tokens
            seq[new_fixed] = prev[new_fixed[:, 1:]]
            tokens = seq[:, :-1].clone()
            tokens[:,1:][unfixed[:, 1:-1]] = prev[:, :-1][unfixed[:, 1:-1]]
    
            if step_cnt == iterative_step: 
                seq[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
                n_unfixed = unfixed.sum(dim=-1).tolist()
                print(f'Exit with {n_unfixed} unfixed tokens.')
                break
            if args.debug:
                from torchvision.utils import save_image
                seqt = seq.clone()
                seqt[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
                imgs.extend([tokenizer.img_tokenizer.DecodeIds(s[-4096:]) for s in seqt])
        if args.debug:
            imgs = torch.cat(imgs, dim=0)
            save_image(imgs, f'steps{device}.jpg', normalize=True)
        model.module.transformer.max_memory_length = args.max_memory_length
    
        return seq