diff --git a/arguments.py b/arguments.py
index a9e722aa247c317239666c5f82906be55fbfb433..7ea04322c89471000927a29dc854f46b19ead1fd 100755
--- a/arguments.py
+++ b/arguments.py
@@ -311,12 +311,12 @@ def add_sparse_args(parser):
 def make_sparse_config(args):
     args.layout = [int(x) for x in args.layout.split(',')]
     sparse_config = argparse.Namespace(sparse_type=args.sparse_type)
+    sparse_config.layout = args.layout
     if args.sparse_type == 'standard':
         pass
     if args.sparse_type == 'cuda_2d' or args.generation_task == 'cuda-2d generation':
         sparse_config.kernel_size = args.kernel_size
         sparse_config.kernel_size2 = args.kernel_size2
-        sparse_config.layout = args.layout
     elif args.sparse_type == 'torch_1d':
         raise NotImplementedError
     args.sparse_config = sparse_config
diff --git a/create_gt.py b/create_gt.py
new file mode 100644
index 0000000000000000000000000000000000000000..45cd3efdb6677a4ce919100d253e73edbac9118a
--- /dev/null
+++ b/create_gt.py
@@ -0,0 +1,16 @@
+# %%
+p = 'people.jpeg'
+from data_utils.vqvae_tokenizer import VQVAETokenizer
+model = VQVAETokenizer(
+    'pretrained/vqvae/vqvae_hard_biggerset_011.pt'
+)
+img = model.read_img(p, img_size=512)
+# %%
+test_dir = 'tmp'
+import os
+import torch
+from torchvision.utils import save_image
+img = model.EncodeAsIds(img)
+imgs = model.DecodeIds(torch.tensor(img))
+save_image(imgs, os.path.join(test_dir, 'show512_people.jpg'), normalize=True)
+# %%
diff --git a/data_utils/datasets.py b/data_utils/datasets.py
index 7b1203839f85f14874d96647c4f7c0fd7be7d042..d12d8f7ab827230b66e2cf4be06fe46911982050 100755
--- a/data_utils/datasets.py
+++ b/data_utils/datasets.py
@@ -117,23 +117,21 @@ def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset):
                 }
 
     elif dataset_type == 'CompactBinaryDataset':
-        layout = args.layout
+        layout = [64, 64+16**2, 64+16**2+32**2, 64+64**2+16**2+32**2] # FIXME
         DS_CLASS = BinaryDataset
         kwargs_to_dataset['length_per_sample'] = layout[-1]
         def process_fn(row):
             row = row.astype(np.int64)
-            # THIS IS Reverse order, TODO 
-            lens = list(reversed([layout[i] - layout[i-1] for i in range(1, len(layout))]))
-            codes = [row[layout[0]: layout[0]+lens[0]]]
-            if len(lens) > 1:
-                codes.append(row[layout[0]+lens[0]: layout[0]+lens[0]+lens[1]])
+        
+            codes = [row[layout[i-1]:layout[i]] for i in range(1, len(layout))]
+            
             text = row[:layout[0]]
             text = text[text>0][:layout[0] - 3] # [CLS] [BASE] [ROI1]
             n_pad = layout[0]-3-len(text)
             parts = [
                 np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64),
-                TextCodeTemplate(text, codes[-1]),
-                *reversed(codes[:-1])
+                TextCodeTemplate(text, codes[1]), # FIXME
+                *codes[2:] # FIXME
             ]
             ret = np.concatenate(parts, axis=0)
             return {'text': ret, 
diff --git a/draw_diff.py b/draw_diff.py
index b4baab18c797077231138786e377f66adee62f62..6c2d74b5cf57c469752a5fb3c5170574115e6c94 100644
--- a/draw_diff.py
+++ b/draw_diff.py
@@ -7,6 +7,15 @@ def loadbao(name):
             a, b = line.split()
             ret.append(abs(float(b)))
     return ret
+
+def loadlion(name):
+    ret1, ret2 = [], []
+    with open(name, 'r') as fin:
+        for line in fin:
+            a, b = line.split()
+            ret1.append(abs(float(a)))
+            ret2.append(abs(float(b)))
+    return ret1, ret2
 import torchvision
 import torchvision.transforms as transforms
 
@@ -20,13 +29,13 @@ def sq(img, x, y, lx, ly):
 transform = transforms.Compose([
                 transforms.Resize(512),
                 transforms.CenterCrop(512),
-            ])
-img = torchvision.io.read_image('bao.jpeg')
+            ]) 
+img = torchvision.io.read_image('cat2.jpeg')
 img = transform(img) / 255.
