From 6b89c2bd0c0e78e111fa23b9a7c259a5b5e977ba Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Thu, 23 Sep 2021 07:24:16 +0000
Subject: [PATCH] reduce iters

---
 generation/cuda_2d_sampling.py | 32 +++++++++++++++++---------------
 1 file changed, 17 insertions(+), 15 deletions(-)

diff --git a/generation/cuda_2d_sampling.py b/generation/cuda_2d_sampling.py
index 30ac2fd..21e8dcc 100644
--- a/generation/cuda_2d_sampling.py
+++ b/generation/cuda_2d_sampling.py
@@ -10,7 +10,6 @@ def filling_sequence_cuda_2d(
         args, 
         mems=None, 
         invalid_slices=[], 
-        iterative_step=20,
         **kwargs):
     '''
         seq: [id[ROI1], 10000, 20000, id[BASE], id[BOI1], 1024 * -1/known tokens, id[EOI1], 4096 * -1..., ]
@@ -76,14 +75,14 @@ def filling_sequence_cuda_2d(
         step_cnt += 1
 
         # warmup 
-        real_topk = 15
-        real_temp = 0.5 #- min(1,((step_cnt) / iterative_step)) * .3
-        # real_temp = args.temperature
-        if step_cnt <= 5:
-            real_temp = 0.4
-        elif step_cnt == 6:
+        real_topk = 10
+        warmup_steps = 3
+        iterative_step= warmup_steps + 6
+        if step_cnt <= warmup_steps:
+            real_temp = 0.1
+        elif step_cnt == warmup_steps + 1:
             real_temp = 0.55
-        elif step_cnt > 6:
+        elif step_cnt > warmup_steps + 1:
             real_temp = 0.45
         # if  5 < step_cnt:
         #     real_topk = 200
@@ -96,9 +95,9 @@ def filling_sequence_cuda_2d(
         topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1)
         ent = -(topraw * topraw.log()).sum(dim=-1)
         # topsum = topraw.sum(dim=-1)
-        if step_cnt > 5:
+        if step_cnt > warmup_steps:
             # import pdb;pdb.set_trace()
-            real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.2).unsqueeze(-1) + 0.6
+            real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.3).unsqueeze(-1) + 0.6
             # import pdb;pdb.set_trace()
         else:
             real_temp2 = real_temp
@@ -116,7 +115,7 @@ def filling_sequence_cuda_2d(
         # prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1)
         # update unfixed
         choice = 1
-        if choice == 0 and 5 < step_cnt:
+        if choice == 0 and warmup_steps < step_cnt:
             mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
             # import pdb;pdb.set_trace()
             dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:])
@@ -130,11 +129,14 @@ def filling_sequence_cuda_2d(
             moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not()
             moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not()
             # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not()
-        elif choice == 1 and 5 < step_cnt:
+        elif choice == 1 and warmup_steps < step_cnt:
             new_fixed = unfixed & False
-            x = (step_cnt-5) // 4
-            y = (step_cnt-5) % 4
-            new_fixed[..., -4096:].view(batch_size, 16, 4, 16, 4)[:, :, x, :, y] = True
+            ll, rr = 4, 4
+            for x in range(min(ll, step_cnt - warmup_steps)):
+                y = step_cnt - warmup_steps - x - 1
+                if y < rr:
+                    print(x,y)
+                    new_fixed[..., -4096:].view(batch_size, 64//ll, ll, 64//rr, rr)[:, :, x, :, y] = True
             new_fixed &= unfixed
         else:
             new_fixed = unfixed & False # TODO
-- 
GitLab