From a2625727b4feed5a7474edc32038ddda4ef2b3bc Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Sun, 10 Oct 2021 17:27:41 +0000
Subject: [PATCH] del old sampling

---
 generation/cuda_2d_sampling.py | 172 ---------------------
 generation/sampling.py         | 263 ---------------------------------
 2 files changed, 435 deletions(-)
 delete mode 100644 generation/cuda_2d_sampling.py
 delete mode 100755 generation/sampling.py

diff --git a/generation/cuda_2d_sampling.py b/generation/cuda_2d_sampling.py
deleted file mode 100644
index 7cd0440..0000000
--- a/generation/cuda_2d_sampling.py
+++ /dev/null
@@ -1,172 +0,0 @@
-from vqvae.vqvae_zc import Encoder
-from .sampling import *
-import math
-import sys
-from copy import deepcopy
-from torchvision.utils import save_image
-def filling_sequence_cuda_2d(
-        model, 
-        seq, 
-        args, 
-        mems=None, 
-        invalid_slices=[], 
-        **kwargs):
-    '''
-        seq: [id[ROI1], 10000, 20000, id[BASE], id[BOI1], 1024 * -1/known tokens, id[EOI1], 4096 * -1..., ]
-    '''
-    tokenizer = get_tokenizer()
-    invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
-    device = seq.device
-    assert args.sparse_config.sparse_type == 'cuda_2d'
-    std_config = deepcopy(args.sparse_config)
-    std_config.sparse_type = 'standard'
-    sparse_config = args.sparse_config
-    # split two parts
-    seq0, seq1 = seq[:-4097], seq[-4097:] # +1 for EOI1
-    # generate a batch of seq0
-    model.module.transformer.reset_sparse_config(std_config)
-    args.sparse_config = std_config
-    output0 = filling_sequence(model, seq0, args)
-    model.module.transformer.reset_sparse_config(sparse_config)
-    args.sparse_config = sparse_config
-    model.module.transformer.max_memory_length = 0
-
-
-    # filter bad generation & select top N=2, TODO
-    output0 = output0
-
-    from torchvision import transforms
-    tr = transforms.Compose([
-        transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), 
-    ])
-    imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1024:].tolist())) for x in output0] # ground truth
-    blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(device), add_normalization=True) # blured image as init value
-
-    # pad seq to desired shape
-    n_pad = args.layout[1] - len(seq0)
-    batch_size = output0.shape[0]
-    assert n_pad > 0, "You should truncate long input before filling."
-    seq = torch.cat((
-        torch.tensor([tokenizer['[PAD]']]* n_pad, device=seq.device, dtype=seq.dtype)
-            .unsqueeze(0).expand(batch_size, n_pad),
-        output0,
-        seq1.unsqueeze(0).expand(batch_size, len(seq1))    
-        ), dim=1
-    ) # [b, layout[-1]]
-
-    # init 
-    step_cnt = 0
-    tokens = seq[:, :-1].clone()
-    unfixed = (seq < 0)
-    # tokens[unfixed[:, :-1]] = tokens[unfixed[:, :-1]].random_(0, tokenizer.img_tokenizer.num_tokens)
-    tokens[:, -4095:] = blur64[:, :-1]
-    attention_mask = torch.ones(args.layout[1], args.layout[1]).tril().to(device)
-    attention_mask[n_pad:, :n_pad] = 0
-    position_ids = torch.cat((
-        torch.zeros(n_pad, dtype=torch.long),
-        torch.arange(0, args.layout[1] - n_pad), 
-        torch.arange(0, args.layout[2]-args.layout[1]))).to(device)
-    # iterate
-    imgs = []
-    # import pdb;pdb.set_trace()
-    while unfixed.sum() > 0:
-        print(unfixed.sum())
-        logits, *_dump = model(tokens, position_ids, attention_mask)
-        step_cnt += 1
-
-        # warmup 
-        real_topk = 10
-        warmup_steps = 3
-        iterative_step= warmup_steps + 6
-        if step_cnt <= warmup_steps:
-            real_temp = 0.1
-        elif step_cnt == warmup_steps + 1:
-            real_temp = 0.55
-        elif step_cnt > warmup_steps + 1:
-            real_temp = 0.45
-        # if  5 < step_cnt:
-        #     real_topk = 200
-        # sampling
-        for invalid_slice in invalid_slices: # forbide to generate other tokens
-            logits[..., invalid_slice] = -float('Inf')
-        assert args.top_k > 0
-        
-        # probs0 = F.softmax(logits/real_temp, dim=-1)
-        topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1)
-        ent = -(topraw * topraw.log()).sum(dim=-1)
-        # topsum = topraw.sum(dim=-1)
-        if step_cnt > warmup_steps:
-            # import pdb;pdb.set_trace()
-            real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.3).unsqueeze(-1) + 0.6
-            # import pdb;pdb.set_trace()
-        else:
-            real_temp2 = real_temp
-        # import pdb;pdb.set_trace()
-        probs = F.softmax(logits/real_temp2, dim=-1)
-        tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1)
-        prev = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
-        edge_idx = tk_idx[:, :, -1:]
-        edge_value = tk_value[:, :, -1:]
-        edge_mask = probs.gather(dim=-1, index=prev) < edge_value
-        prev[edge_mask] = edge_idx[edge_mask]
-        prev.squeeze_(-1)
-        # tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1])
-        # prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1)
-        # prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1)
-        # update unfixed
-        choice = 1
-        if choice == 0 and warmup_steps < step_cnt:
-            # mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
-            # # import pdb;pdb.set_trace()
-            # dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:])
-
-            # new_fixed = unfixed.clone()
-            # moved_new_fixed = new_fixed[:, 2:]
-            # moved_new_fixed &= dprob
-            # moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not()
-            # moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not()
-            # # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not()
-            # moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not()
-            # moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not()
-            # # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not()
-            pass
-        elif choice == 1 and warmup_steps < step_cnt:
-            new_fixed = unfixed & False
-            ll, rr = 4, 4
-            for x in range(min(ll, step_cnt - warmup_steps)):
-                y = step_cnt - warmup_steps - x - 1
-                if y < rr:
-                    print(x,y)
-                    new_fixed[..., -4096:].view(batch_size, 64//ll, ll, 64//rr, rr)[:, :, x, :, y] = True
-            new_fixed &= unfixed
-        else:
-            new_fixed = unfixed & False # TODO
-        new_fixed[:, -1] = True
-
-        # with open(f'bed{step_cnt}.txt', 'w') as fout:
-        #     for i, prob in enumerate(topraw[0, -4096:]):
-        #         s = ' '.join([str(x) for x in prob.tolist()])
-        #         fout.write(f'{i} {s}\n')
-
-        unfixed &= new_fixed.logical_not()
-        # update seq and tokens
-        seq[new_fixed] = prev[new_fixed[:, 1:]]
-        tokens = seq[:, :-1].clone()
-        tokens[:,1:][unfixed[:, 1:-1]] = prev[:, :-1][unfixed[:, 1:-1]]
-
-        if step_cnt == iterative_step: 
-            seq[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
-            n_unfixed = unfixed.sum(dim=-1).tolist()
-            print(f'Exit with {n_unfixed} unfixed tokens.')
-            break
-        if args.debug:
-            from torchvision.utils import save_image
-            seqt = seq.clone()
-            seqt[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
-            imgs.extend([tokenizer.img_tokenizer.DecodeIds(s[-4096:]) for s in seqt])
-    if args.debug:
-        imgs = torch.cat(imgs, dim=0)
-        save_image(imgs, f'steps{device}.jpg', normalize=True)
-    model.module.transformer.max_memory_length = args.max_memory_length
-
-    return seq
\ No newline at end of file
diff --git a/generation/sampling.py b/generation/sampling.py
deleted file mode 100755
index f6b543d..0000000
--- a/generation/sampling.py
+++ /dev/null
@@ -1,263 +0,0 @@
-# -*- encoding: utf-8 -*-
-'''
-@File    :   sampling.py
-@Time    :   2021/01/13 19:52:12
-@Author  :   Ming Ding 
-@Contact :   dm18@mails.tsinghua.edu.cn
-'''
-
-# here put the import lib
-import os
-import sys
-import math
-import random
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-
-from pretrain_gpt2 import get_masks_and_position_ids
-from data_utils import get_tokenizer
-from copy import deepcopy
-
-def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
-    # This function has been mostly taken from huggingface conversational ai code at
-    # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
-
-    if top_k > 0:
-        # Remove all tokens with a probability less than the last token of the top-k
-        # s1 = (logits-logits.max()).exp().sum()
-        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
-        logits[indices_to_remove] = filter_value      
-        # s2 = (logits-logits.max()).exp().sum()
-        # with open('lion.txt', 'a') as fout:
-        #     fout.write(f'{s1} {s2}\n')
-
-    if top_p > 0.0:
-        # convert to 1D
-        logits = logits.view(logits.size()[1]).contiguous()
-        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
-        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
-
-        # Remove tokens with cumulative probability above the threshold
-        sorted_indices_to_remove = cumulative_probs > top_p
-        # Shift the indices to the right to keep also the first token above the threshold
-        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
-        sorted_indices_to_remove[..., 0] = 0
-        indices_to_remove = sorted_indices[sorted_indices_to_remove]
-        logits[indices_to_remove] = filter_value
-        # going back to 2D
-        logits = logits.view(1, -1).contiguous()
-
-    return logits
-
-def get_batch(context_tokens, device, args):
-    tokens = context_tokens
-    if len(tokens.shape) == 1:
-        tokens = tokens.unsqueeze(0).contiguous()
-    else:
-        tokens = tokens.view(tokens.shape[0], -1).contiguous()
-    tokens = tokens.to(device)
-
-    # Get the masks and postition ids.
-    attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
-        tokens, args=args)
-    return tokens, attention_mask, position_ids
-
-def update_mems(hiddens, mems, max_memory_length=10000):
-    memory_length = mems[0].size(1) if mems else 0
-    query_length = hiddens[0].size(1)
-    new_memory_length = min(max_memory_length, memory_length + query_length)
-    new_mems = []
-    with torch.no_grad():
-        for i in range(len(hiddens)):
-            if new_memory_length <= query_length:
-                new_mems.append(hiddens[i][:, -new_memory_length:])
-            else:
-                new_mems.append(torch.cat((mems[i][:, -new_memory_length+query_length:], hiddens[i]), dim=1))
-    return new_mems
-
-def filling_sequence(
-        model, 
-        seq, 
-        args, 
-        mems=None, 
-        invalid_slices=[], 
-        **kwargs):
-    '''
-        seq: [2, 3, 5, ..., -1(to be generated), -N (N beams), -1]
-        context_length: first non(-1)s
-    '''
-    tokenizer = get_tokenizer()
-    device = seq.device
-    assert len(seq.shape) == 1
-    out_seq_length = len(seq)
-    # building the initial tokens, attention_mask, and position_ids
-    context_length = 0
-    offset = 100000
-
-    invalid_slices = [slice(0, tokenizer.img_tokenizer.num_tokens)]
-
-    while seq[context_length] >= 0:
-        # change what to generate
-        if seq[context_length] in [tokenizer['[BOI1]'], tokenizer['[BOI2]']]:
-            invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
-        elif seq[context_length] in [tokenizer['[EOI1]'], tokenizer['[EOI2]']]:
-            invalid_slices = [
-                slice(0, tokenizer.img_tokenizer.num_tokens),
-                slice(tokenizer.img_tokenizer.num_tokens + tokenizer.txt_tokenizer.num_tokens, None)]
-
-        if seq[context_length] == tokenizer['[ROI2]']:
-            offset = context_length
-        context_length += 1
-    tokens, attention_mask, position_ids = get_batch(seq[:context_length], device, args)
-    txt_len = seq.tolist().index(tokenizer['[BASE]'])
-    print('txt_len:', txt_len)
-    config = deepcopy(model.module.transformer.sparse_config)
-    ori_config = model.module.transformer.sparse_config
-    config.layout[0] = txt_len
-    model.module.transformer.reset_sparse_config(config)
-
-    counter = context_length - 1 # == len(tokens) - 1
-    index = 0 # len(mems)
-    if mems is None:
-        mems = []
-    score = [0] # sum log likelihood for beams
-    
-    while counter < (out_seq_length - 1):
-        # Now, we want to generate seq[counter + 1]
-        # token[:, index: counter+1] are just added.
-
-        if seq[counter + 1] in [tokenizer['[BOI1]'], tokenizer['[BOI2]']]:
-            invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
-        elif seq[counter + 1] in [tokenizer['[EOI1]'], tokenizer['[EOI2]']]:
-            invalid_slices = [
-                slice(0, tokenizer.img_tokenizer.num_tokens),
-                slice(tokenizer.img_tokenizer.num_tokens + tokenizer.txt_tokenizer.num_tokens, None)]
-
-        if index == 0: # first 
-            position_ids[position_ids > offset] -= offset
-            logits, *qkv = model(tokens, position_ids, attention_mask, *mems)
-            mems = update_mems(qkv, mems)
-
-            # tmp = -F.log_softmax(logits, dim=-1)
-            # tmp = tmp[0,:-1].gather(dim=-1,index=tokens[0,1:].unsqueeze(-1))[4:,0]
-            # for i in range(1,len(tmp)):
-            #     print(i, tmp[i].item())
-            index = counter
-            # print(tmp[1:].mean(), file=sys.stderr)
-        elif seq[counter + 1] >= 0: # provided
-            if seq[counter + 1] == tokenizer['[ROI2]']:
-                offset = counter + 1
-            tokens, mems, score = shrink_beams(tokens, mems, 1, score)
-            nb = 1
-            counter += 1
-            tokens = torch.cat((tokens, seq[counter: counter+1].expand(tokens.shape[0], 1)), dim=1)
-            continue
-        else:
-            assert tokens.shape[1] == counter + 1 
-            position_ids = torch.arange(index, counter + 1, dtype=torch.long, device=tokens.device).unsqueeze(0)
-            position_ids[position_ids > offset] -= offset
-            # TODO each time, the feed input cannot be too long (window size), or it will have a discrepcy from sparse training, but this is not very important. 
-            tokens, mems, score = shrink_beams(tokens, mems, -seq[counter + 1], score)
-            logits, *qkv = model(tokens[:, index: ], 
-                position_ids,
-                0, # rebuild in transformers (sep version)
-                *mems)
-            mems = update_mems(qkv, mems)
-
-            index = counter
-        nb = -seq[counter + 1]
-        counter += 1
-        index += 1
-
-
-        logits = logits[:, -1] # [batch size, vocab size]
-
-        temp = args.temperature
-        real_topk = args.top_k
-        if counter <= context_length + 32:
-            real_topk = 80
-        # else:
-            # real_topk = args.top_k
-        # if counter == context_length + 32 + 12:
-        #     import pdb;pdb.set_trace()
-        # TODO since the temperature is crucial, how can we find a good setting?
-        logits /= temp
-        for invalid_slice in invalid_slices: #   to generate other tokens
-            logits[..., invalid_slice] = -float('Inf')
-        # logits = top_k_logits(logits, top_k=real_topk, top_p=args.top_p)
-        probs = F.softmax(logits, dim=-1)
-
-        tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1)
-
-        # expand beams
-        if nb > 1 and tokens.shape[0] == 1: # 1->nb
-            tokens = tokens.expand(nb, -1).contiguous()
-            mems = [mem.expand(nb, -1, -1) for mem in mems]
-            prev = torch.multinomial(probs, num_samples=nb, replacement=True)
-            score = torch.log(torch.gather(probs, dim=1, index=prev)[0]).tolist()
-        else: # nb -> nb
-            assert tokens.shape[0] == nb
-            prev = torch.multinomial(probs, num_samples=1)
-            for j in range(0, prev.shape[0]):
-                if probs[j, prev[j,-1]] < tk_value[j, -1]:
-                    prev[j, -1] = tk_idx[j,torch.randint(tk_idx.shape[-1]-100, tk_idx.shape[-1], (1,))]
-                    # prev[j, -1] = tk_idx[j,torch.randint(0, tk_idx.shape[-1], (1,))]
-
-            score_plus = torch.log(torch.gather(probs, dim=1, index=prev)[:, 0])
-            for idx in range(nb):
-                score[idx] += score_plus[idx]
-        
-        tokens = torch.cat((tokens, prev.view(tokens.shape[0], 1)), dim=1)
-
-    output_tokens_list = tokens.view(tokens.shape[0], -1).contiguous()
-    model.module.transformer.reset_sparse_config(ori_config)
-    return output_tokens_list
-
-def shrink_beams(tokens, mems, nb, score):
-    # beam search is a failed attempt, will be removed soon...
-    if tokens.shape[0] == nb:
-        return tokens, mems, score
-    # shrink
-    maximum = max(score)
-    max_idx = score.index(maximum)
-    tokens = tokens[max_idx].unsqueeze(0)
-    score = [0]
-    new_mems = [mem[max_idx: max_idx + 1] for mem in mems]
-    return tokens, new_mems, score
-
-def add_interlacing_beam_marks(seq, nb=12, period=30000):
-    assert isinstance(seq, list) or len(seq.shape) == 1
-    blk_cnt = 0
-    for i in range(len(seq)):
-        if seq[i] == -1:
-            blk_cnt += 1
-            seq[i] = -nb
-            if blk_cnt == period:
-                nb += (nb % 2) * 2 - 1
-                blk_cnt = 0
-        else:
-            blk_cnt = 0
-
-
-def inverse_prompt_score(model, seq, args):
-    tokenizer = get_tokenizer()
-    device = seq.device
-    assert len(seq.shape) == 2
-
-    botext = 2 + 1024 + 1
-    assert tokenizer['[ROI1]'] == seq[0][botext]
-
-    tokens, attention_mask, position_ids = get_batch(seq, device, args)
-    logits, *qkv = model(tokens, position_ids, attention_mask)
-    mems = update_mems(qkv, mems)
-
-    logits[..., :tokenizer.img_tokenizer.num_tokens] = -float('Inf')
-    log_probs = torch.log(F.softmax(logits, dim=-1))
-
-    pred = log_probs[:, botext:-1, :] 
-    target = tokens[:, botext+1:].unsqueeze(-1) 
-    scores = torch.gather(pred, dim=2, index=target).squeeze(-1).sum(dim=-1)
-    return scores
-            
\ No newline at end of file
-- 
GitLab