Skip to content
Snippets Groups Projects
Commit 08c82cba authored by Ming Ding's avatar Ming Ding
Browse files

change to more sr

parent 69eab316
Branches master
No related tags found
No related merge requests found
......@@ -76,26 +76,29 @@ def filling_sequence_cuda_2d(
step_cnt += 1
# warmup
real_topk = 200
# real_temp = 0.7 #- min(1,((step_cnt) / iterative_step)) * .3
real_topk = 15
real_temp = 0.5 #- min(1,((step_cnt) / iterative_step)) * .3
# real_temp = args.temperature
if step_cnt <= 5:
real_temp = 0.1
real_temp = 0.4
elif step_cnt == 6:
real_temp = 0.55
elif step_cnt > 6:
real_temp = 0.45
if 5 < step_cnt:
real_topk = 200
# if 5 < step_cnt:
# real_topk = 200
# sampling
for invalid_slice in invalid_slices: # forbide to generate other tokens
logits[..., invalid_slice] = -float('Inf')
assert args.top_k > 0
probs0 = F.softmax(logits/real_temp, dim=-1)
topsum = torch.topk(probs0, 20, dim=-1)[0].sum(dim=-1)
if step_cnt >= 6:
real_temp2 = torch.tensor([[[real_temp]]], device=probs0.device).expand(*probs0.shape[:2], 1) * (topsum < 0.95).unsqueeze(-1) + 0.6
# probs0 = F.softmax(logits/real_temp, dim=-1)
topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1)
ent = -(topraw * topraw.log()).sum(dim=-1)
# topsum = topraw.sum(dim=-1)
if step_cnt > 5:
# import pdb;pdb.set_trace()
real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.2).unsqueeze(-1) + 0.6
# import pdb;pdb.set_trace()
else:
real_temp2 = real_temp
......@@ -137,9 +140,10 @@ def filling_sequence_cuda_2d(
new_fixed = unfixed & False # TODO
new_fixed[:, -1] = True
with open(f'bed{step_cnt}.txt', 'w') as fout:
for i, prob in enumerate(topsum[0, -4096:]):
fout.write(f'{i} {prob}\n')
# with open(f'bed{step_cnt}.txt', 'w') as fout:
# for i, prob in enumerate(topraw[0, -4096:]):
# s = ' '.join([str(x) for x in prob.tolist()])
# fout.write(f'{i} {s}\n')
unfixed &= new_fixed.logical_not()
# update seq and tokens
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment