Skip to content
Snippets Groups Projects
Commit a2625727 authored by Ming Ding's avatar Ming Ding
Browse files

del old sampling

parent e3f4751f
No related branches found
No related tags found
No related merge requests found
from vqvae.vqvae_zc import Encoder
from .sampling import *
import math
import sys
from copy import deepcopy
from torchvision.utils import save_image
def filling_sequence_cuda_2d(
model,
seq,
args,
mems=None,
invalid_slices=[],
**kwargs):
'''
seq: [id[ROI1], 10000, 20000, id[BASE], id[BOI1], 1024 * -1/known tokens, id[EOI1], 4096 * -1..., ]
'''
tokenizer = get_tokenizer()
invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
device = seq.device
assert args.sparse_config.sparse_type == 'cuda_2d'
std_config = deepcopy(args.sparse_config)
std_config.sparse_type = 'standard'
sparse_config = args.sparse_config
# split two parts
seq0, seq1 = seq[:-4097], seq[-4097:] # +1 for EOI1
# generate a batch of seq0
model.module.transformer.reset_sparse_config(std_config)
args.sparse_config = std_config
output0 = filling_sequence(model, seq0, args)
model.module.transformer.reset_sparse_config(sparse_config)
args.sparse_config = sparse_config
model.module.transformer.max_memory_length = 0
# filter bad generation & select top N=2, TODO
output0 = output0
from torchvision import transforms
tr = transforms.Compose([
transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
])
imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1024:].tolist())) for x in output0] # ground truth
blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(device), add_normalization=True) # blured image as init value
# pad seq to desired shape
n_pad = args.layout[1] - len(seq0)
batch_size = output0.shape[0]
assert n_pad > 0, "You should truncate long input before filling."
seq = torch.cat((
torch.tensor([tokenizer['[PAD]']]* n_pad, device=seq.device, dtype=seq.dtype)
.unsqueeze(0).expand(batch_size, n_pad),
output0,
seq1.unsqueeze(0).expand(batch_size, len(seq1))
), dim=1
) # [b, layout[-1]]
# init
step_cnt = 0
tokens = seq[:, :-1].clone()
unfixed = (seq < 0)
# tokens[unfixed[:, :-1]] = tokens[unfixed[:, :-1]].random_(0, tokenizer.img_tokenizer.num_tokens)
tokens[:, -4095:] = blur64[:, :-1]
attention_mask = torch.ones(args.layout[1], args.layout[1]).tril().to(device)
attention_mask[n_pad:, :n_pad] = 0
position_ids = torch.cat((
torch.zeros(n_pad, dtype=torch.long),
torch.arange(0, args.layout[1] - n_pad),
torch.arange(0, args.layout[2]-args.layout[1]))).to(device)
# iterate
imgs = []
# import pdb;pdb.set_trace()
while unfixed.sum() > 0:
print(unfixed.sum())
logits, *_dump = model(tokens, position_ids, attention_mask)
step_cnt += 1
# warmup
real_topk = 10
warmup_steps = 3
iterative_step= warmup_steps + 6
if step_cnt <= warmup_steps:
real_temp = 0.1
elif step_cnt == warmup_steps + 1:
real_temp = 0.55
elif step_cnt > warmup_steps + 1:
real_temp = 0.45
# if 5 < step_cnt:
# real_topk = 200
# sampling
for invalid_slice in invalid_slices: # forbide to generate other tokens
logits[..., invalid_slice] = -float('Inf')
assert args.top_k > 0
# probs0 = F.softmax(logits/real_temp, dim=-1)
topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1)
ent = -(topraw * topraw.log()).sum(dim=-1)
# topsum = topraw.sum(dim=-1)
if step_cnt > warmup_steps:
# import pdb;pdb.set_trace()
real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.3).unsqueeze(-1) + 0.6
# import pdb;pdb.set_trace()
else:
real_temp2 = real_temp
# import pdb;pdb.set_trace()
probs = F.softmax(logits/real_temp2, dim=-1)
tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1)
prev = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
edge_idx = tk_idx[:, :, -1:]
edge_value = tk_value[:, :, -1:]
edge_mask = probs.gather(dim=-1, index=prev) < edge_value
prev[edge_mask] = edge_idx[edge_mask]
prev.squeeze_(-1)
# tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1])
# prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1)
# prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1)
# update unfixed
choice = 1
if choice == 0 and warmup_steps < step_cnt:
# mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
# # import pdb;pdb.set_trace()
# dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:])
# new_fixed = unfixed.clone()
# moved_new_fixed = new_fixed[:, 2:]
# moved_new_fixed &= dprob
# moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not()
# moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not()
# # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not()
# moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not()
# moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not()
# # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not()
pass
elif choice == 1 and warmup_steps < step_cnt:
new_fixed = unfixed & False
ll, rr = 4, 4
for x in range(min(ll, step_cnt - warmup_steps)):
y = step_cnt - warmup_steps - x - 1
if y < rr:
print(x,y)
new_fixed[..., -4096:].view(batch_size, 64//ll, ll, 64//rr, rr)[:, :, x, :, y] = True
new_fixed &= unfixed
else:
new_fixed = unfixed & False # TODO
new_fixed[:, -1] = True
# with open(f'bed{step_cnt}.txt', 'w') as fout:
# for i, prob in enumerate(topraw[0, -4096:]):
# s = ' '.join([str(x) for x in prob.tolist()])
# fout.write(f'{i} {s}\n')
unfixed &= new_fixed.logical_not()
# update seq and tokens
seq[new_fixed] = prev[new_fixed[:, 1:]]
tokens = seq[:, :-1].clone()
tokens[:,1:][unfixed[:, 1:-1]] = prev[:, :-1][unfixed[:, 1:-1]]
if step_cnt == iterative_step:
seq[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
n_unfixed = unfixed.sum(dim=-1).tolist()
print(f'Exit with {n_unfixed} unfixed tokens.')
break
if args.debug:
from torchvision.utils import save_image
seqt = seq.clone()
seqt[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
imgs.extend([tokenizer.img_tokenizer.DecodeIds(s[-4096:]) for s in seqt])
if args.debug:
imgs = torch.cat(imgs, dim=0)
save_image(imgs, f'steps{device}.jpg', normalize=True)
model.module.transformer.max_memory_length = args.max_memory_length
return seq
\ No newline at end of file
# -*- encoding: utf-8 -*-
'''
@File : sampling.py
@Time : 2021/01/13 19:52:12
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import numpy as np
import torch
import torch.nn.functional as F
from pretrain_gpt2 import get_masks_and_position_ids
from data_utils import get_tokenizer
from copy import deepcopy
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# This function has been mostly taken from huggingface conversational ai code at
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
# s1 = (logits-logits.max()).exp().sum()
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
# s2 = (logits-logits.max()).exp().sum()
# with open('lion.txt', 'a') as fout:
# fout.write(f'{s1} {s2}\n')
if top_p > 0.0:
# convert to 1D
logits = logits.view(logits.size()[1]).contiguous()
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
# going back to 2D
logits = logits.view(1, -1).contiguous()
return logits
def get_batch(context_tokens, device, args):
tokens = context_tokens
if len(tokens.shape) == 1:
tokens = tokens.unsqueeze(0).contiguous()
else:
tokens = tokens.view(tokens.shape[0], -1).contiguous()
tokens = tokens.to(device)
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
tokens, args=args)
return tokens, attention_mask, position_ids
def update_mems(hiddens, mems, max_memory_length=10000):
memory_length = mems[0].size(1) if mems else 0
query_length = hiddens[0].size(1)
new_memory_length = min(max_memory_length, memory_length + query_length)
new_mems = []
with torch.no_grad():
for i in range(len(hiddens)):
if new_memory_length <= query_length:
new_mems.append(hiddens[i][:, -new_memory_length:])
else:
new_mems.append(torch.cat((mems[i][:, -new_memory_length+query_length:], hiddens[i]), dim=1))
return new_mems
def filling_sequence(
model,
seq,
args,
mems=None,
invalid_slices=[],
**kwargs):
'''
seq: [2, 3, 5, ..., -1(to be generated), -N (N beams), -1]
context_length: first non(-1)s
'''
tokenizer = get_tokenizer()
device = seq.device
assert len(seq.shape) == 1
out_seq_length = len(seq)
# building the initial tokens, attention_mask, and position_ids
context_length = 0
offset = 100000
invalid_slices = [slice(0, tokenizer.img_tokenizer.num_tokens)]
while seq[context_length] >= 0:
# change what to generate
if seq[context_length] in [tokenizer['[BOI1]'], tokenizer['[BOI2]']]:
invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
elif seq[context_length] in [tokenizer['[EOI1]'], tokenizer['[EOI2]']]:
invalid_slices = [
slice(0, tokenizer.img_tokenizer.num_tokens),
slice(tokenizer.img_tokenizer.num_tokens + tokenizer.txt_tokenizer.num_tokens, None)]
if seq[context_length] == tokenizer['[ROI2]']:
offset = context_length
context_length += 1
tokens, attention_mask, position_ids = get_batch(seq[:context_length], device, args)
txt_len = seq.tolist().index(tokenizer['[BASE]'])
print('txt_len:', txt_len)
config = deepcopy(model.module.transformer.sparse_config)
ori_config = model.module.transformer.sparse_config
config.layout[0] = txt_len
model.module.transformer.reset_sparse_config(config)
counter = context_length - 1 # == len(tokens) - 1
index = 0 # len(mems)
if mems is None:
mems = []
score = [0] # sum log likelihood for beams
while counter < (out_seq_length - 1):
# Now, we want to generate seq[counter + 1]
# token[:, index: counter+1] are just added.
if seq[counter + 1] in [tokenizer['[BOI1]'], tokenizer['[BOI2]']]:
invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
elif seq[counter + 1] in [tokenizer['[EOI1]'], tokenizer['[EOI2]']]:
invalid_slices = [
slice(0, tokenizer.img_tokenizer.num_tokens),
slice(tokenizer.img_tokenizer.num_tokens + tokenizer.txt_tokenizer.num_tokens, None)]
if index == 0: # first
position_ids[position_ids > offset] -= offset
logits, *qkv = model(tokens, position_ids, attention_mask, *mems)
mems = update_mems(qkv, mems)
# tmp = -F.log_softmax(logits, dim=-1)
# tmp = tmp[0,:-1].gather(dim=-1,index=tokens[0,1:].unsqueeze(-1))[4:,0]
# for i in range(1,len(tmp)):
# print(i, tmp[i].item())
index = counter
# print(tmp[1:].mean(), file=sys.stderr)
elif seq[counter + 1] >= 0: # provided
if seq[counter + 1] == tokenizer['[ROI2]']:
offset = counter + 1
tokens, mems, score = shrink_beams(tokens, mems, 1, score)
nb = 1
counter += 1
tokens = torch.cat((tokens, seq[counter: counter+1].expand(tokens.shape[0], 1)), dim=1)
continue
else:
assert tokens.shape[1] == counter + 1
position_ids = torch.arange(index, counter + 1, dtype=torch.long, device=tokens.device).unsqueeze(0)
position_ids[position_ids > offset] -= offset
# TODO each time, the feed input cannot be too long (window size), or it will have a discrepcy from sparse training, but this is not very important.
tokens, mems, score = shrink_beams(tokens, mems, -seq[counter + 1], score)
logits, *qkv = model(tokens[:, index: ],
position_ids,
0, # rebuild in transformers (sep version)
*mems)
mems = update_mems(qkv, mems)
index = counter
nb = -seq[counter + 1]
counter += 1
index += 1
logits = logits[:, -1] # [batch size, vocab size]
temp = args.temperature
real_topk = args.top_k
if counter <= context_length + 32:
real_topk = 80
# else:
# real_topk = args.top_k
# if counter == context_length + 32 + 12:
# import pdb;pdb.set_trace()
# TODO since the temperature is crucial, how can we find a good setting?
logits /= temp
for invalid_slice in invalid_slices: # to generate other tokens
logits[..., invalid_slice] = -float('Inf')
# logits = top_k_logits(logits, top_k=real_topk, top_p=args.top_p)
probs = F.softmax(logits, dim=-1)
tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1)
# expand beams
if nb > 1 and tokens.shape[0] == 1: # 1->nb
tokens = tokens.expand(nb, -1).contiguous()
mems = [mem.expand(nb, -1, -1) for mem in mems]
prev = torch.multinomial(probs, num_samples=nb, replacement=True)
score = torch.log(torch.gather(probs, dim=1, index=prev)[0]).tolist()
else: # nb -> nb
assert tokens.shape[0] == nb
prev = torch.multinomial(probs, num_samples=1)
for j in range(0, prev.shape[0]):
if probs[j, prev[j,-1]] < tk_value[j, -1]:
prev[j, -1] = tk_idx[j,torch.randint(tk_idx.shape[-1]-100, tk_idx.shape[-1], (1,))]
# prev[j, -1] = tk_idx[j,torch.randint(0, tk_idx.shape[-1], (1,))]
score_plus = torch.log(torch.gather(probs, dim=1, index=prev)[:, 0])
for idx in range(nb):
score[idx] += score_plus[idx]
tokens = torch.cat((tokens, prev.view(tokens.shape[0], 1)), dim=1)
output_tokens_list = tokens.view(tokens.shape[0], -1).contiguous()
model.module.transformer.reset_sparse_config(ori_config)
return output_tokens_list
def shrink_beams(tokens, mems, nb, score):
# beam search is a failed attempt, will be removed soon...
if tokens.shape[0] == nb:
return tokens, mems, score
# shrink
maximum = max(score)
max_idx = score.index(maximum)
tokens = tokens[max_idx].unsqueeze(0)
score = [0]
new_mems = [mem[max_idx: max_idx + 1] for mem in mems]
return tokens, new_mems, score
def add_interlacing_beam_marks(seq, nb=12, period=30000):
assert isinstance(seq, list) or len(seq.shape) == 1
blk_cnt = 0
for i in range(len(seq)):
if seq[i] == -1:
blk_cnt += 1
seq[i] = -nb
if blk_cnt == period:
nb += (nb % 2) * 2 - 1
blk_cnt = 0
else:
blk_cnt = 0
def inverse_prompt_score(model, seq, args):
tokenizer = get_tokenizer()
device = seq.device
assert len(seq.shape) == 2
botext = 2 + 1024 + 1
assert tokenizer['[ROI1]'] == seq[0][botext]
tokens, attention_mask, position_ids = get_batch(seq, device, args)
logits, *qkv = model(tokens, position_ids, attention_mask)
mems = update_mems(qkv, mems)
logits[..., :tokenizer.img_tokenizer.num_tokens] = -float('Inf')
log_probs = torch.log(F.softmax(logits, dim=-1))
pred = log_probs[:, botext:-1, :]
target = tokens[:, botext+1:].unsqueeze(-1)
scores = torch.gather(pred, dim=2, index=target).squeeze(-1).sum(dim=-1)
return scores
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment