From 6b89c2bd0c0e78e111fa23b9a7c259a5b5e977ba Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Thu, 23 Sep 2021 07:24:16 +0000 Subject: [PATCH] reduce iters --- generation/cuda_2d_sampling.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/generation/cuda_2d_sampling.py b/generation/cuda_2d_sampling.py index 30ac2fd..21e8dcc 100644 --- a/generation/cuda_2d_sampling.py +++ b/generation/cuda_2d_sampling.py @@ -10,7 +10,6 @@ def filling_sequence_cuda_2d( args, mems=None, invalid_slices=[], - iterative_step=20, **kwargs): ''' seq: [id[ROI1], 10000, 20000, id[BASE], id[BOI1], 1024 * -1/known tokens, id[EOI1], 4096 * -1..., ] @@ -76,14 +75,14 @@ def filling_sequence_cuda_2d( step_cnt += 1 # warmup - real_topk = 15 - real_temp = 0.5 #- min(1,((step_cnt) / iterative_step)) * .3 - # real_temp = args.temperature - if step_cnt <= 5: - real_temp = 0.4 - elif step_cnt == 6: + real_topk = 10 + warmup_steps = 3 + iterative_step= warmup_steps + 6 + if step_cnt <= warmup_steps: + real_temp = 0.1 + elif step_cnt == warmup_steps + 1: real_temp = 0.55 - elif step_cnt > 6: + elif step_cnt > warmup_steps + 1: real_temp = 0.45 # if 5 < step_cnt: # real_topk = 200 @@ -96,9 +95,9 @@ def filling_sequence_cuda_2d( topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1) ent = -(topraw * topraw.log()).sum(dim=-1) # topsum = topraw.sum(dim=-1) - if step_cnt > 5: + if step_cnt > warmup_steps: # import pdb;pdb.set_trace() - real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.2).unsqueeze(-1) + 0.6 + real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.3).unsqueeze(-1) + 0.6 # import pdb;pdb.set_trace() else: real_temp2 = real_temp @@ -116,7 +115,7 @@ def filling_sequence_cuda_2d( # prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1) # update unfixed choice = 1 - if choice == 0 and 5 < step_cnt: + if choice == 0 and warmup_steps < step_cnt: mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2])) # import pdb;pdb.set_trace() dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:]) @@ -130,11 +129,14 @@ def filling_sequence_cuda_2d( moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not() moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not() # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not() - elif choice == 1 and 5 < step_cnt: + elif choice == 1 and warmup_steps < step_cnt: new_fixed = unfixed & False - x = (step_cnt-5) // 4 - y = (step_cnt-5) % 4 - new_fixed[..., -4096:].view(batch_size, 16, 4, 16, 4)[:, :, x, :, y] = True + ll, rr = 4, 4 + for x in range(min(ll, step_cnt - warmup_steps)): + y = step_cnt - warmup_steps - x - 1 + if y < rr: + print(x,y) + new_fixed[..., -4096:].view(batch_size, 64//ll, ll, 64//rr, rr)[:, :, x, :, y] = True new_fixed &= unfixed else: new_fixed = unfixed & False # TODO -- GitLab