From a2625727b4feed5a7474edc32038ddda4ef2b3bc Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Sun, 10 Oct 2021 17:27:41 +0000 Subject: [PATCH] del old sampling --- generation/cuda_2d_sampling.py | 172 --------------------- generation/sampling.py | 263 --------------------------------- 2 files changed, 435 deletions(-) delete mode 100644 generation/cuda_2d_sampling.py delete mode 100755 generation/sampling.py diff --git a/generation/cuda_2d_sampling.py b/generation/cuda_2d_sampling.py deleted file mode 100644 index 7cd0440..0000000 --- a/generation/cuda_2d_sampling.py +++ /dev/null @@ -1,172 +0,0 @@ -from vqvae.vqvae_zc import Encoder -from .sampling import * -import math -import sys -from copy import deepcopy -from torchvision.utils import save_image -def filling_sequence_cuda_2d( - model, - seq, - args, - mems=None, - invalid_slices=[], - **kwargs): - ''' - seq: [id[ROI1], 10000, 20000, id[BASE], id[BOI1], 1024 * -1/known tokens, id[EOI1], 4096 * -1..., ] - ''' - tokenizer = get_tokenizer() - invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] - device = seq.device - assert args.sparse_config.sparse_type == 'cuda_2d' - std_config = deepcopy(args.sparse_config) - std_config.sparse_type = 'standard' - sparse_config = args.sparse_config - # split two parts - seq0, seq1 = seq[:-4097], seq[-4097:] # +1 for EOI1 - # generate a batch of seq0 - model.module.transformer.reset_sparse_config(std_config) - args.sparse_config = std_config - output0 = filling_sequence(model, seq0, args) - model.module.transformer.reset_sparse_config(sparse_config) - args.sparse_config = sparse_config - model.module.transformer.max_memory_length = 0 - - - # filter bad generation & select top N=2, TODO - output0 = output0 - - from torchvision import transforms - tr = transforms.Compose([ - transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), - ]) - imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1024:].tolist())) for x in output0] # ground truth - blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(device), add_normalization=True) # blured image as init value - - # pad seq to desired shape - n_pad = args.layout[1] - len(seq0) - batch_size = output0.shape[0] - assert n_pad > 0, "You should truncate long input before filling." - seq = torch.cat(( - torch.tensor([tokenizer['[PAD]']]* n_pad, device=seq.device, dtype=seq.dtype) - .unsqueeze(0).expand(batch_size, n_pad), - output0, - seq1.unsqueeze(0).expand(batch_size, len(seq1)) - ), dim=1 - ) # [b, layout[-1]] - - # init - step_cnt = 0 - tokens = seq[:, :-1].clone() - unfixed = (seq < 0) - # tokens[unfixed[:, :-1]] = tokens[unfixed[:, :-1]].random_(0, tokenizer.img_tokenizer.num_tokens) - tokens[:, -4095:] = blur64[:, :-1] - attention_mask = torch.ones(args.layout[1], args.layout[1]).tril().to(device) - attention_mask[n_pad:, :n_pad] = 0 - position_ids = torch.cat(( - torch.zeros(n_pad, dtype=torch.long), - torch.arange(0, args.layout[1] - n_pad), - torch.arange(0, args.layout[2]-args.layout[1]))).to(device) - # iterate - imgs = [] - # import pdb;pdb.set_trace() - while unfixed.sum() > 0: - print(unfixed.sum()) - logits, *_dump = model(tokens, position_ids, attention_mask) - step_cnt += 1 - - # warmup - real_topk = 10 - warmup_steps = 3 - iterative_step= warmup_steps + 6 - if step_cnt <= warmup_steps: - real_temp = 0.1 - elif step_cnt == warmup_steps + 1: - real_temp = 0.55 - elif step_cnt > warmup_steps + 1: - real_temp = 0.45 - # if 5 < step_cnt: - # real_topk = 200 - # sampling - for invalid_slice in invalid_slices: # forbide to generate other tokens - logits[..., invalid_slice] = -float('Inf') - assert args.top_k > 0 - - # probs0 = F.softmax(logits/real_temp, dim=-1) - topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1) - ent = -(topraw * topraw.log()).sum(dim=-1) - # topsum = topraw.sum(dim=-1) - if step_cnt > warmup_steps: - # import pdb;pdb.set_trace() - real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.3).unsqueeze(-1) + 0.6 - # import pdb;pdb.set_trace() - else: - real_temp2 = real_temp - # import pdb;pdb.set_trace() - probs = F.softmax(logits/real_temp2, dim=-1) - tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1) - prev = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1) - edge_idx = tk_idx[:, :, -1:] - edge_value = tk_value[:, :, -1:] - edge_mask = probs.gather(dim=-1, index=prev) < edge_value - prev[edge_mask] = edge_idx[edge_mask] - prev.squeeze_(-1) - # tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1]) - # prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1) - # prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1) - # update unfixed - choice = 1 - if choice == 0 and warmup_steps < step_cnt: - # mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2])) - # # import pdb;pdb.set_trace() - # dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:]) - - # new_fixed = unfixed.clone() - # moved_new_fixed = new_fixed[:, 2:] - # moved_new_fixed &= dprob - # moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not() - # moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not() - # # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not() - # moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not() - # moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not() - # # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not() - pass - elif choice == 1 and warmup_steps < step_cnt: - new_fixed = unfixed & False - ll, rr = 4, 4 - for x in range(min(ll, step_cnt - warmup_steps)): - y = step_cnt - warmup_steps - x - 1 - if y < rr: - print(x,y) - new_fixed[..., -4096:].view(batch_size, 64//ll, ll, 64//rr, rr)[:, :, x, :, y] = True - new_fixed &= unfixed - else: - new_fixed = unfixed & False # TODO - new_fixed[:, -1] = True - - # with open(f'bed{step_cnt}.txt', 'w') as fout: - # for i, prob in enumerate(topraw[0, -4096:]): - # s = ' '.join([str(x) for x in prob.tolist()]) - # fout.write(f'{i} {s}\n') - - unfixed &= new_fixed.logical_not() - # update seq and tokens - seq[new_fixed] = prev[new_fixed[:, 1:]] - tokens = seq[:, :-1].clone() - tokens[:,1:][unfixed[:, 1:-1]] = prev[:, :-1][unfixed[:, 1:-1]] - - if step_cnt == iterative_step: - seq[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step - n_unfixed = unfixed.sum(dim=-1).tolist() - print(f'Exit with {n_unfixed} unfixed tokens.') - break - if args.debug: - from torchvision.utils import save_image - seqt = seq.clone() - seqt[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step - imgs.extend([tokenizer.img_tokenizer.DecodeIds(s[-4096:]) for s in seqt]) - if args.debug: - imgs = torch.cat(imgs, dim=0) - save_image(imgs, f'steps{device}.jpg', normalize=True) - model.module.transformer.max_memory_length = args.max_memory_length - - return seq \ No newline at end of file diff --git a/generation/sampling.py b/generation/sampling.py deleted file mode 100755 index f6b543d..0000000 --- a/generation/sampling.py +++ /dev/null @@ -1,263 +0,0 @@ -# -*- encoding: utf-8 -*- -''' -@File : sampling.py -@Time : 2021/01/13 19:52:12 -@Author : Ming Ding -@Contact : dm18@mails.tsinghua.edu.cn -''' - -# here put the import lib -import os -import sys -import math -import random - -import numpy as np -import torch -import torch.nn.functional as F - -from pretrain_gpt2 import get_masks_and_position_ids -from data_utils import get_tokenizer -from copy import deepcopy - -def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): - # This function has been mostly taken from huggingface conversational ai code at - # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 - - if top_k > 0: - # Remove all tokens with a probability less than the last token of the top-k - # s1 = (logits-logits.max()).exp().sum() - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - # s2 = (logits-logits.max()).exp().sum() - # with open('lion.txt', 'a') as fout: - # fout.write(f'{s1} {s2}\n') - - if top_p > 0.0: - # convert to 1D - logits = logits.view(logits.size()[1]).contiguous() - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - indices_to_remove = sorted_indices[sorted_indices_to_remove] - logits[indices_to_remove] = filter_value - # going back to 2D - logits = logits.view(1, -1).contiguous() - - return logits - -def get_batch(context_tokens, device, args): - tokens = context_tokens - if len(tokens.shape) == 1: - tokens = tokens.unsqueeze(0).contiguous() - else: - tokens = tokens.view(tokens.shape[0], -1).contiguous() - tokens = tokens.to(device) - - # Get the masks and postition ids. - attention_mask, loss_mask, position_ids = get_masks_and_position_ids( - tokens, args=args) - return tokens, attention_mask, position_ids - -def update_mems(hiddens, mems, max_memory_length=10000): - memory_length = mems[0].size(1) if mems else 0 - query_length = hiddens[0].size(1) - new_memory_length = min(max_memory_length, memory_length + query_length) - new_mems = [] - with torch.no_grad(): - for i in range(len(hiddens)): - if new_memory_length <= query_length: - new_mems.append(hiddens[i][:, -new_memory_length:]) - else: - new_mems.append(torch.cat((mems[i][:, -new_memory_length+query_length:], hiddens[i]), dim=1)) - return new_mems - -def filling_sequence( - model, - seq, - args, - mems=None, - invalid_slices=[], - **kwargs): - ''' - seq: [2, 3, 5, ..., -1(to be generated), -N (N beams), -1] - context_length: first non(-1)s - ''' - tokenizer = get_tokenizer() - device = seq.device - assert len(seq.shape) == 1 - out_seq_length = len(seq) - # building the initial tokens, attention_mask, and position_ids - context_length = 0 - offset = 100000 - - invalid_slices = [slice(0, tokenizer.img_tokenizer.num_tokens)] - - while seq[context_length] >= 0: - # change what to generate - if seq[context_length] in [tokenizer['[BOI1]'], tokenizer['[BOI2]']]: - invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] - elif seq[context_length] in [tokenizer['[EOI1]'], tokenizer['[EOI2]']]: - invalid_slices = [ - slice(0, tokenizer.img_tokenizer.num_tokens), - slice(tokenizer.img_tokenizer.num_tokens + tokenizer.txt_tokenizer.num_tokens, None)] - - if seq[context_length] == tokenizer['[ROI2]']: - offset = context_length - context_length += 1 - tokens, attention_mask, position_ids = get_batch(seq[:context_length], device, args) - txt_len = seq.tolist().index(tokenizer['[BASE]']) - print('txt_len:', txt_len) - config = deepcopy(model.module.transformer.sparse_config) - ori_config = model.module.transformer.sparse_config - config.layout[0] = txt_len - model.module.transformer.reset_sparse_config(config) - - counter = context_length - 1 # == len(tokens) - 1 - index = 0 # len(mems) - if mems is None: - mems = [] - score = [0] # sum log likelihood for beams - - while counter < (out_seq_length - 1): - # Now, we want to generate seq[counter + 1] - # token[:, index: counter+1] are just added. - - if seq[counter + 1] in [tokenizer['[BOI1]'], tokenizer['[BOI2]']]: - invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] - elif seq[counter + 1] in [tokenizer['[EOI1]'], tokenizer['[EOI2]']]: - invalid_slices = [ - slice(0, tokenizer.img_tokenizer.num_tokens), - slice(tokenizer.img_tokenizer.num_tokens + tokenizer.txt_tokenizer.num_tokens, None)] - - if index == 0: # first - position_ids[position_ids > offset] -= offset - logits, *qkv = model(tokens, position_ids, attention_mask, *mems) - mems = update_mems(qkv, mems) - - # tmp = -F.log_softmax(logits, dim=-1) - # tmp = tmp[0,:-1].gather(dim=-1,index=tokens[0,1:].unsqueeze(-1))[4:,0] - # for i in range(1,len(tmp)): - # print(i, tmp[i].item()) - index = counter - # print(tmp[1:].mean(), file=sys.stderr) - elif seq[counter + 1] >= 0: # provided - if seq[counter + 1] == tokenizer['[ROI2]']: - offset = counter + 1 - tokens, mems, score = shrink_beams(tokens, mems, 1, score) - nb = 1 - counter += 1 - tokens = torch.cat((tokens, seq[counter: counter+1].expand(tokens.shape[0], 1)), dim=1) - continue - else: - assert tokens.shape[1] == counter + 1 - position_ids = torch.arange(index, counter + 1, dtype=torch.long, device=tokens.device).unsqueeze(0) - position_ids[position_ids > offset] -= offset - # TODO each time, the feed input cannot be too long (window size), or it will have a discrepcy from sparse training, but this is not very important. - tokens, mems, score = shrink_beams(tokens, mems, -seq[counter + 1], score) - logits, *qkv = model(tokens[:, index: ], - position_ids, - 0, # rebuild in transformers (sep version) - *mems) - mems = update_mems(qkv, mems) - - index = counter - nb = -seq[counter + 1] - counter += 1 - index += 1 - - - logits = logits[:, -1] # [batch size, vocab size] - - temp = args.temperature - real_topk = args.top_k - if counter <= context_length + 32: - real_topk = 80 - # else: - # real_topk = args.top_k - # if counter == context_length + 32 + 12: - # import pdb;pdb.set_trace() - # TODO since the temperature is crucial, how can we find a good setting? - logits /= temp - for invalid_slice in invalid_slices: # to generate other tokens - logits[..., invalid_slice] = -float('Inf') - # logits = top_k_logits(logits, top_k=real_topk, top_p=args.top_p) - probs = F.softmax(logits, dim=-1) - - tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1) - - # expand beams - if nb > 1 and tokens.shape[0] == 1: # 1->nb - tokens = tokens.expand(nb, -1).contiguous() - mems = [mem.expand(nb, -1, -1) for mem in mems] - prev = torch.multinomial(probs, num_samples=nb, replacement=True) - score = torch.log(torch.gather(probs, dim=1, index=prev)[0]).tolist() - else: # nb -> nb - assert tokens.shape[0] == nb - prev = torch.multinomial(probs, num_samples=1) - for j in range(0, prev.shape[0]): - if probs[j, prev[j,-1]] < tk_value[j, -1]: - prev[j, -1] = tk_idx[j,torch.randint(tk_idx.shape[-1]-100, tk_idx.shape[-1], (1,))] - # prev[j, -1] = tk_idx[j,torch.randint(0, tk_idx.shape[-1], (1,))] - - score_plus = torch.log(torch.gather(probs, dim=1, index=prev)[:, 0]) - for idx in range(nb): - score[idx] += score_plus[idx] - - tokens = torch.cat((tokens, prev.view(tokens.shape[0], 1)), dim=1) - - output_tokens_list = tokens.view(tokens.shape[0], -1).contiguous() - model.module.transformer.reset_sparse_config(ori_config) - return output_tokens_list - -def shrink_beams(tokens, mems, nb, score): - # beam search is a failed attempt, will be removed soon... - if tokens.shape[0] == nb: - return tokens, mems, score - # shrink - maximum = max(score) - max_idx = score.index(maximum) - tokens = tokens[max_idx].unsqueeze(0) - score = [0] - new_mems = [mem[max_idx: max_idx + 1] for mem in mems] - return tokens, new_mems, score - -def add_interlacing_beam_marks(seq, nb=12, period=30000): - assert isinstance(seq, list) or len(seq.shape) == 1 - blk_cnt = 0 - for i in range(len(seq)): - if seq[i] == -1: - blk_cnt += 1 - seq[i] = -nb - if blk_cnt == period: - nb += (nb % 2) * 2 - 1 - blk_cnt = 0 - else: - blk_cnt = 0 - - -def inverse_prompt_score(model, seq, args): - tokenizer = get_tokenizer() - device = seq.device - assert len(seq.shape) == 2 - - botext = 2 + 1024 + 1 - assert tokenizer['[ROI1]'] == seq[0][botext] - - tokens, attention_mask, position_ids = get_batch(seq, device, args) - logits, *qkv = model(tokens, position_ids, attention_mask) - mems = update_mems(qkv, mems) - - logits[..., :tokenizer.img_tokenizer.num_tokens] = -float('Inf') - log_probs = torch.log(F.softmax(logits, dim=-1)) - - pred = log_probs[:, botext:-1, :] - target = tokens[:, botext+1:].unsqueeze(-1) - scores = torch.gather(pred, dim=2, index=target).squeeze(-1).sum(dim=-1) - return scores - \ No newline at end of file -- GitLab