-a = np.array(loadbao('bao2.txt'))
-b = np.array(loadbao('bao3.txt'))
-for t in (b-a>1).nonzero()[0]:
-    x,y = t // 32, t % 32
-    sq(img, x*16, y*16, 15, 15)
-print(a.mean(), b.mean())
-torchvision.utils.save_image(img, 'example_bao.jpg')
+# a,b = np.array(loadlion('bed6.txt'))
+b = np.array(loadbao('bed6.txt'))
+for t in (b<0.999).nonzero()[0]:
+    x,y = t // 64, t % 64
+    sq(img, x*8, y*8, 7, 7)
+print(b.sum())
+torchvision.utils.save_image(img, 'example_cat6.jpg')
diff --git a/generate_samples.py b/generate_samples.py
index 4ae8fa02da9c751920e649b057705aab2a5587c4..f0c5ba54e71a87d8443e08c342714345f1bb000d 100755
--- a/generate_samples.py
+++ b/generate_samples.py
@@ -325,13 +325,13 @@ def main():
         torch.cuda.set_device(device)
 
     # Random seeds for reproducability.
-    set_random_seed(args.seed)
 
     # get the tokenizer
     tokenizer = prepare_tokenizer(args)
 
     # Model, optimizer, and learning rate.
     model = setup_model(args)
+    set_random_seed(args.seed)
 
     generate_images_continually(model, args)
 
diff --git a/generation/cuda_2d_sampling.py b/generation/cuda_2d_sampling.py
index d5fe7f83935a99dcdbbbe5da57fe506c5ee700ab..69d3d7cb10d6a902ad2c2711d2e20abd0a61e5cc 100644
--- a/generation/cuda_2d_sampling.py
+++ b/generation/cuda_2d_sampling.py
@@ -1,3 +1,4 @@
+from vqvae.vqvae_zc import Encoder
 from .sampling import *
 import math
 import sys
@@ -37,7 +38,7 @@ def filling_sequence_cuda_2d(
 
     from torchvision import transforms
     tr = transforms.Compose([
-        transforms.Resize(512), 
+        transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), 
     ])
     imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1024:].tolist())) for x in output0] # ground truth
     blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(device), add_normalization=True) # blured image as init value
@@ -73,29 +74,73 @@ def filling_sequence_cuda_2d(
         print(unfixed.sum())
         logits, *_dump = model(tokens, position_ids, attention_mask)
         step_cnt += 1
-        last_logits = logits
 
         # warmup 
-        real_topk = 5
-        real_temp = 2 - min(1,((step_cnt) / iterative_step)) * 1.9
+        real_topk = 200
+        # real_temp = 0.7 #- min(1,((step_cnt) / iterative_step)) * .3
+        # real_temp = args.temperature
+        if step_cnt <= 5:
+            real_temp = 0.1
+        elif step_cnt == 6:
+            real_temp = 0.55
+        elif step_cnt > 6:
+            real_temp = 0.45
+        if  5 < step_cnt:
+            real_topk = 200
         # sampling
         for invalid_slice in invalid_slices: # forbide to generate other tokens
             logits[..., invalid_slice] = -float('Inf')
         assert args.top_k > 0
-        tk_value, tk_idx = torch.topk(logits, real_topk, dim=-1)
-        tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1])
-        prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1)
-        prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1)
+        
+        probs0 = F.softmax(logits/real_temp, dim=-1)
+        topsum = torch.topk(probs0, 20, dim=-1)[0].sum(dim=-1)
+        if step_cnt >= 6:
+            real_temp2 = torch.tensor([[[real_temp]]], device=probs0.device).expand(*probs0.shape[:2], 1) * (topsum < 0.95).unsqueeze(-1) + 0.6
+            # import pdb;pdb.set_trace()
+        else:
+            real_temp2 = real_temp
+        # import pdb;pdb.set_trace()
+        probs = F.softmax(logits/real_temp2, dim=-1)
+        tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1)
+        prev = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
+        edge_idx = tk_idx[:, :, -1:]
+        edge_value = tk_value[:, :, -1:]
+        edge_mask = probs.gather(dim=-1, index=prev) < edge_value
+        prev[edge_mask] = edge_idx[edge_mask]
+        prev.squeeze_(-1)
+        # tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1])
+        # prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1)
+        # prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1)
         # update unfixed
         choice = 1
