Newer
Older
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from .sampling import *
import math
import sys
from copy import deepcopy
from torchvision.utils import save_image
def filling_sequence_cuda_2d(
model,
seq,
args,
mems=None,
invalid_slices=[],
**kwargs):
'''
seq: [id[ROI1], 10000, 20000, id[BASE], id[BOI1], 1024 * -1/known tokens, id[EOI1], 4096 * -1..., ]
'''
tokenizer = get_tokenizer()
invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
device = seq.device
assert args.sparse_config.sparse_type == 'cuda_2d'
std_config = deepcopy(args.sparse_config)
std_config.sparse_type = 'standard'
sparse_config = args.sparse_config
# split two parts
seq0, seq1 = seq[:-4097], seq[-4097:] # +1 for EOI1
# generate a batch of seq0
model.module.transformer.reset_sparse_config(std_config)
args.sparse_config = std_config
output0 = filling_sequence(model, seq0, args)
model.module.transformer.reset_sparse_config(sparse_config)
args.sparse_config = sparse_config
model.module.transformer.max_memory_length = 0
# filter bad generation & select top N=2, TODO
output0 = output0
from torchvision import transforms
tr = transforms.Compose([
transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
])
imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1024:].tolist())) for x in output0] # ground truth
blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(device), add_normalization=True) # blured image as init value
# pad seq to desired shape
n_pad = args.layout[1] - len(seq0)
batch_size = output0.shape[0]
assert n_pad > 0, "You should truncate long input before filling."
seq = torch.cat((
torch.tensor([tokenizer['[PAD]']]* n_pad, device=seq.device, dtype=seq.dtype)
.unsqueeze(0).expand(batch_size, n_pad),
output0,
seq1.unsqueeze(0).expand(batch_size, len(seq1))
), dim=1
) # [b, layout[-1]]
# init
step_cnt = 0
tokens = seq[:, :-1].clone()
unfixed = (seq < 0)
# tokens[unfixed[:, :-1]] = tokens[unfixed[:, :-1]].random_(0, tokenizer.img_tokenizer.num_tokens)
tokens[:, -4095:] = blur64[:, :-1]
attention_mask = torch.ones(args.layout[1], args.layout[1]).tril().to(device)
attention_mask[n_pad:, :n_pad] = 0
position_ids = torch.cat((
torch.zeros(n_pad, dtype=torch.long),
torch.arange(0, args.layout[1] - n_pad),
torch.arange(0, args.layout[2]-args.layout[1]))).to(device)
# iterate
imgs = []
# import pdb;pdb.set_trace()
while unfixed.sum() > 0:
print(unfixed.sum())
logits, *_dump = model(tokens, position_ids, attention_mask)
step_cnt += 1
# warmup
real_topk = 10
warmup_steps = 3
iterative_step= warmup_steps + 6
if step_cnt <= warmup_steps:
real_temp = 0.1
elif step_cnt == warmup_steps + 1:
# sampling
for invalid_slice in invalid_slices: # forbide to generate other tokens
logits[..., invalid_slice] = -float('Inf')
assert args.top_k > 0
# probs0 = F.softmax(logits/real_temp, dim=-1)
topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1)
ent = -(topraw * topraw.log()).sum(dim=-1)
# topsum = topraw.sum(dim=-1)
real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.3).unsqueeze(-1) + 0.6
# import pdb;pdb.set_trace()
else:
real_temp2 = real_temp
# import pdb;pdb.set_trace()
probs = F.softmax(logits/real_temp2, dim=-1)
tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1)
prev = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
edge_idx = tk_idx[:, :, -1:]
edge_value = tk_value[:, :, -1:]
edge_mask = probs.gather(dim=-1, index=prev) < edge_value
prev[edge_mask] = edge_idx[edge_mask]
prev.squeeze_(-1)
# tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1])
# prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1)
# prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1)
# mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
# # import pdb;pdb.set_trace()
# dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:])
# new_fixed = unfixed.clone()
# moved_new_fixed = new_fixed[:, 2:]
# moved_new_fixed &= dprob
# moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not()
# moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not()
# # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not()
# moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not()
# moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not()
# # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not()
pass
ll, rr = 4, 4
for x in range(min(ll, step_cnt - warmup_steps)):
y = step_cnt - warmup_steps - x - 1
if y < rr:
print(x,y)
new_fixed[..., -4096:].view(batch_size, 64//ll, ll, 64//rr, rr)[:, :, x, :, y] = True
else:
new_fixed = unfixed & False # TODO
new_fixed[:, -1] = True
# with open(f'bed{step_cnt}.txt', 'w') as fout:
# for i, prob in enumerate(topraw[0, -4096:]):
# s = ' '.join([str(x) for x in prob.tolist()])
# fout.write(f'{i} {s}\n')
unfixed &= new_fixed.logical_not()
# update seq and tokens
seq[new_fixed] = prev[new_fixed[:, 1:]]
tokens = seq[:, :-1].clone()
tokens[:,1:][unfixed[:, 1:-1]] = prev[:, :-1][unfixed[:, 1:-1]]
if step_cnt == iterative_step:
seq[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
n_unfixed = unfixed.sum(dim=-1).tolist()
print(f'Exit with {n_unfixed} unfixed tokens.')
break
if args.debug:
from torchvision.utils import save_image
seqt = seq.clone()
seqt[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
imgs.extend([tokenizer.img_tokenizer.DecodeIds(s[-4096:]) for s in seqt])
if args.debug:
imgs = torch.cat(imgs, dim=0)
save_image(imgs, f'steps{device}.jpg', normalize=True)
model.module.transformer.max_memory_length = args.max_memory_length
return seq