From 08c82cbaca65f09752bbc70c727fc63d39dd250d Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Wed, 22 Sep 2021 17:13:11 +0000 Subject: [PATCH] change to more sr --- generation/cuda_2d_sampling.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/generation/cuda_2d_sampling.py b/generation/cuda_2d_sampling.py index 69d3d7c..30ac2fd 100644 --- a/generation/cuda_2d_sampling.py +++ b/generation/cuda_2d_sampling.py @@ -76,26 +76,29 @@ def filling_sequence_cuda_2d( step_cnt += 1 # warmup - real_topk = 200 - # real_temp = 0.7 #- min(1,((step_cnt) / iterative_step)) * .3 + real_topk = 15 + real_temp = 0.5 #- min(1,((step_cnt) / iterative_step)) * .3 # real_temp = args.temperature if step_cnt <= 5: - real_temp = 0.1 + real_temp = 0.4 elif step_cnt == 6: real_temp = 0.55 elif step_cnt > 6: real_temp = 0.45 - if 5 < step_cnt: - real_topk = 200 + # if 5 < step_cnt: + # real_topk = 200 # sampling for invalid_slice in invalid_slices: # forbide to generate other tokens logits[..., invalid_slice] = -float('Inf') assert args.top_k > 0 - probs0 = F.softmax(logits/real_temp, dim=-1) - topsum = torch.topk(probs0, 20, dim=-1)[0].sum(dim=-1) - if step_cnt >= 6: - real_temp2 = torch.tensor([[[real_temp]]], device=probs0.device).expand(*probs0.shape[:2], 1) * (topsum < 0.95).unsqueeze(-1) + 0.6 + # probs0 = F.softmax(logits/real_temp, dim=-1) + topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1) + ent = -(topraw * topraw.log()).sum(dim=-1) + # topsum = topraw.sum(dim=-1) + if step_cnt > 5: + # import pdb;pdb.set_trace() + real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.2).unsqueeze(-1) + 0.6 # import pdb;pdb.set_trace() else: real_temp2 = real_temp @@ -137,9 +140,10 @@ def filling_sequence_cuda_2d( new_fixed = unfixed & False # TODO new_fixed[:, -1] = True - with open(f'bed{step_cnt}.txt', 'w') as fout: - for i, prob in enumerate(topsum[0, -4096:]): - fout.write(f'{i} {prob}\n') + # with open(f'bed{step_cnt}.txt', 'w') as fout: + # for i, prob in enumerate(topraw[0, -4096:]): + # s = ' '.join([str(x) for x in prob.tolist()]) + # fout.write(f'{i} {s}\n') unfixed &= new_fixed.logical_not() # update seq and tokens -- GitLab