diff --git a/generation/cuda_2d_sampling.py b/generation/cuda_2d_sampling.py
index 30ac2fdd9694236477e1bf756ea9c3b064edcb52..21e8dcc2707ea067a6e99b72e988347f15ad8395 100644
--- a/generation/cuda_2d_sampling.py
+++ b/generation/cuda_2d_sampling.py
@@ -10,7 +10,6 @@ def filling_sequence_cuda_2d(
         args, 
         mems=None, 
         invalid_slices=[], 
-        iterative_step=20,
         **kwargs):
     '''
         seq: [id[ROI1], 10000, 20000, id[BASE], id[BOI1], 1024 * -1/known tokens, id[EOI1], 4096 * -1..., ]
@@ -76,14 +75,14 @@ def filling_sequence_cuda_2d(
         step_cnt += 1
 
         # warmup 
-        real_topk = 15
-        real_temp = 0.5 #- min(1,((step_cnt) / iterative_step)) * .3
-        # real_temp = args.temperature
-        if step_cnt <= 5:
-            real_temp = 0.4
-        elif step_cnt == 6:
+        real_topk = 10
+        warmup_steps = 3
+        iterative_step= warmup_steps + 6
+        if step_cnt <= warmup_steps:
+            real_temp = 0.1
+        elif step_cnt == warmup_steps + 1:
             real_temp = 0.55
-        elif step_cnt > 6:
+        elif step_cnt > warmup_steps + 1:
             real_temp = 0.45
         # if  5 < step_cnt:
         #     real_topk = 200
@@ -96,9 +95,9 @@ def filling_sequence_cuda_2d(
         topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1)
         ent = -(topraw * topraw.log()).sum(dim=-1)
         # topsum = topraw.sum(dim=-1)
-        if step_cnt > 5:
+        if step_cnt > warmup_steps:
             # import pdb;pdb.set_trace()
-            real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.2).unsqueeze(-1) + 0.6
+            real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.3).unsqueeze(-1) + 0.6
             # import pdb;pdb.set_trace()
         else:
             real_temp2 = real_temp
@@ -116,7 +115,7 @@ def filling_sequence_cuda_2d(
         # prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1)
         # update unfixed
         choice = 1
-        if choice == 0 and 5 < step_cnt:
+        if choice == 0 and warmup_steps < step_cnt:
             mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
             # import pdb;pdb.set_trace()
             dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:])
@@ -130,11 +129,14 @@ def filling_sequence_cuda_2d(
             moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not()
             moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not()
             # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not()
-        elif choice == 1 and 5 < step_cnt:
+        elif choice == 1 and warmup_steps < step_cnt:
             new_fixed = unfixed & False
-            x = (step_cnt-5) // 4
-            y = (step_cnt-5) % 4
-            new_fixed[..., -4096:].view(batch_size, 16, 4, 16, 4)[:, :, x, :, y] = True
+            ll, rr = 4, 4
+            for x in range(min(ll, step_cnt - warmup_steps)):
+                y = step_cnt - warmup_steps - x - 1
+                if y < rr:
+                    print(x,y)
+                    new_fixed[..., -4096:].view(batch_size, 64//ll, ll, 64//rr, rr)[:, :, x, :, y] = True
             new_fixed &= unfixed
         else:
             new_fixed = unfixed & False # TODO