-        if choice == 0 and step_cnt > 5:
-            mprob = tk_probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
-            dprob = (mprob[:, 1:] < 0.5) & ((mprob[:, :-1] > 0.8)| (unfixed[:, 1:-1].logical_not()))
+        if choice == 0 and 5 < step_cnt:
+            mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
+            # import pdb;pdb.set_trace()
+            dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:])
+
             new_fixed = unfixed.clone()
-            new_fixed[:, 2:] &= dprob
+            moved_new_fixed = new_fixed[:, 2:]
+            moved_new_fixed &= dprob
+            moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not()
+            moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not()
+            # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not()
+            moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not()
+            moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not()
+            # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not()
+        elif choice == 1 and 5 < step_cnt:
+            new_fixed = unfixed & False
+            x = (step_cnt-5) // 4
+            y = (step_cnt-5) % 4
+            new_fixed[..., -4096:].view(batch_size, 16, 4, 16, 4)[:, :, x, :, y] = True
+            new_fixed &= unfixed
         else:
             new_fixed = unfixed & False # TODO
         new_fixed[:, -1] = True
+
+        with open(f'bed{step_cnt}.txt', 'w') as fout:
+            for i, prob in enumerate(topsum[0, -4096:]):
+                fout.write(f'{i} {prob}\n')
+
         unfixed &= new_fixed.logical_not()
         # update seq and tokens
         seq[new_fixed] = prev[new_fixed[:, 1:]]
@@ -115,7 +160,6 @@ def filling_sequence_cuda_2d(
     if args.debug:
         imgs = torch.cat(imgs, dim=0)
         save_image(imgs, f'steps{device}.jpg', normalize=True)
-    
     model.module.transformer.max_memory_length = args.max_memory_length
 
     return seq
\ No newline at end of file
diff --git a/generation/sampling.py b/generation/sampling.py
index 8d614724a10a3f1ea1733d4c3fe23a602b70ed11..b049ba8a7f3b6e5d33813a90e8db1d81e14ffd6b 100755
--- a/generation/sampling.py
+++ b/generation/sampling.py
@@ -19,6 +19,7 @@ import torch.nn.functional as F
 
 from pretrain_gpt2 import get_masks_and_position_ids
 from data_utils import get_tokenizer
+from copy import deepcopy
 
 def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
     # This function has been mostly taken from huggingface conversational ai code at
@@ -26,8 +27,12 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
 
     if top_k > 0:
         # Remove all tokens with a probability less than the last token of the top-k
+        # s1 = (logits-logits.max()).exp().sum()
         indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
-        logits[indices_to_remove] = filter_value
+        logits[indices_to_remove] = filter_value      
+        # s2 = (logits-logits.max()).exp().sum()
+        # with open('lion.txt', 'a') as fout:
+        #     fout.write(f'{s1} {s2}\n')
 
     if top_p > 0.0:
         # convert to 1D
@@ -107,6 +112,12 @@ def filling_sequence(
             offset = context_length
         context_length += 1
     tokens, attention_mask, position_ids = get_batch(seq[:context_length], device, args)
+    txt_len = seq.tolist().index(tokenizer['[BASE]'])
+    print('txt_len:', txt_len)
+    config = deepcopy(model.module.transformer.sparse_config)
+    ori_config = model.module.transformer.sparse_config
+    config.layout[0] = txt_len
+    model.module.transformer.reset_sparse_config(config)
 
     counter = context_length - 1 # == len(tokens) - 1
     index = 0 # len(mems)
@@ -130,12 +141,12 @@ def filling_sequence(
             logits, *qkv = model(tokens, position_ids, attention_mask, *mems)
             mems = update_mems(qkv, mems)
 
-            tmp = -F.log_softmax(logits, dim=-1)
-            tmp = tmp[0,:-1].gather(dim=-1,index=tokens[0,1:].unsqueeze(-1))[4:,0]
+            # tmp = -F.log_softmax(logits, dim=-1)
+            # tmp = tmp[0,:-1].gather(dim=-1,index=tokens[0,1:].unsqueeze(-1))[4:,0]
             # for i in range(1,len(tmp)):
             #     print(i, tmp[i].item())
             index = counter
-            print(tmp[1:].mean(), file=sys.stderr)
+            # print(tmp[1:].mean(), file=sys.stderr)
         elif seq[counter + 1] >= 0: # provided
             if seq[counter + 1] == tokenizer['[ROI2]']:
                 offset = counter + 1
@@ -165,29 +176,44 @@ def filling_sequence(
         logits = logits[:, -1] # [batch size, vocab size]
 
         temp = args.temperature
+        real_topk = args.top_k
+        if counter <= context_length + 32:
+            real_topk = 80
+        # else:
+            # real_topk = args.top_k
+        # if counter == context_length + 32 + 12:
+        #     import pdb;pdb.set_trace()
         # TODO since the temperature is crucial, how can we find a good setting?
         logits /= temp
-        for invalid_slice in invalid_slices: # forbide to generate other tokens
+        for invalid_slice in invalid_slices: #   to generate other tokens
             logits[..., invalid_slice] = -float('Inf')
-        logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)
-        log_probs = F.softmax(logits, dim=-1)
+        # logits = top_k_logits(logits, top_k=real_topk, top_p=args.top_p)
+        probs = F.softmax(logits, dim=-1)
+
+        tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1)
 
         # expand beams
         if nb > 1 and tokens.shape[0] == 1: # 1->nb
             tokens = tokens.expand(nb, -1).contiguous()
             mems = [mem.expand(nb, -1, -1) for mem in mems]
-            prev = torch.multinomial(log_probs, num_samples=nb, replacement=True)
-            score = torch.log(torch.gather(log_probs, dim=1, index=prev)[0]).tolist()
+            prev = torch.multinomial(probs, num_samples=nb, replacement=True)
+            score = torch.log(torch.gather(probs, dim=1, index=prev)[0]).tolist()
         else: # nb -> nb
             assert tokens.shape[0] == nb
-            prev = torch.multinomial(log_probs, num_samples=1)
-            score_plus = torch.log(torch.gather(log_probs, dim=1, index=prev)[:, 0])
+            prev = torch.multinomial(probs, num_samples=1)
+            for j in range(0, prev.shape[0]):
+                if probs[j, prev[j,-1]] < tk_value[j, -1]:
+                    prev[j, -1] = tk_idx[j,torch.randint(tk_idx.shape[-1]-100, tk_idx.shape[-1], (1,))]
+                    # prev[j, -1] = tk_idx[j,torch.randint(0, tk_idx.shape[-1], (1,))]
+
+            score_plus = torch.log(torch.gather(probs, dim=1, index=prev)[:, 0])
             for idx in range(nb):
                 score[idx] += score_plus[idx]
         
         tokens = torch.cat((tokens, prev.view(tokens.shape[0], 1)), dim=1)
 
     output_tokens_list = tokens.view(tokens.shape[0], -1).contiguous()
+    model.module.transformer.reset_sparse_config(ori_config)
     return output_tokens_list
 
 def shrink_beams(tokens, mems, nb, score):
diff --git a/mpu/sparse_transformer.py b/mpu/sparse_transformer.py
index 53a92de78d6b0409e520dbb139774d18d37af6f7..37f3f9dbaeacff4ddcf1c500c75613bee878add6 100755
--- a/mpu/sparse_transformer.py
+++ b/mpu/sparse_transformer.py
@@ -75,12 +75,13 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
     """
     def __init__(self, hidden_size, num_attention_heads,
                  attention_dropout_prob, output_dropout_prob,
-                 init_method, output_layer_init_method=None,sparse_config=None,
+                 init_method, layer_id, output_layer_init_method=None,sparse_config=None,
                  finetune=False):
         super(GPT2ParallelSelfAttention, self).__init__()
         # Set output layer initialization if not provided.
         if output_layer_init_method is None:
             output_layer_init_method = init_method
+        self.layer_id = layer_id
         # Per attention head and per partition values.
         world_size = get_model_parallel_world_size()
         self.hidden_size_per_partition = divide(hidden_size, world_size)
@@ -177,7 +178,7 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
             key_layer = self._transpose_for_scores(mixed_key_layer)
             value_layer = self._transpose_for_scores(mixed_value_layer)
             
-            context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn)
+            context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn, layer_id=self.layer_id, txt_len=layout[0] if not self.training else -1)
             
             context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
             new_context_layer_shape = context_layer.size()[:-2] + \
@@ -193,7 +194,9 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
                 text_len=sparse_config.layout[0],
                 kernel_size=sparse_config.kernel_size,
                 kernel_size2=sparse_config.kernel_size2,
-                attention_dropout=dropout_fn
+                attention_dropout=dropout_fn,
+                text_start=(1-mask[...,-1,:]).sum().long().item()+1 if not self.training else -1,
+                layer_id=self.layer_id
             )
 
         if sparse_config.sparse_type == 'cuda_2d':
@@ -308,6 +311,7 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
                  output_dropout_prob,
                  layernorm_epsilon,
                  init_method,
+                 layer_id,
                  output_layer_init_method=None,
                  sandwich_ln=True,
                  sparse_config=argparse.Namespace(sparse_type='standard'),
@@ -317,6 +321,7 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
         # Set output layer initialization if not provided.
         if output_layer_init_method is None:
             output_layer_init_method = init_method
+        self.layer_id = layer_id
 
         # Layernorm on the input data.
         self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
@@ -328,6 +333,7 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
             attention_dropout_prob,
             output_dropout_prob,
             init_method,
+            layer_id,
             output_layer_init_method=output_layer_init_method,
             sparse_config=sparse_config,
             finetune=finetune
@@ -488,6 +494,7 @@ class GPT2ParallelTransformer(torch.nn.Module):
                 output_dropout_prob,
                 layernorm_epsilon,
                 unscaled_init_method(init_method_std),
+                layer_id,
                 output_layer_init_method=output_layer_init_method,
                 sandwich_ln=sandwich_ln,
                 sparse_config=sparse_config,
@@ -624,7 +631,7 @@ def _chunk(x, w, times):
 
     return x.as_strided(size=chunk_size, stride=chunk_stride)
 
-def standard_attention(query_layer, key_layer, value_layer, attention_mask, attention_dropout=None):
+def standard_attention(query_layer, key_layer, value_layer, attention_mask, attention_dropout=None, layer_id = -1, txt_len=-1):
     # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. 
     # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. 
 
@@ -638,9 +645,17 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, atte
         attention_scores = torch.mul(attention_scores, attention_mask) - \
                     10000.0 * (1.0 - attention_mask)
     
-    # Attention probabilities. [b, np, s, s]
+    # Attention probabilities [b, np, s, s]
     attention_probs = F.softmax(attention_scores, dim=-1)
-    
+
+    if txt_len > 0:
+        t = key_layer.shape[-2] - txt_len - 1
+        if t // 32 <= 32:
+            # line = attention_probs[..., :, 1:txt_len].max(dim=-1, keepdim=True)[0]
+            # tmask = attention_probs[..., :, 1:txt_len] >= line
+            attention_probs[..., :, 1:txt_len] *= 6 if txt_len <= 10 else 4
+            attention_probs /= attention_probs.sum(dim=-1, keepdim=True)[0]
+
     if attention_dropout is not None:
         with get_cuda_rng_tracker().fork():
             attention_probs = attention_dropout(attention_probs)
@@ -651,7 +666,7 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, atte
     return context_layer
 
 
-def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len=64, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs):
+def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len=64, kernel_size=9, kernel_size2=7, attention_dropout=None, text_start = -1, layer_id=-1, **kwargs):
     '''
     q0, k0, v0: [batch_size, 1088, hidden_size]
     q1, k1, v1: [batch_size, 4096, h2]
@@ -673,6 +688,9 @@ def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, te
     attention_scores = torch.mul(attention_scores, attention_mask) - \
                     10000.0 * (1.0 - attention_mask)
     attention_probs0 = F.softmax(attention_scores, dim=-1)
+    if text_start > 0:
+        attention_probs0[..., :, text_start:text_len-2] *= 1
+        attention_probs0 /= attention_probs0.sum(dim=-1, keepdim=True)[0]
     # local attention for level 1
     q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
     k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
@@ -701,10 +719,11 @@ def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, te
     # weighting for level 1
     probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
     context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True)
-    context1_to_1 = context1_to_1.view(b, n_head * h, l1**2)
+    context1 = context1_to_1.view(b, n_head * h, l1**2)
     # weighting for cross attention
     probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0)
     v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0)
     context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False)
     context1_to_0 = context1_to_0.view(b, n_head * h, l1**2)
-    return context0.transpose(1, 2).reshape(b, s0, h0), (context1_to_0 + context1_to_1).transpose(-1, -2)
\ No newline at end of file
+    context1 = context1 + context1_to_0
+    return context0.transpose(1, 2).reshape(b, s0, h0), context1.transpose(-1, -2)
\ No newline at end of file
diff --git a/pretrain_gpt2.py b/pretrain_gpt2.py
index 87844669294b4455e8a1687d405ac83e977f9ca2..8fb6d3b355c095d1616765b6a60f6969f9620b76 100755
--- a/pretrain_gpt2.py
+++ b/pretrain_gpt2.py
@@ -179,7 +179,6 @@ def get_learning_rate_scheduler(optimizer, args):
 
     return lr_scheduler
 
-
 def setup_model_and_optimizer(args):
     """Setup model and optimizer."""
 
diff --git a/random_display.py b/random_display.py
index 32096ddbf43b589665943d94bab91b701bec4501..7c9d56f0f0a30c323c1527d8ed1591e5723459a4 100644
--- a/random_display.py
+++ b/random_display.py
@@ -5,21 +5,24 @@ import os
 import torch
 import random
 test_dir = 'tmp'
-# bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin'
-bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/quanjing003/quanjing003.bin.part_0.cogdata'
-bin_ds = BinaryDataset(os.path.join(bin_dir), process_fn=lambda x:x, length_per_sample=64*64+32*32+64, dtype='int32', preload=False)
+# bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin'
+bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata'
+bin_ds = BinaryDataset(os.path.join(bin_dir), process_fn=lambda x:x, length_per_sample=16**2+64*64+32*32+64, dtype='int32', preload=False)
 args = argparse.Namespace(img_tokenizer_path='pretrained/vqvae/vqvae_hard_biggerset_011.pt', img_tokenizer_num_tokens=None)
 tokenizer = get_tokenizer(args)
 
 bin_ds = [bin_ds[random.randint(0, len(bin_ds)-1)] for i in range(32)]
 for x in bin_ds:
-    end = x.tolist().index(-1)
+    if x[63] != -1:
+        end = 64
+    else:
+        end = x.tolist().index(-1)
     print(tokenizer.DecodeIds(x[:end])[0])
 
 from torchvision.utils import save_image
-imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64:64+64**2], dtype=torch.long, device='cuda')) for x in bin_ds], dim=0)
-save_image(imgs, os.path.join(test_dir, 'testcase512.jpg'), normalize=True)
-imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64+64**2:64+64**2+32**2], dtype=torch.long,device='cuda')) for x in bin_ds], dim=0)
+imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64:64+16**2], dtype=torch.long, device='cuda')) for x in bin_ds], dim=0)
+save_image(imgs, os.path.join(test_dir, 'testcase128.jpg'), normalize=True)
+imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64+16**2:64+16**2+32**2], dtype=torch.long,device='cuda')) for x in bin_ds], dim=0)
 save_image(imgs, os.path.join(test_dir, 'testcase256.jpg'), normalize=True)
-# imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64+64**2+32**2:], dtype=torch.long,device='cuda')) for x in bin_ds], dim=0)
-# save_image(imgs, os.path.join(test_dir, 'testcase128.jpg'), normalize=True)
\ No newline at end of file
+imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64+16**2+32**2:], dtype=torch.long,device='cuda')) for x in bin_ds], dim=0)
+save_image(imgs, os.path.join(test_dir, 'testcase512.jpg'), normalize=True)
\ No newline at end of file
diff --git a/scripts/cuda_2d_text2image.sh b/scripts/cuda_2d_text2image.sh
index 3f03974a9f479821f415107e4cefb84efae97ed2..6a32e31627fcdc70ab9b62699312800ea6acd0e2 100755
--- a/scripts/cuda_2d_text2image.sh
+++ b/scripts/cuda_2d_text2image.sh
@@ -1,6 +1,6 @@
 #!/bin/bash
 
-CHECKPOINT_PATH=data/checkpoints/cogview-long
+CHECKPOINT_PATH=data/checkpoints/cogview-base
 # CHECKPOINT_PATH=data/checkpoints/cogview-compare
 NLAYERS=48
 NHIDDEN=2560
@@ -10,9 +10,9 @@ MASTER_PORT=$(shuf -n 1 -i 10000-65535)
 MPSIZE=1
 
 #SAMPLING ARGS
-TEMP=1.03
+TEMP=1.
 #If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p
-TOPK=100
+TOPK=200
 TOPP=0
 
 script_path=$(realpath $0)
@@ -36,13 +36,14 @@ MASTER_PORT=${MASTER_PORT} python generate_samples.py \
        --max-position-embeddings-finetune $MAXSEQLEN \
        --generation-task "cuda-2d generation" \
        --input-source ./input.txt \
-       --output-path samples_cuda_2d2 \
-       --batch-size 3 \
+       --output-path samples_cuda_2d3 \
+       --batch-size 4 \
        --max-inference-batch-size 4 \
        --device 0 \
        --finetune \
        --no-load-optim \
        --sparse-type cuda_2d \
+       --debug \
        $@
 
 
diff --git a/scripts/ds_config_zero.json b/scripts/ds_config_zero.json
index 1f0f35b84c9f2720dc39d1240cafd63a9e26b2f7..b9855836dd584b47b0e170e94e6c4d07cf9ba6f2 100755
--- a/scripts/ds_config_zero.json
+++ b/scripts/ds_config_zero.json
@@ -1,6 +1,6 @@
 {
-  "train_micro_batch_size_per_gpu": 6,
-  "gradient_accumulation_steps": 5,
+  "train_micro_batch_size_per_gpu": 2,
+  "gradient_accumulation_steps": 1,
   "steps_per_print": 1,
   "gradient_clipping": 0.1,
   "zero_optimization": {
@@ -10,7 +10,8 @@
     "overlap_comm": true,
     "reduce_scatter": true,
     "reduce_bucket_size": 100000000,
-    "allgather_bucket_size": 1000000000
+    "allgather_bucket_size": 1000000000,
+    "load_from_fp32_weights": false
   },
   "zero_allow_untested_optimizer": true,
   "fp16": {
@@ -23,13 +24,13 @@
   "optimizer": {
     "type": "Adam",
     "params": {
-      "lr": 0.0005,
+      "lr": 0.0002,
       "betas": [
         0.9,
         0.95
       ],
       "eps": 1e-8,
-      "weight_decay": 4e-2
+      "weight_decay": 1e-4
     }
   },
   "activation_checkpointing": {
diff --git a/scripts/pretrain_multiple_nodes.sh b/scripts/pretrain_multiple_nodes.sh
index a1245d3f2bd70fac8f5f1809ac6c7cdf79f0f68d..f099f8011dd4e174d0e4cd1d1dcc68afc641b392 100755
--- a/scripts/pretrain_multiple_nodes.sh
+++ b/scripts/pretrain_multiple_nodes.sh
@@ -16,12 +16,12 @@ HOST_FILE_PATH="hostfile"
 # OPTIONS_NCCL=""
 # HOST_FILE_PATH="hostfile_single"
 
-small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/zijian/zijian.bin.part_0.cogdata"
-full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin"
+small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata"
+full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin"
 
 config_json="$script_dir/ds_config_zero.json"
 gpt_options=" \
-       --experiment-name cogview-base-continue-long \
+       --experiment-name cogview-base-long \
        --img-tokenizer-num-tokens 8192 \
        --dataset-type CompactBinaryDataset \
        --model-parallel-size ${MP_SIZE} \
@@ -47,9 +47,9 @@ gpt_options=" \
        --no-load-optim \
        --no-save-optim \
        --eval-interval 1000 \
-       --save /root/checkpoints \
+       --save $main_dir/data/checkpoints \
        --fast-load \
-       --load data/checkpoints/cogview-continue \
+       --load data/checkpoints/cogview-base \
        --finetune 
 "
           
diff --git a/scripts/text2image.sh b/scripts/text2image.sh
index fdb3bf3a17ddac30b3954a50d9f219691ed13f12..4f2f5d55940a5f95a4b5e343c6b01da07d65139a 100755
--- a/scripts/text2image.sh
+++ b/scripts/text2image.sh
@@ -6,8 +6,8 @@
 # NHIDDEN=1024
 # NATT=16
 
-CHECKPOINT_PATH=data/checkpoints/cogview-continue
-# CHECKPOINT_PATH=pretrained/cogview/cogview-base
+# CHECKPOINT_PATH=data/checkpoints/cogview-base
+CHECKPOINT_PATH=pretrained/cogview/cogview-base
 NLAYERS=48
 NHIDDEN=2560
 NATT=40
@@ -16,7 +16,7 @@ MASTER_PORT=$(shuf -n 1 -i 10000-65535)
 MPSIZE=1
 
 #SAMPLING ARGS
-TEMP=1.
+TEMP=1
 #If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p
 TOPK=200
 TOPP=0
@@ -46,6 +46,7 @@ MASTER_PORT=${MASTER_PORT} python generate_samples.py \
        --batch-size 8 \
        --max-inference-batch-size 8 \
        --device 0 \
+       --debug \
        $@
 
 
diff --git a/utils.py b/utils.py
index 0d4c51d64c77299f56b9e6861b376a4a37808518..3256dce3e6fc1daf6df72aaae7f15d1ca4916de0 100755
--- a/utils.py
+++ b/utils.py
@@ -323,8 +323,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states=
     if args.deepspeed:
         
         checkpoint_name, sd = model.load_checkpoint(args.load, iteration, load_optimizer_states=not args.no_load_optim, load_module_strict=not args.finetune)
-        if args.finetune:
-            model.module.module.init_plus_from_old()
+        # if args.finetune:
+        #     model.module.module.init_plus_from_old()
         if (args.finetune or args.no_load_optim) and model.zero_optimization():
             model.optimizer.refresh_fp32_params()
         if "client_lr_scheduler" in sd and not args.finetune:
@@ -461,4 +461,4 @@ def move_weights(our, oai, dst2src=False):
     load_weights(transformer_model.wpe, our.position_embeddings, dst2src)
 
     for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h):
-        load_transformer_layer(our_layer, oai_layer, dst2src)
+        load_transformer_layer(our_layer, oai_layer, dst2src)
\ No newline at end of file