diff --git a/arguments.py b/arguments.py
index b06c987ea51f430abef05a49caa2df6734fa8a98..a9e722aa247c317239666c5f82906be55fbfb433 100755
--- a/arguments.py
+++ b/arguments.py
@@ -147,6 +147,8 @@ def add_training_args(parser):
     group.add_argument('--warmup', type=float, default=0.01,
                        help='percentage of data to warmup on (.01 = 1% of all '
                             'training iters). Default 0.01')
+    group.add_argument('--restart-iter', type=int, default=0,
+                       help='restart with warmup from this iteration.')
     # model checkpointing
     group.add_argument('--save', type=str, default=None,
                        help='Output directory to save checkpoints to.')
@@ -301,19 +303,20 @@ def add_sparse_args(parser):
     group.add_argument("--key-window-times", type=int, default=6)
     group.add_argument("--num-pivot", type=int, default=768)
     # for cuda_2d
-    group.add_argument("--kernel-size", type=int, default=11)
+    group.add_argument("--kernel-size", type=int, default=9)
     group.add_argument("--kernel-size2", type=int, default=7)
-    group.add_argument("--layout", type=str, default='0,64,1088,5184')
+    group.add_argument("--layout", type=str, default='64,1088,5184')
     return parser
 
 def make_sparse_config(args):
+    args.layout = [int(x) for x in args.layout.split(',')]
     sparse_config = argparse.Namespace(sparse_type=args.sparse_type)
     if args.sparse_type == 'standard':
         pass
     if args.sparse_type == 'cuda_2d' or args.generation_task == 'cuda-2d generation':
         sparse_config.kernel_size = args.kernel_size
         sparse_config.kernel_size2 = args.kernel_size2
-        sparse_config.layout = [int(x) for x in args.layout.split(',')]
+        sparse_config.layout = args.layout
     elif args.sparse_type == 'torch_1d':
         raise NotImplementedError
     args.sparse_config = sparse_config
diff --git a/data_utils/datasets.py b/data_utils/datasets.py
index 877b53bd896c6bc55c48514f7ed1327226292c7b..7b1203839f85f14874d96647c4f7c0fd7be7d042 100755
--- a/data_utils/datasets.py
+++ b/data_utils/datasets.py
@@ -80,16 +80,16 @@ class BinaryDataset(Dataset):
     def __getitem__(self, index):
         return self.process_fn(self.bin[index])
 
-def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset):       
+def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset): 
+    kwargs_to_dataset = {}      
 
     tokenizer = get_tokenizer()
-    if args.finetune and args.max_position_embeddings_finetune > args.max_position_embeddings:
-        ml = args.max_position_embeddings_finetune
+    if args.layout[-1] > args.max_position_embeddings:
+        ml = args.layout[-1]
     else:
         ml = args.max_position_embeddings
 
     def pad_to_len(ret):
-        
         if len(ret) < ml: # pad
             return np.concatenate((ret, 
                 np.array([tokenizer['[PAD]']] * (ml - len(ret)))),
@@ -117,14 +117,27 @@ def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset):
                 }
 
     elif dataset_type == 'CompactBinaryDataset':
+        layout = args.layout
         DS_CLASS = BinaryDataset
+        kwargs_to_dataset['length_per_sample'] = layout[-1]
         def process_fn(row):
-            text, code = row[:64].astype(np.int64), row[64:].astype(np.int64) # must 64 + 1024
-            text = text[text>-1]
-            ret = TextCodeTemplate(text, code)
-            ret, attention_mask_sep = pad_to_len(ret)
+            row = row.astype(np.int64)
+            # THIS IS Reverse order, TODO 
+            lens = list(reversed([layout[i] - layout[i-1] for i in range(1, len(layout))]))
+            codes = [row[layout[0]: layout[0]+lens[0]]]
+            if len(lens) > 1:
+                codes.append(row[layout[0]+lens[0]: layout[0]+lens[0]+lens[1]])
+            text = row[:layout[0]]
+            text = text[text>0][:layout[0] - 3] # [CLS] [BASE] [ROI1]
+            n_pad = layout[0]-3-len(text)
+            parts = [
+                np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64),
+                TextCodeTemplate(text, codes[-1]),
+                *reversed(codes[:-1])
+            ]
+            ret = np.concatenate(parts, axis=0)
             return {'text': ret, 
-                'loss_mask':  np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep))
+                'loss_mask':  np.array([0] * (n_pad+1) + [1] * (len(ret) - n_pad - 1)) # don't predict [CLS]
                 }
     elif dataset_type == 'BinaryDataset':
         DS_CLASS = BinaryDataset
@@ -134,5 +147,5 @@ def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset):
                 'loss_mask':  loss_mask
                 }
 
-    return DS_CLASS(path, process_fn)
+    return DS_CLASS(path, process_fn, **kwargs_to_dataset)
 
diff --git a/data_utils/vqvae_tokenizer.py b/data_utils/vqvae_tokenizer.py
index 13addb42fa8e20b8ae368a8e5f49b1eec206b17f..56ee251126fdcb901d4b88cea114342b1dfccdb7 100755
--- a/data_utils/vqvae_tokenizer.py
+++ b/data_utils/vqvae_tokenizer.py
@@ -50,12 +50,15 @@ class VQVAETokenizer(object):
         self.device = device
         self.image_tokens = model.quantize_t.n_embed
         self.num_tokens = model.quantize_t.n_embed
+        self.tr_normalize = transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800])
 
     def __len__(self):
         return self.num_tokens
 
-    def EncodeAsIds(self, img):
+    def EncodeAsIds(self, img, add_normalization=False):
         assert len(img.shape) == 4 # [b, c, h, w]
+        if add_normalization:
+            img = self.tr_normalize(img)
         return img2code(self.model, img)
 
     def DecodeIds(self, code, shape=None):
@@ -78,7 +81,6 @@ class VQVAETokenizer(object):
         img = tr(Image.open(path))
         if img.shape[0] == 4:
             img = img[:-1]
-        tr_normalize = transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800])
-        img = tr_normalize(img)
+        img = self.tr_normalize(img)
         img = img.unsqueeze(0).float().to(self.device) # size [1, 3, h, w]
         return img  
\ No newline at end of file
diff --git a/draw_diff.py b/draw_diff.py
index 6dd12a980c6e3b2ea859726d9ad1202331ed22b4..b4baab18c797077231138786e377f66adee62f62 100644
--- a/draw_diff.py
+++ b/draw_diff.py
@@ -23,9 +23,10 @@ transform = transforms.Compose([
             ])
 img = torchvision.io.read_image('bao.jpeg')
 img = transform(img) / 255.
-a = np.array(loadbao('bao.txt'))
-b = np.array(loadbao('bao2.txt'))
-for t in (a-b>2).nonzero()[0]:
+a = np.array(loadbao('bao2.txt'))
+b = np.array(loadbao('bao3.txt'))
+for t in (b-a>1).nonzero()[0]:
     x,y = t // 32, t % 32
     sq(img, x*16, y*16, 15, 15)
+print(a.mean(), b.mean())
 torchvision.utils.save_image(img, 'example_bao.jpg')
diff --git a/generate_samples.py b/generate_samples.py
index aff3e469b9f458bb93e63166994a663f31532933..4ae8fa02da9c751920e649b057705aab2a5587c4 100755
--- a/generate_samples.py
+++ b/generate_samples.py
@@ -41,7 +41,7 @@ from pretrain_gpt2 import get_model
 import math
 from copy import deepcopy
 from tqdm import tqdm
-from generation import get_batch, filling_sequence, add_interlacing_beam_marks, magnify, inverse_prompt_score, filling_sequence_local
+from generation import get_batch, filling_sequence, add_interlacing_beam_marks, magnify, inverse_prompt_score, filling_sequence_local, filling_sequence_cuda_2d
 from torchvision.utils import save_image
 import torch.distributed as dist
 
@@ -167,7 +167,7 @@ def generate_images_once(model, args, raw_text, seq=None, num=8, query_template=
             # profile = line_profiler.LineProfiler(model.module.forward)
             # profile = line_profiler.LineProfiler(standard_attention)
             # profile.enable()
-            fill_fn = filling_sequence_local if args.generation_task == 'cuda-2d generation' else filling_sequence
+            fill_fn = filling_sequence_cuda_2d if args.generation_task == 'cuda-2d generation' else filling_sequence
             output_tokens_list.append(fill_fn(model, seq.clone(), args))
             # torch.cuda.empty_cache()
             # profile.disable()  # 停止分析
@@ -212,7 +212,7 @@ def generate_images_once(model, args, raw_text, seq=None, num=8, query_template=
 
 def generate_images_continually(model, args):
     if args.generation_task == 'text2image':
-        query_template = '[ROI1] {} [BASE] [BOI1] [Image200]bao.jpeg'
+        query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024'
     elif args.generation_task == 'image2text':
         query_template = '[BASE] [BOI1] [Image]{} [EOI1] [ROI1] [MASK]*20'
     elif args.generation_task == 'low-level super-resolution':
@@ -222,7 +222,7 @@ def generate_images_continually(model, args):
     elif args.generation_task == 'post-selection':
         query_template = '[BASE] [BOI1] [Image]{} [EOI1] [ROI1] {}'
     elif args.generation_task == 'cuda-2d generation':
-        query_template = '[CLS] {} [BASE] [Image200]bao.jpeg [MASK]*4096'
+        query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024 [EOI1] [MASK]*4096'
     else:
         raise NotImplementedError
     for raw_text, seq, output_path in get_context(args, query_template):
diff --git a/generation/__init__.py b/generation/__init__.py
index b73ab2313cef87d5306e653894797d211f4a2c36..94f6ea1d0e968244e57ee1bf024dadb9fd1bd654 100755
--- a/generation/__init__.py
+++ b/generation/__init__.py
@@ -1,3 +1,4 @@
 from .sampling import get_batch, filling_sequence, add_interlacing_beam_marks, inverse_prompt_score
 from .magnify import magnify
-from .local_sampling import filling_sequence_local
\ No newline at end of file
+from .local_sampling import filling_sequence_local
+from .cuda_2d_sampling import filling_sequence_cuda_2d
\ No newline at end of file
diff --git a/generation/cuda_2d_sampling.py b/generation/cuda_2d_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5fe7f83935a99dcdbbbe5da57fe506c5ee700ab
--- /dev/null
+++ b/generation/cuda_2d_sampling.py
@@ -0,0 +1,121 @@
+from .sampling import *
+import math
+import sys
+from copy import deepcopy
+from torchvision.utils import save_image
+def filling_sequence_cuda_2d(
+        model, 
+        seq, 
+        args, 
+        mems=None, 
+        invalid_slices=[], 
+        iterative_step=20,
+        **kwargs):
+    '''
+        seq: [id[ROI1], 10000, 20000, id[BASE], id[BOI1], 1024 * -1/known tokens, id[EOI1], 4096 * -1..., ]
+    '''
+    tokenizer = get_tokenizer()
+    invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
+    device = seq.device
+    assert args.sparse_config.sparse_type == 'cuda_2d'
+    std_config = deepcopy(args.sparse_config)
+    std_config.sparse_type = 'standard'
+    sparse_config = args.sparse_config
+    # split two parts
+    seq0, seq1 = seq[:-4097], seq[-4097:] # +1 for EOI1
+    # generate a batch of seq0
+    model.module.transformer.reset_sparse_config(std_config)
+    args.sparse_config = std_config
+    output0 = filling_sequence(model, seq0, args)
+    model.module.transformer.reset_sparse_config(sparse_config)
+    args.sparse_config = sparse_config
+    model.module.transformer.max_memory_length = 0
+
+
+    # filter bad generation & select top N=2, TODO
+    output0 = output0
+
+    from torchvision import transforms
+    tr = transforms.Compose([
+        transforms.Resize(512), 
+    ])
+    imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1024:].tolist())) for x in output0] # ground truth
+    blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(device), add_normalization=True) # blured image as init value
+
+    # pad seq to desired shape
+    n_pad = args.layout[1] - len(seq0)
+    batch_size = output0.shape[0]
+    assert n_pad > 0, "You should truncate long input before filling."
+    seq = torch.cat((
+        torch.tensor([tokenizer['[PAD]']]* n_pad, device=seq.device, dtype=seq.dtype)
+            .unsqueeze(0).expand(batch_size, n_pad),
+        output0,
+        seq1.unsqueeze(0).expand(batch_size, len(seq1))    
+        ), dim=1
+    ) # [b, layout[-1]]
+
+    # init 
+    step_cnt = 0
+    tokens = seq[:, :-1].clone()
+    unfixed = (seq < 0)
+    # tokens[unfixed[:, :-1]] = tokens[unfixed[:, :-1]].random_(0, tokenizer.img_tokenizer.num_tokens)
+    tokens[:, -4095:] = blur64[:, :-1]
+    attention_mask = torch.ones(args.layout[1], args.layout[1]).tril().to(device)
+    attention_mask[n_pad:, :n_pad] = 0
+    position_ids = torch.cat((
+        torch.zeros(n_pad, dtype=torch.long),
+        torch.arange(0, args.layout[1] - n_pad), 
+        torch.arange(0, args.layout[2]-args.layout[1]))).to(device)
+    # iterate
+    imgs = []
+    # import pdb;pdb.set_trace()
+    while unfixed.sum() > 0:
+        print(unfixed.sum())
+        logits, *_dump = model(tokens, position_ids, attention_mask)
+        step_cnt += 1
+        last_logits = logits
+
+        # warmup 
+        real_topk = 5
+        real_temp = 2 - min(1,((step_cnt) / iterative_step)) * 1.9
+        # sampling
+        for invalid_slice in invalid_slices: # forbide to generate other tokens
+            logits[..., invalid_slice] = -float('Inf')
+        assert args.top_k > 0
+        tk_value, tk_idx = torch.topk(logits, real_topk, dim=-1)
+        tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1])
+        prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1)
+        prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1)
+        # update unfixed
+        choice = 1
+        if choice == 0 and step_cnt > 5:
+            mprob = tk_probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
+            dprob = (mprob[:, 1:] < 0.5) & ((mprob[:, :-1] > 0.8)| (unfixed[:, 1:-1].logical_not()))
+            new_fixed = unfixed.clone()
+            new_fixed[:, 2:] &= dprob
+        else:
+            new_fixed = unfixed & False # TODO
+        new_fixed[:, -1] = True
+        unfixed &= new_fixed.logical_not()
+        # update seq and tokens
+        seq[new_fixed] = prev[new_fixed[:, 1:]]
+        tokens = seq[:, :-1].clone()
+        tokens[:,1:][unfixed[:, 1:-1]] = prev[:, :-1][unfixed[:, 1:-1]]
+
+        if step_cnt == iterative_step: 
+            seq[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
+            n_unfixed = unfixed.sum(dim=-1).tolist()
+            print(f'Exit with {n_unfixed} unfixed tokens.')
+            break
+        if args.debug:
+            from torchvision.utils import save_image
+            seqt = seq.clone()
+            seqt[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
+            imgs.extend([tokenizer.img_tokenizer.DecodeIds(s[-4096:]) for s in seqt])
+    if args.debug:
+        imgs = torch.cat(imgs, dim=0)
+        save_image(imgs, f'steps{device}.jpg', normalize=True)
+    
+    model.module.transformer.max_memory_length = args.max_memory_length
+
+    return seq
\ No newline at end of file
diff --git a/generation/sampling.py b/generation/sampling.py
index 4c3acd2501e50f74589884963a87a76840d019dd..8d614724a10a3f1ea1733d4c3fe23a602b70ed11 100755
--- a/generation/sampling.py
+++ b/generation/sampling.py
@@ -132,8 +132,8 @@ def filling_sequence(
 
             tmp = -F.log_softmax(logits, dim=-1)
             tmp = tmp[0,:-1].gather(dim=-1,index=tokens[0,1:].unsqueeze(-1))[4:,0]
-            for i in range(1,len(tmp)):
-                print(i, tmp[i].item())
+            # for i in range(1,len(tmp)):
+            #     print(i, tmp[i].item())
             index = counter
             print(tmp[1:].mean(), file=sys.stderr)
         elif seq[counter + 1] >= 0: # provided
diff --git a/learning_rates.py b/learning_rates.py
index 4749ae4891383d0cbf11212d085cfd696d463131..b68e8be1d26559e41ab318c70493b388efd43e01 100755
--- a/learning_rates.py
+++ b/learning_rates.py
@@ -17,13 +17,16 @@
 import torch
 from torch.optim.lr_scheduler import _LRScheduler
 import math
+from utils import print_rank_0
+
 
 class AnnealingLR(_LRScheduler):
     """Anneals the learning rate from start to zero along a cosine curve."""
 
     DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
 
-    def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, decay_ratio=0.5):
+    def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, decay_ratio=0.5, restart_iter=0):
+        self.restart_iter = restart_iter
         assert warmup_iter <= num_iters
         self.optimizer = optimizer
         self.start_lr = start_lr
@@ -38,13 +41,16 @@ class AnnealingLR(_LRScheduler):
 
     def get_lr(self):
         # https://openreview.net/pdf?id=BJYwwY9ll pg. 4
-        if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
-            return float(self.start_lr) * self.num_iters / self.warmup_iter
+        real_num_iters = self.num_iters - self.restart_iter
+        real_end_iter = self.end_iter - self.restart_iter
+        # print_rank_0(f'real_num_iters: {real_num_iters}')
+        if self.warmup_iter > 0 and real_num_iters <= self.warmup_iter:
+            return float(self.start_lr) * real_num_iters / self.warmup_iter
         else:
             if self.decay_style == self.DECAY_STYLES[0]:
-                return self.start_lr*((self.end_iter-(self.num_iters-self.warmup_iter))/self.end_iter)
+                return self.start_lr*((real_end_iter-(real_num_iters-self.warmup_iter))/real_end_iter)
             elif self.decay_style == self.DECAY_STYLES[1]:
-                decay_step_ratio = min(1.0, (self.num_iters - self.warmup_iter) / self.end_iter)
+                decay_step_ratio = min(1.0, (real_num_iters - self.warmup_iter) / real_end_iter)
                 return self.start_lr / self.decay_ratio * (
                         (math.cos(math.pi * decay_step_ratio) + 1) * (self.decay_ratio - 1) / 2 + 1)
             elif self.decay_style == self.DECAY_STYLES[2]:
@@ -73,8 +79,9 @@ class AnnealingLR(_LRScheduler):
         return sd
 
     def load_state_dict(self, sd):
+        import pdb;pdb.set_trace()
         # self.start_lr = sd['start_lr']
-        self.warmup_iter = sd['warmup_iter']
+        # self.warmup_iter = sd['warmup_iter']
         self.num_iters = sd['num_iters']
         # self.end_iter = sd['end_iter']
         self.decay_style = sd['decay_style']
diff --git a/model/gpt2_modeling.py b/model/gpt2_modeling.py
index 6c47e98ada30ff7fb73881f4d93ba960e2693900..c854f1bce00f0a0397ebb796c1f7f17f7e3c1b4b 100755
--- a/model/gpt2_modeling.py
+++ b/model/gpt2_modeling.py
@@ -41,15 +41,14 @@ def gpt2_get_params_for_weight_decay_optimization(module):
         if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)):
             no_weight_decay_params['params'].extend(
                 [p for p in list(module_._parameters.values())
-                 if p is not None])
+                 if p is not None and p.requires_grad])
         else:
             weight_decay_params['params'].extend(
                 [p for n, p in list(module_._parameters.items())
-                 if p is not None and n != 'bias'])
+                 if p is not None and n != 'bias' and p.requires_grad])
             no_weight_decay_params['params'].extend(
                 [p for n, p in list(module_._parameters.items())
-                 if p is not None and n == 'bias'])
-
+                 if p is not None and n == 'bias' and p.requires_grad])
     return weight_decay_params, no_weight_decay_params
 
 
@@ -74,7 +73,8 @@ class GPT2Model(torch.nn.Module):
                  sandwich_ln,
                  checkpoint_num_layers=1,
                  parallel_output=True,
-                 sparse_config=argparse.Namespace(sparse_type='standard')
+                 sparse_config=argparse.Namespace(sparse_type='standard'),
+                 finetune=False
                  ):
 
         super(GPT2Model, self).__init__()
@@ -99,7 +99,8 @@ class GPT2Model(torch.nn.Module):
                                                        checkpoint_activations,
                                                        checkpoint_num_layers,
                                                        sandwich_ln=sandwich_ln,
-                                                       sparse_config=sparse_config
+                                                       sparse_config=sparse_config,
+                                                       finetune=finetune
                                                        )
 
     def forward(self, input_ids, position_ids, attention_mask, *mems):
@@ -120,3 +121,6 @@ class GPT2Model(torch.nn.Module):
             return (logits_parallel, *hidden_layers)
 
         return (mpu.gather_from_model_parallel_region(logits_parallel), *hidden_layers)
+    
+    def init_plus_from_old(self):
+        self.transformer.init_plus_from_old()
diff --git a/mpu/sparse_transformer.py b/mpu/sparse_transformer.py
index 7adc0dfbb13e25aba14aa64fa1f0285a457d2bbf..53a92de78d6b0409e520dbb139774d18d37af6f7 100755
--- a/mpu/sparse_transformer.py
+++ b/mpu/sparse_transformer.py
@@ -75,7 +75,8 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
     """
     def __init__(self, hidden_size, num_attention_heads,
                  attention_dropout_prob, output_dropout_prob,
-                 init_method, output_layer_init_method=None,sparse_config=None):
+                 init_method, output_layer_init_method=None,sparse_config=None,
+                 finetune=False):
         super(GPT2ParallelSelfAttention, self).__init__()
         # Set output layer initialization if not provided.
         if output_layer_init_method is None:
@@ -111,11 +112,30 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
             get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
             checkpoint = deepspeed.checkpointing.checkpoint
 
-        # self.offset_bias = torch.nn.Parameter(
-        #     torch.ones(num_attention_heads, sparse_config.kernel_size**2//2+1 +sparse_config.kernel_size2**2)) 
-
         self.sparse_config = sparse_config
 
+        if finetune: 
+            # build new branch
+            self.query_key_value_plus = ColumnParallelLinear(hidden_size, 3*hidden_size,
+                                                    stride=3,
+                                                    gather_output=False,
+                                                    init_method=init_method)
+            self.dense_plus = RowParallelLinear(hidden_size,
+                                       hidden_size,
+                                       input_is_parallel=True,
+                                       init_method=output_layer_init_method)
+
+    def init_plus_from_old(self):
+        self.query_key_value_plus.weight.data.copy_(self.query_key_value.weight.data)
+        if hasattr(self.query_key_value_plus, 'bias') and hasattr(self.query_key_value, 'bias'):
+            self.query_key_value_plus.bias.data.copy_(self.query_key_value.bias.data)
+        
+        self.dense_plus.weight.data.copy_(self.dense.weight.data)
+        if hasattr(self.dense_plus, 'bias') and hasattr(self.dense, 'bias'):
+            self.dense_plus.bias.data.copy_(self.dense.bias.data)
+    def reset_sparse_config(self, config):
+        self.sparse_config = config
+
     def _transpose_for_scores(self, tensor):
         """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
         size [b, np, s, hn].
@@ -128,16 +148,15 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
 
 
     def forward(self, hidden_states, mask, mem=None):
-        # import pdb;pdb.set_trace()
         sparse_config = self.sparse_config
-        # hidden_states: [b, s, h]
-        # ltor_mask: [1, 1, s, s]
-
-        # Attention heads. [b, s, hp]
-        query_length = hidden_states.size(1)
+        layout = sparse_config.layout
+        if sparse_config.sparse_type == 'cuda_2d':
+            assert hidden_states.size(1) == sparse_config.layout[-1]
+            # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]}
+            hidden_states_plus = hidden_states[:, layout[1]:]
+            hidden_states = hidden_states[:, :layout[1]]
 
         mixed_raw_layer = self.query_key_value(hidden_states)
-
         (mixed_query_layer,
             mixed_key_layer,
             mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
@@ -145,38 +164,45 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
             memk, memv = split_tensor_along_last_dim(mem, 2)
             mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1)
             mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1)
+        
+        if sparse_config.sparse_type == 'cuda_2d':
+            mixed_raw_layer_plus = self.query_key_value_plus(hidden_states_plus)
+            q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer_plus, 3)
 
         dropout_fn = self.attention_dropout if self.training else None
 
-        if sparse_config.sparse_type in ['standard', 'torch_1d']:
-            # Reshape and transpose [b, np, s, hn]
+        if sparse_config.sparse_type == 'standard':
             query_layer = self._transpose_for_scores(mixed_query_layer)
             
             key_layer = self._transpose_for_scores(mixed_key_layer)
             value_layer = self._transpose_for_scores(mixed_value_layer)
             
-            if sparse_config.sparse_type == 'standard':
-                context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn)
-            else:
-                context_layer = sparse_attention_1d(query_layer, key_layer, value_layer, sparse_config.pivot_idx, 
-                    mask, sparse_config.query_window, sparse_config.key_window_times, dropout_fn)
-                # inference: context_layer = sparse_attention_inference(query_layer, key_layer, value_layer, pivot_idx)
+            context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn)
             
             context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
             new_context_layer_shape = context_layer.size()[:-2] + \
                                     (self.hidden_size_per_partition,)
-            # [b, s, hp]
             context_layer = context_layer.view(*new_context_layer_shape)
             
         elif sparse_config.sparse_type == 'cuda_2d':
-            context_layer = sparse_attention_2dfull(mixed_query_layer, mixed_key_layer, mixed_value_layer, self.num_attention_heads_per_partition,
-                 sparse_config.layout, mask, sparse_config.kernel_size,
-                  kernel_size2=sparse_config.kernel_size2,
-                  attention_dropout=dropout_fn
-                 )
-
-        # Output. [b, s, h]
-        output = self.dense(context_layer)
+            context_layer0, context_layer1 = sparse_attention_2d_light(
+                mixed_query_layer, mixed_key_layer, mixed_value_layer,
+                q1, k1, v1,
+                mask,
+                n_head=self.num_attention_heads_per_partition,
+                text_len=sparse_config.layout[0],
+                kernel_size=sparse_config.kernel_size,
+                kernel_size2=sparse_config.kernel_size2,
+                attention_dropout=dropout_fn
+            )
+
+        if sparse_config.sparse_type == 'cuda_2d':
+            output_0 = self.dense(context_layer0)
+            output_1 = self.dense_plus(context_layer1)
+            output = torch.cat((output_0, output_1), dim=1)
+        else:
+            output = self.dense(context_layer)
+            
         if self.training:
             output = self.output_dropout(output)
         
@@ -284,7 +310,8 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
                  init_method,
                  output_layer_init_method=None,
                  sandwich_ln=True,
-                 sparse_config=argparse.Namespace(sparse_type='standard')
+                 sparse_config=argparse.Namespace(sparse_type='standard'),
+                 finetune=False
                  ):
         super(GPT2ParallelTransformerLayer, self).__init__()
         # Set output layer initialization if not provided.
@@ -302,7 +329,8 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
             output_dropout_prob,
             init_method,
             output_layer_init_method=output_layer_init_method,
-            sparse_config=sparse_config
+            sparse_config=sparse_config,
+            finetune=finetune
             )
 
         # Layernorm on the input data.
@@ -324,6 +352,10 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
 
         self.sparse_config = sparse_config
 
+    def reset_sparse_config(self, config):
+            self.sparse_config = config
+            self.attention.reset_sparse_config(config)
+    
     def forward(self, hidden_states, ltor_mask, mem=None):
         # hidden_states: [b, s, h]
         # ltor_mask: [1, 1, s, s]
@@ -419,7 +451,8 @@ class GPT2ParallelTransformer(torch.nn.Module):
                  init_method_std=0.02,
                  use_scaled_init_for_output_weights=True,
                  sandwich_ln=True,
-                 sparse_config=argparse.Namespace(sparse_type='standard')
+                 sparse_config=argparse.Namespace(sparse_type='standard'),
+                 finetune=False
                  ):
         super(GPT2ParallelTransformer, self).__init__()
         # Store activation checkpoiting flag.
@@ -441,12 +474,11 @@ class GPT2ParallelTransformer(torch.nn.Module):
         # Initialize the position embeddings.
         torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
 
-        # TODO: after testing, this is not useful.
-        # self.img_type_embeddings = torch.nn.Parameter(torch.Tensor(64, hidden_size)) 
-        # torch.nn.init.normal_(self.img_type_embeddings, mean=0.0, std=init_method_std)
-        # self.txt_type_embeddings = torch.nn.Parameter(torch.Tensor(hidden_size)) 
-        # torch.nn.init.normal_(self.txt_type_embeddings, mean=0.0, std=init_method_std)
-
+        if finetune:
+            self.position_embeddings_plus = torch.nn.Embedding(4096, # FIXME
+                                                            hidden_size)
+            # Initialize the position embeddings.
+            torch.nn.init.normal_(self.position_embeddings_plus.weight, mean=0.0, std=init_method_std)
 
         def get_layer(layer_id):
             return GPT2ParallelTransformerLayer(
@@ -458,7 +490,8 @@ class GPT2ParallelTransformer(torch.nn.Module):
                 unscaled_init_method(init_method_std),
                 output_layer_init_method=output_layer_init_method,
                 sandwich_ln=sandwich_ln,
-                sparse_config=sparse_config
+                sparse_config=sparse_config,
+                finetune=finetune
                 )
 
         # Transformer layers.
@@ -474,6 +507,16 @@ class GPT2ParallelTransformer(torch.nn.Module):
             checkpoint = deepspeed.checkpointing.checkpoint
         self.sparse_config = sparse_config
 
+    def init_plus_from_old(self):
+        self.position_embeddings_plus.weight.data.view(4, 1024, -1).copy_(self.position_embeddings.weight.data[-1024:]) # FIXME
+        for layer in self.layers:
+            layer.attention.init_plus_from_old()
+
+    def reset_sparse_config(self, config):
+            self.sparse_config = config
+            for layer in self.layers:
+                layer.reset_sparse_config(config)
+
     def forward(self, hidden_states, position_ids, attention_mask, *mems):
 
         batch_size, query_length = hidden_states.size()[:2]
@@ -495,13 +538,14 @@ class GPT2ParallelTransformer(torch.nn.Module):
                 return m
             attention_mask = build_mask_matrix(query_length, key_length, sep)
 
-        # =====================   Image & Text Type Embedding   ======================== #
-        # TODO: after testing, this is not useful.
-        # extend_len = (key_length + 63) // 64
-        # hidden_states = hidden_states + txt_indices_bool.unsqueeze(-1) * self.txt_type_embeddings.view(1, 1, -1) + \
-        #     img_indices_bool.unsqueeze(-1) * self.img_type_embeddings.expand(extend_len, 64, -1).reshape(extend_len * 64, -1)[memory_length: key_length]
-        # ===================== END OF BLOCK ======================= #
-        position_embeddings = self.position_embeddings(position_ids)
+    
+        if self.sparse_config.sparse_type == 'cuda_2d':
+            position = position_ids[..., :self.sparse_config.layout[1]]
+            position_plus = position_ids[..., self.sparse_config.layout[1]:]
+            position_embeddings = torch.cat(
+                (self.position_embeddings(position), self.position_embeddings_plus(position_plus)), dim=-2)
+        else:
+            position_embeddings = self.position_embeddings(position_ids)
         hidden_states = hidden_states + position_embeddings
         hidden_states = self.embedding_dropout(hidden_states)
 
@@ -539,10 +583,8 @@ class GPT2ParallelTransformer(torch.nn.Module):
                 hidden_states = checkpoint(custom(l, l + chunk_length), *args)
                 l += chunk_length
         else:
-            assert self.sparse_config.sparse_type == 'standard'
             for i, layer in enumerate(self.layers):
                 args = [hidden_states, attention_mask_saved]
-
                 if mems:
                     mem_i = mems[i]  
                 elif self.max_memory_length > 0:
@@ -555,11 +597,6 @@ class GPT2ParallelTransformer(torch.nn.Module):
 
         # Final layer norm.
         output = self.final_layernorm(hidden_states)
-        # if self.max_memory_length > 0: # TODO cache
-        #     if self.sparse_config.sparse_type != 'cuda_2d':
-        #         mem_layers = self.update_mems(mem_layers, mems)
-        #     else:
-        #         pass # handle update outside the model, because mems is not the full cached qkv.
 
         return (output, *mem_layers)
         
@@ -609,245 +646,65 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, atte
             attention_probs = attention_dropout(attention_probs)
     # Context layer.
     # [b, np, s, hn]
-    
-    context_layer = torch.matmul(attention_probs, value_layer)
-    return context_layer
-
-def sparse_attention_1d(q, k, v, pivot_idx, pivot_attention_mask, query_window=128, key_window_times=6, attention_dropout=None):
-    ''' Sparse Attention
-    Args:
-        q, k, v: inputs, [b, num_heads, s, hn], k is padded to n * query_window
-        pivot_idx: [b, num_pivots]
-        pivot_attention_mask: [b, s, num_pivots]
-        query_window: .
-        key_window_times: key_window = query_window * key_window_times
-    '''
-
-    b, n_head, s, hn = q.shape
-    b, n_piv = pivot_idx.shape
-    w = query_window
-
-    pivot_idx_dummy = pivot_idx.view(b, 1, n_piv, 1).expand(b, n_head, n_piv, hn)
-    # =====================   Pivot Attention   ======================== #
-    pivot_k, pivot_v = torch.gather(k, 2, pivot_idx_dummy), torch.gather(v, 2, pivot_idx_dummy)
-    attention_scores = torch.matmul(q, pivot_k.transpose(-1, -2))
-    pivot_attention_mask = pivot_attention_mask.unsqueeze(1)
-
-    attention_scores_pivot = torch.mul(attention_scores, pivot_attention_mask / math.sqrt(hn)) - 10000.0 * (1.0 - pivot_attention_mask)
-
-    attention_scores_pivot = attention_scores_pivot + math.log(s // n_piv)
-    # =====================   Window Attention   ======================= #
-    window_k = _chunk(k, query_window, key_window_times)
-    window_v = _chunk(v, query_window, key_window_times)
-    # window_k [b, n_head, s // w up int, w*times, hn]
-
-    if s % w == 0: # training # TODO args check
-        assert k.shape[2] == s
-        assert window_k.shape[2] == s // w
-        window_q = q.view(b, n_head, s // w, w, hn)        
-        attention_scores = torch.matmul(window_q, window_k.transpose(-1, -2))
-        window_attention_mask = torch.ones((w, w * key_window_times), dtype=attention_scores.dtype, device=q.device).tril_(diagonal=w * (key_window_times - 1))
-        attention_scores_window = torch.mul(attention_scores, window_attention_mask / math.sqrt(hn)) - 10000.0 * (1.0 - window_attention_mask)
-        for t in range(1, key_window_times):
-            attention_scores_window[:, :, t - 1, :, :w * key_window_times - w * t] -= 10000.0
-    else: 
-        raise ValueError('The seq_len must be exactly divided by window_size.')
-    # =====================   Joint Softmax   ======================= #
-    attention_scores_window = attention_scores_window.view(b, n_head, s, w * key_window_times)
-    attention_scores = torch.cat((attention_scores_pivot, attention_scores_window), dim=-1)
-    attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
-
-    if attention_dropout is not None:
-        with get_cuda_rng_tracker().fork():
-            attention_probs = attention_dropout(attention_probs)
-
-    context_layer = torch.matmul(attention_probs[..., :-w * key_window_times], pivot_v) + torch.einsum('bcgwk,bcgkh->bcgwh', attention_probs[..., -w * key_window_times:].view(b, n_head, s // w, w, w * key_window_times), window_v).view(b, n_head, s, hn)
 
+    context_layer = torch.matmul(attention_probs, value_layer)
     return context_layer
 
-# def sparse_attention_inference_1d(q, k, v, pivot_and_window_idx, **kwargs):
-#     '''the inference process of sparse attention.
-#     The Qs are in the same block, but seq_len mod window size might != 0.
-
-#     The Qs are the final tokens of Ks. the pivot_and_window_idx[-query_len] are Qs.
-
-#     '''
-#     b, n_head, sq, hn = q.shape
-#     sk = k.shape[2]
-#     _b, n_piv = pivot_and_window_idx.shape
-
-#     pivot_and_window_idx_dummy = pivot_and_window_idx.view(b, 1, n_piv, 1).expand(b, n_head, n_piv, hn)
-#     pivot_k, pivot_v = torch.gather(k, 2, pivot_and_window_idx_dummy), torch.gather(v, 2, pivot_and_window_idx_dummy)
-#     attention_scores = torch.matmul(q / math.sqrt(hn), pivot_k.transpose(-1, -2))
-#     if sq > 1:
-#         query_part_scores = attention_scores[:, :, -sq:, -sq:]
-#         m = torch.ones((sq, sq), device=q.device, dtype=q.dtype) * -10000.
-#         m.triu_(diagonal=1)
-#         query_part_scores += m
-
-#     attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
-
-#     context_layer = torch.matmul(attention_probs, pivot_v) 
-#     return context_layer
 
-def transpose_and_split(x, layout, n_head):
-    x = x.transpose(1, 2)
-    x = x.reshape(x.shape[0]*n_head, x.shape[1] // n_head, x.shape[2])
-    x_text = x[..., :layout[0]]
-    x0 = x[...,layout[1]:layout[2]].view(x.shape[0], x.shape[1], sqrt(layout[2] - layout[1]), sqrt(layout[2] - layout[1])).contiguous()
-    x1 = x[...,layout[2]:layout[3]].view(x.shape[0], x.shape[1], sqrt(layout[3] - layout[2]), sqrt(layout[3] - layout[2])).contiguous()
-    return x, x_text, x0, x1
-
-def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs):
+def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len=64, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs):
     '''
-    q, k, v: [batch_size, 64+1024+4096, hidden_size]
+    q0, k0, v0: [batch_size, 1088, hidden_size]
+    q1, k1, v1: [batch_size, 4096, h2]
     n_head: int
-    layout: [endoftext/startofpad, startof0, startof1, endofall]
-    attention_mask_text2d: [batch_size, sq_len, endoftext]
+    attention_mask: [batch_size, 1088, 1088]
     '''
     from .local_attention_function import f_similar, f_weighting
-    b, sq_len, hn = q.shape
-    alpha = sqrt((layout[3] - layout[2]) // (layout[2] - layout[1]))
-
-    q = q / math.sqrt(hn // n_head) # normalization
-
-    q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext]
-    k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head)
-    v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head)
-    # import pdb; pdb.set_trace()
-    # all to text
-    scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d)
-    scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0])
-    # 0 to 0
-    scores_0_to_0 = f_similar(q0, k0, kernel_size*2-1, kernel_size, True)
-    # 1 to 1
+    b, s0, h0 = q0.shape
+    b, s1, h1 = q1.shape
+    assert v1.shape[-1] == h0, 'q1, k1 can be smaller, but v1 cannot.'
+    h = h0 // n_head
+    l0, l1 = int(math.sqrt(s0-text_len)+0.0001), int(math.sqrt(s1)+0.0001)
+
+    q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
+    v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
+    k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
+    # standard attention for level 0
+    attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
+    attention_scores = torch.mul(attention_scores, attention_mask) - \
+                    10000.0 * (1.0 - attention_mask)
+    attention_probs0 = F.softmax(attention_scores, dim=-1)
+    # local attention for level 1
+    q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
+    k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
+    v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
     scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)    
-    # 1 to 0
-    scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2]
-    # softmax
-    # if 'offset_bias' in kwargs:
-    #     p1, p2 = kernel_size**2//2 + 1, kernel_size2**2
-    #     offset_bias = kwargs['offset_bias'].expand(b, n_head, p1+p2).view(b*n_head, 1, p1+p2)
-    #     scores_0_to_0 = scores_0_to_0 * offset_bias[...,:p1]
-    #     scores_1_to_1 = scores_1_to_1 * offset_bias[...,:p1]
-    #     scores_1_to_0 = scores_1_to_0 * offset_bias[...,-p2:]
-
-    scores_0 = torch.cat(
-        (scores_all_to_text[:, layout[1]:layout[2]], 
-        scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])), 
-        dim=-1)
-    scores_1 = torch.cat(
-        (scores_all_to_text[:, layout[2]:layout[3]],
-         scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]),
-         scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])),
-         dim=-1)
-    probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text]
-    probs_0 = F.softmax(scores_0, dim=-1) # 
-    probs_1 = F.softmax(scores_1, dim=-1)
-
-    if attention_dropout is not None:
-        with get_cuda_rng_tracker().fork():
-            probs_0 = attention_dropout(probs_0)
-            probs_1 = attention_dropout(probs_1)
-    # weighting
-    pad = torch.zeros(layout[1], device=q.device, dtype=q.dtype)
-    probs_all_to_text = torch.cat((
-        probs_text,
-        pad[-layout[0]:].expand(b*n_head, layout[1]-layout[0], layout[0]),
-        probs_0[:, :, :layout[0]],
-        probs_1[:, :, :layout[0]]
-    ), dim=1)
-
-    context_all_to_text = torch.einsum('bhij,bhcj->bihc', 
-        probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]), 
-        v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn)
-    
-    context_0_to_0 = f_weighting(v0, probs_0[..., layout[0]:].view_as(scores_0_to_0).contiguous(), kernel_size*2-1, kernel_size, True)
-
-    context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False)
+    # attention_probs1 = F.softmax(scores_1_to_1, dim=-1)
 
-    context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size*2-1, kernel_size, True)
-    
-    context_all_to_01 =torch.cat(
+    # cross attention
+    k0T = k0T[..., -l0**2:].reshape(b*n_head, h, l0, l0).contiguous()
+    scores_1_to_0 = f_similar(q1, k0T, kernel_size2, kernel_size2, False) # [b*n_head, l1, l1, field]
+    scores_1 = torch.cat(
         (
-            pad.expand(b*n_head, hn//n_head, layout[1]),
-            context_0_to_0.view(b*n_head, hn//n_head, layout[2]-layout[1]),
-            (context_1_to_0 + context_1_to_1).view(b*n_head, hn//n_head, layout[3]-layout[2])
-        ), dim=-1).view(b, hn, -1).transpose(1, 2)
-    return context_all_to_text + context_all_to_01 
-
-
-def sparse_attention_2dfull(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs):
-    '''
-    q, k, v: [batch_size, 64+1024+4096, hidden_size]
-    n_head: int
-    layout: [endoftext/startofpad, startof0, startof1, endofall]
-    attention_mask_text2d: [batch_size, sq_len, endoftext]
-    '''
-    from .local_attention_function import f_similar, f_weighting
-    b, sq_len, hn = q.shape
-    alpha = sqrt((layout[3] - layout[2]) // (layout[2] - layout[1]))
-
-    q = q / math.sqrt(hn // n_head) # normalization
-
-    q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext]
-    k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head)
-    v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head)
-    # import pdb; pdb.set_trace()
-    # all to text
-    scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d)
-    scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0])
-    # 0 to 0
-    if not hasattr(sparse_attention_2dfull, 'attention_mask0'):
-        sparse_attention_2dfull.attention_mask0 = torch.ones((layout[2] - layout[1], layout[2] - layout[1]), device=q.device, dtype=q.dtype).tril_()
-    attention_mask0 = sparse_attention_2dfull.attention_mask0
-    scores_0_to_0 = torch.einsum('bhi,bhj->bij', q0.view(*q0.shape[:2], -1), k0.view(*k0.shape[:2], -1)) * attention_mask0 - 10000.0 * (1.0 - attention_mask0)
-    # 1 to 1
-    scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)    
-    # 1 to 0
-    scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2]
-    # softmax
-
-    scores_0 = torch.cat(
-        (scores_all_to_text[:, layout[1]:layout[2]], 
-        scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])), 
+            scores_1_to_0.view(b*n_head, -1, scores_1_to_0.shape[3]),
+            scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3])
+        ),
         dim=-1)
-    scores_1 = torch.cat(
-        (scores_all_to_text[:, layout[2]:layout[3]],
-         scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]),
-         scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])),
-         dim=-1)
-    probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text]
-    probs_0 = F.softmax(scores_0, dim=-1) # 
-    probs_1 = F.softmax(scores_1, dim=-1)
+    attention_probs1 = F.softmax(scores_1, dim=-1)
 
     if attention_dropout is not None:
         with get_cuda_rng_tracker().fork():
-            probs_0 = attention_dropout(probs_0)
-            probs_1 = attention_dropout(probs_1)
-    # weighting
-    pad = torch.zeros(layout[1], device=q.device, dtype=q.dtype)
-    probs_all_to_text = torch.cat((
-        probs_text,
-        pad[-layout[0]:].expand(b*n_head, layout[1]-layout[0], layout[0]),
-        probs_0[:, :, :layout[0]],
-        probs_1[:, :, :layout[0]]
-    ), dim=1)
-
-    context_all_to_text = torch.einsum('bhij,bhcj->bihc', 
-        probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]), 
-        v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn)
-    
-    context_0_to_0 = torch.einsum('bcj,bij->bci', v0.view(*v0.shape[:2], -1), probs_0[..., layout[0]:].view_as(scores_0_to_0))
-
-    context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False)
-
-    context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size*2-1, kernel_size, True)
-    
-    context_all_to_01 =torch.cat(
-        (
-            pad.expand(b*n_head, hn//n_head, layout[1]),
-            context_0_to_0.view(b*n_head, hn//n_head, layout[2]-layout[1]),
-            (context_1_to_0 + context_1_to_1).view(b*n_head, hn//n_head, layout[3]-layout[2])
-        ), dim=-1).view(b, hn, -1).transpose(1, 2)
-    return context_all_to_text + context_all_to_01 
+            attention_probs0 = attention_dropout(attention_probs0)
+            attention_probs1 = attention_dropout(attention_probs1)
+        
+    # weighting for level 0
+    context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
+    # weighting for level 1
+    probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
+    context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True)
+    context1_to_1 = context1_to_1.view(b, n_head * h, l1**2)
+    # weighting for cross attention
+    probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0)
+    v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0)
+    context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False)
+    context1_to_0 = context1_to_0.view(b, n_head * h, l1**2)
+    return context0.transpose(1, 2).reshape(b, s0, h0), (context1_to_0 + context1_to_1).transpose(-1, -2)
\ No newline at end of file
diff --git a/mpu/unused_codes.py b/mpu/unused_codes.py
new file mode 100644
index 0000000000000000000000000000000000000000..23256e777d9053c75c0ce02ff619bdbf314cbd7d
--- /dev/null
+++ b/mpu/unused_codes.py
@@ -0,0 +1,247 @@
+
+def sparse_attention_1d(q, k, v, pivot_idx, pivot_attention_mask, query_window=128, key_window_times=6, attention_dropout=None):
+    ''' Sparse Attention
+    Args:
+        q, k, v: inputs, [b, num_heads, s, hn], k is padded to n * query_window
+        pivot_idx: [b, num_pivots]
+        pivot_attention_mask: [b, s, num_pivots]
+        query_window: .
+        key_window_times: key_window = query_window * key_window_times
+    '''
+
+    b, n_head, s, hn = q.shape
+    b, n_piv = pivot_idx.shape
+    w = query_window
+
+    pivot_idx_dummy = pivot_idx.view(b, 1, n_piv, 1).expand(b, n_head, n_piv, hn)
+    # =====================   Pivot Attention   ======================== #
+    pivot_k, pivot_v = torch.gather(k, 2, pivot_idx_dummy), torch.gather(v, 2, pivot_idx_dummy)
+    attention_scores = torch.matmul(q, pivot_k.transpose(-1, -2))
+    pivot_attention_mask = pivot_attention_mask.unsqueeze(1)
+
+    attention_scores_pivot = torch.mul(attention_scores, pivot_attention_mask / math.sqrt(hn)) - 10000.0 * (1.0 - pivot_attention_mask)
+
+    attention_scores_pivot = attention_scores_pivot + math.log(s // n_piv)
+    # =====================   Window Attention   ======================= #
+    window_k = _chunk(k, query_window, key_window_times)
+    window_v = _chunk(v, query_window, key_window_times)
+    # window_k [b, n_head, s // w up int, w*times, hn]
+
+    if s % w == 0: # training # TODO args check
+        assert k.shape[2] == s
+        assert window_k.shape[2] == s // w
+        window_q = q.view(b, n_head, s // w, w, hn)        
+        attention_scores = torch.matmul(window_q, window_k.transpose(-1, -2))
+        window_attention_mask = torch.ones((w, w * key_window_times), dtype=attention_scores.dtype, device=q.device).tril_(diagonal=w * (key_window_times - 1))
+        attention_scores_window = torch.mul(attention_scores, window_attention_mask / math.sqrt(hn)) - 10000.0 * (1.0 - window_attention_mask)
+        for t in range(1, key_window_times):
+            attention_scores_window[:, :, t - 1, :, :w * key_window_times - w * t] -= 10000.0
+    else: 
+        raise ValueError('The seq_len must be exactly divided by window_size.')
+    # =====================   Joint Softmax   ======================= #
+    attention_scores_window = attention_scores_window.view(b, n_head, s, w * key_window_times)
+    attention_scores = torch.cat((attention_scores_pivot, attention_scores_window), dim=-1)
+    attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
+
+    if attention_dropout is not None:
+        with get_cuda_rng_tracker().fork():
+            attention_probs = attention_dropout(attention_probs)
+
+    context_layer = torch.matmul(attention_probs[..., :-w * key_window_times], pivot_v) + torch.einsum('bcgwk,bcgkh->bcgwh', attention_probs[..., -w * key_window_times:].view(b, n_head, s // w, w, w * key_window_times), window_v).view(b, n_head, s, hn)
+
+    return context_layer
+
+
+def transpose_and_split(x, layout, n_head):
+    x = x.transpose(1, 2)
+    x = x.reshape(x.shape[0]*n_head, x.shape[1] // n_head, x.shape[2])
+    x_text = x[..., :layout[0]]
+    x0 = x[...,layout[1]:layout[2]].view(x.shape[0], x.shape[1], sqrt(layout[2] - layout[1]), sqrt(layout[2] - layout[1])).contiguous()
+    x1 = x[...,layout[2]:layout[3]].view(x.shape[0], x.shape[1], sqrt(layout[3] - layout[2]), sqrt(layout[3] - layout[2])).contiguous()
+    return x, x_text, x0, x1
+
+def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs):
+    '''
+    q, k, v: [batch_size, 64+1024+4096, hidden_size]
+    n_head: int
+    layout: [endoftext/startofpad, startof0, startof1, endofall]
+    attention_mask_text2d: [batch_size, sq_len, endoftext]
+    '''
+    from .local_attention_function import f_similar, f_weighting
+    b, sq_len, hn = q.shape
+    alpha = sqrt((layout[3] - layout[2]) // (layout[2] - layout[1]))
+
+    q = q / math.sqrt(hn // n_head) # normalization
+
+    q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext]
+    k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head)
+    v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head)
+    # import pdb; pdb.set_trace()
+    # all to text
+    scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d)
+    scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0])
+    # 0 to 0
+    scores_0_to_0 = f_similar(q0, k0, kernel_size*2-1, kernel_size, True)
+    # 1 to 1
+    scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)    
+    # 1 to 0
+    scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2]
+    # softmax
+    # if 'offset_bias' in kwargs:
+    #     p1, p2 = kernel_size**2//2 + 1, kernel_size2**2
+    #     offset_bias = kwargs['offset_bias'].expand(b, n_head, p1+p2).view(b*n_head, 1, p1+p2)
+    #     scores_0_to_0 = scores_0_to_0 * offset_bias[...,:p1]
+    #     scores_1_to_1 = scores_1_to_1 * offset_bias[...,:p1]
+    #     scores_1_to_0 = scores_1_to_0 * offset_bias[...,-p2:]
+
+    scores_0 = torch.cat(
+        (scores_all_to_text[:, layout[1]:layout[2]], 
+        scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])), 
+        dim=-1)
+    scores_1 = torch.cat(
+        (scores_all_to_text[:, layout[2]:layout[3]],
+         scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]),
+         scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])),
+         dim=-1)
+    probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text]
+    probs_0 = F.softmax(scores_0, dim=-1) # 
+    probs_1 = F.softmax(scores_1, dim=-1)
+
+    if attention_dropout is not None:
+        with get_cuda_rng_tracker().fork():
+            probs_0 = attention_dropout(probs_0)
+            probs_1 = attention_dropout(probs_1)
+    # weighting
+    pad = torch.zeros(layout[1], device=q.device, dtype=q.dtype)
+    probs_all_to_text = torch.cat((
+        probs_text,
+        pad[-layout[0]:].expand(b*n_head, layout[1]-layout[0], layout[0]),
+        probs_0[:, :, :layout[0]],
+        probs_1[:, :, :layout[0]]
+    ), dim=1)
+
+    context_all_to_text = torch.einsum('bhij,bhcj->bihc', 
+        probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]), 
+        v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn)
+    
+    context_0_to_0 = f_weighting(v0, probs_0[..., layout[0]:].view_as(scores_0_to_0).contiguous(), kernel_size*2-1, kernel_size, True)
+
+    context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False)
+
+    context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size*2-1, kernel_size, True)
+    
+    context_all_to_01 =torch.cat(
+        (
+            pad.expand(b*n_head, hn//n_head, layout[1]),
+            context_0_to_0.view(b*n_head, hn//n_head, layout[2]-layout[1]),
+            (context_1_to_0 + context_1_to_1).view(b*n_head, hn//n_head, layout[3]-layout[2])
+        ), dim=-1).view(b, hn, -1).transpose(1, 2)
+    return context_all_to_text + context_all_to_01 
+
+
+def sparse_attention_2dfull(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs):
+    '''
+    q, k, v: [batch_size, 64+1024+4096, hidden_size]
+    n_head: int
+    layout: [endoftext/startofpad, startof0, startof1, endofall]
+    attention_mask_text2d: [batch_size, sq_len, endoftext]
+    '''
+    from .local_attention_function import f_similar, f_weighting
+    b, sq_len, hn = q.shape
+    alpha = sqrt((layout[3] - layout[2]) // (layout[2] - layout[1]))
+
+    q = q / math.sqrt(hn // n_head) # normalization
+
+    q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext]
+    k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head)
+    v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head)
+    # import pdb; pdb.set_trace()
+    # all to text
+    scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d)
+    scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0])
+    # 0 to 0
+    if not hasattr(sparse_attention_2dfull, 'attention_mask0'):
+        sparse_attention_2dfull.attention_mask0 = torch.ones((layout[2] - layout[1], layout[2] - layout[1]), device=q.device, dtype=q.dtype).tril_()
+    attention_mask0 = sparse_attention_2dfull.attention_mask0
+    scores_0_to_0 = torch.einsum('bhi,bhj->bij', q0.view(*q0.shape[:2], -1), k0.view(*k0.shape[:2], -1)) * attention_mask0 - 10000.0 * (1.0 - attention_mask0)
+    # 1 to 1
+    scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)    
+    # 1 to 0
+    scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2]
+    # softmax
+
+    scores_0 = torch.cat(
+        (scores_all_to_text[:, layout[1]:layout[2]], 
+        scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])), 
+        dim=-1)
+    scores_1 = torch.cat(
+        (scores_all_to_text[:, layout[2]:layout[3]],
+         scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]),
+         scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])),
+         dim=-1)
+    probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text]
+    probs_0 = F.softmax(scores_0, dim=-1) # 
+    probs_1 = F.softmax(scores_1, dim=-1)
+
+    if attention_dropout is not None:
+        with get_cuda_rng_tracker().fork():
+            probs_0 = attention_dropout(probs_0)
+            probs_1 = attention_dropout(probs_1)
+    # weighting
+    pad = torch.zeros(layout[1], device=q.device, dtype=q.dtype)
+    probs_all_to_text = torch.cat((
+        probs_text,
+        pad[-layout[0]:].expand(b*n_head, layout[1]-layout[0], layout[0]),
+        probs_0[:, :, :layout[0]],
+        probs_1[:, :, :layout[0]]
+    ), dim=1)
+
+    context_all_to_text = torch.einsum('bhij,bhcj->bihc', 
+        probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]), 
+        v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn)
+    
+    context_0_to_0 = torch.einsum('bcj,bij->bci', v0.view(*v0.shape[:2], -1), probs_0[..., layout[0]:].view_as(scores_0_to_0))
+
+    context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False)
+
+    context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size*2-1, kernel_size, True)
+    
+    context_all_to_01 =torch.cat(
+        (
+            pad.expand(b*n_head, hn//n_head, layout[1]),
+            context_0_to_0.view(b*n_head, hn//n_head, layout[2]-layout[1]),
+            (context_1_to_0 + context_1_to_1).view(b*n_head, hn//n_head, layout[3]-layout[2])
+        ), dim=-1).view(b, hn, -1).transpose(1, 2)
+    return context_all_to_text + context_all_to_01 
+
+
+if args.sparse_config.sparse_type == 'cuda_2d':
+            layout = args.sparse_config.layout
+            unpad_indices = (data[:, :layout[1]+1] >= 0) * 10000.
+            unpad_indices[:, -1] = 9000.
+            starts = (torch.arange(layout[1]+1, device=data.device).expand_as(unpad_indices) + unpad_indices).min(dim=-1)[1]
+            layout[0] = starts.max().item()
+            attention_mask = torch.ones((batch_size, seq_length, layout[0]), device=data.device)
+            for i in range(batch_size):
+                attention_mask[i, :, starts[i]:layout[1]] = 0
+            attention_mask[:, :layout[0]].tril_()
+            attention_mask = attention_mask.unsqueeze(1)
+elif args.sparse_config.sparse_type == 'standard':
+            attention_mask = torch.ones((batch_size, seq_length, seq_length), device=data.device)
+            attention_mask.tril_()
+            # attention_mask = torch.zeros((seq_length, seq_length), device=data.device)
+            # h = w = 32
+            # k1=9
+            # layout = [10, 64, 64+h*w, 64+h*w*5]
+            # for i in range(layout[1]):
+            #     attention_mask[i, :i+1] = 1
+            # for i in range(layout[1], layout[2]):
+            #     x = (i - layout[1]) // w
+            #     y = (i - layout[1]) % w
+            #     lx = max(0, x - k1 // 2)
+            #     ly = max(0, y - k1 // 2)
+            #     rx = min(h-1, x + k1 // 2)
+            #     ry = min(w-1, y + k1 // 2)
+            #     attention_mask[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1
+            #     attention_mask[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1
+            # attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
\ No newline at end of file
diff --git a/pretrain_gpt2.py b/pretrain_gpt2.py
index 784153e8b65c70c7004525dfdf208fd69c9cf35a..87844669294b4455e8a1687d405ac83e977f9ca2 100755
--- a/pretrain_gpt2.py
+++ b/pretrain_gpt2.py
@@ -73,7 +73,8 @@ def get_model(args, sparse_config=None):
                       checkpoint_num_layers=args.checkpoint_num_layers,
                       parallel_output=True,
                       sparse_config=sparse_config if sparse_config is not None else args.sparse_config,
-                      sandwich_ln=args.sandwich_ln
+                      sandwich_ln=args.sandwich_ln,
+                      finetune=args.finetune
                       )
 
     if mpu.get_data_parallel_rank() == 0:
@@ -163,7 +164,7 @@ def get_learning_rate_scheduler(optimizer, args):
         num_iters = args.lr_decay_iters
     else:
         num_iters = args.train_iters
-    num_iters = max(1, num_iters)
+    num_iters = max(1, num_iters - args.restart_iter)
     init_step = -1
     warmup_iter = args.warmup * num_iters
     lr_scheduler = AnnealingLR(optimizer,
@@ -172,7 +173,9 @@ def get_learning_rate_scheduler(optimizer, args):
                                num_iters=num_iters,
                                decay_style=args.lr_decay_style,
                                last_iter=init_step,
-                               decay_ratio=args.lr_decay_ratio)
+                               decay_ratio=args.lr_decay_ratio,
+                               restart_iter=args.restart_iter
+                               )
 
     return lr_scheduler
 
@@ -182,6 +185,12 @@ def setup_model_and_optimizer(args):
 
     model = get_model(args)
 
+    if args.finetune:
+        model.requires_grad_(False)
+        for name, param in model.named_parameters():
+            if name.find('_plus') > 0:
+                param.requires_grad_(True)
+
     param_groups = get_optimizer_param_groups(model)
 
     if args.train_data is not None:
@@ -213,38 +222,18 @@ def get_masks_and_position_ids(data,
 
     # Attention mask (lower triangular).
     if attention_mask is None:
-        # single direction, [PAD]s are at the end of the seq, so doesn't matter.
         if args.sparse_config.sparse_type == 'cuda_2d':
-            layout = args.sparse_config.layout
-            unpad_indices = (data[:, :layout[1]+1] >= 0) * 10000.
-            unpad_indices[:, -1] = 9000.
-            starts = (torch.arange(layout[1]+1, device=data.device).expand_as(unpad_indices) + unpad_indices).min(dim=-1)[1]
-            layout[0] = starts.max().item()
-            attention_mask = torch.ones((batch_size, seq_length, layout[0]), device=data.device)
+            # single direction, [PAD]s are at the start of the seq.
+            assert loss_mask is not None
+            # loss_mask has n_pad(+1 CLS and [1:] then) zeros, so it is the same as attention_mask, reuse.
+            attention_mask = loss_mask[:, :args.layout[1]].unsqueeze(-2).expand(batch_size, args.layout[1], args.layout[1]).tril()
             for i in range(batch_size):
-                attention_mask[i, :, starts[i]:layout[1]] = 0
-            attention_mask[:, :layout[0]].tril_()
+                attention_mask[i].fill_diagonal_(1)
             attention_mask = attention_mask.unsqueeze(1)
         elif args.sparse_config.sparse_type == 'standard':
             attention_mask = torch.ones((batch_size, seq_length, seq_length), device=data.device)
             attention_mask.tril_()
-            # attention_mask = torch.zeros((seq_length, seq_length), device=data.device)
-            # h = w = 32
-            # k1=9
-            # layout = [10, 64, 64+h*w, 64+h*w*5]
-            # for i in range(layout[1]):
-            #     attention_mask[i, :i+1] = 1
-            # for i in range(layout[1], layout[2]):
-            #     x = (i - layout[1]) // w
-            #     y = (i - layout[1]) % w
-            #     lx = max(0, x - k1 // 2)
-            #     ly = max(0, y - k1 // 2)
-            #     rx = min(h-1, x + k1 // 2)
-            #     ry = min(w-1, y + k1 // 2)
-            #     attention_mask[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1
-            #     attention_mask[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1
-            # attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
-        elif args.sparse_config.sparse_type == 'torch_1d':
+        else:
             raise NotImplementedError
 
     # Loss mask.
@@ -252,26 +241,19 @@ def get_masks_and_position_ids(data,
         loss_mask = torch.ones(data.size(), dtype=data.dtype, device=data.device)
 
     # Position ids.
-    if args is not None and args.finetune and args.max_position_embeddings < args.max_position_embeddings_finetune:
-        # for each sample, find [ROI2] and split
-        # ([ROI1] text... [BOI1] img... [EOI1] [ROI2]<pos_id==1089> ...)
-        start_token = get_tokenizer()['[ROI2]']
-        tmp = torch.nonzero(data == start_token, as_tuple=False)
-        start_token_poses = [100000] * batch_size
-        for x, y in tmp:
-            start_token_poses[x] = min(start_token_poses[x], y)
-        assert 100000 not in start_token_poses, 'Some samples do not have [ROI2]!'
+    if args.sparse_config.sparse_type == 'cuda_2d':
+        assert loss_mask is not None
+        layout = args.layout
+        assert seq_length == layout[-1]
+        n_pads = seq_length - loss_mask.sum(dim=-1).long()
         position_ids = torch.zeros(batch_size, seq_length, dtype=torch.long,
                                     device=data.device)
         for i in range(batch_size):
-            sep = start_token_poses[i]
-            torch.arange(start=0, end=sep, out=position_ids[i, :sep], 
+            torch.arange(layout[1] - n_pads[i], out=position_ids[i, n_pads[i]:layout[1]], 
                 dtype=torch.long, device=data.device)
-            second_pos = 0 # reuse
-            torch.arange(start=second_pos, end=second_pos + seq_length - sep, 
-                out=position_ids[i, sep:], 
+            torch.arange(layout[2] - layout[1], 
+                out=position_ids[i, layout[1]:],
                 dtype=torch.long, device=data.device)
-        position_ids[position_ids >= args.max_position_embeddings] = args.max_position_embeddings - 1
     else:
         position_ids = torch.arange(seq_length, dtype=torch.long,
                                     device=data.device)
@@ -298,28 +280,9 @@ def get_batch(data_iterator, args, timers):
     tokens_ = data_b['text'].long()
     loss_mask = data_b['loss_mask'].float()
 
-    #FIXME change order
-    if args.sparse_config.sparse_type == 'cuda_2d':
-        assert args.sparse_config.layout[-1] == 64+32**2+64**2
-    tokens_new = tokens_.clone()
-    tokens_new[:, 64:64+32**2] = tokens_[:, 64+64**2:]
-    tokens_new[:, 64+32**2:] = tokens_[:, 64:64+64**2]
-    tokens_ = tokens_new
-    # only 32# 
-    # tokens_ = tokens_[:, :64+32**2]
-    # loss_mask = loss_mask[:, :64+32**2]
-    
-    if args.dataset_type == 'BinaryDataset':
-        labels = tokens_.contiguous()
-        loss_mask = loss_mask.contiguous()
-        tokenizer = get_tokenizer()
-        cls_token = torch.zeros(tokens_.shape[0], 1, dtype=tokens_.dtype, device=tokens_.device) + tokenizer['[CLS]']
-        tokens = torch.cat((cls_token, tokens_[:, :-1]), dim=1)
-        tokens[:, 64] = tokenizer['[BASE]']
-    else:
-        labels = tokens_[:, 1:].contiguous()
-        loss_mask = loss_mask[:, 1:].contiguous()
-        tokens = tokens_[:, :-1].contiguous()
+    labels = tokens_[:, 1:].contiguous()
+    loss_mask = loss_mask[:, 1:].contiguous()
+    tokens = tokens_[:, :-1].contiguous()
     
     attention_mask = None        
 
@@ -344,7 +307,6 @@ def forward_step(data_iterator, model, args, timers, mems):
     timers('batch generator').start()
     tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
         data_iterator, args, timers)
-
     timers('batch generator').stop()
 
     # split img & txt positions, [PAD] not included # TODO check enough
@@ -352,7 +314,6 @@ def forward_step(data_iterator, model, args, timers, mems):
     img_txt_sep = tokenizer.img_tokenizer.num_tokens
     img_indices_bool = (tokens.detach() < img_txt_sep) & (loss_mask > 0)
     txt_indices_bool = (~img_indices_bool) & (loss_mask > 0)
-
     # Forward model.
     logits, *mems = model(tokens, position_ids, attention_mask, *mems)
     losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(),
@@ -363,15 +324,14 @@ def forward_step(data_iterator, model, args, timers, mems):
 
     losses = losses.view(-1) * loss_mask
     loss = torch.sum(losses) / loss_mask.sum()
-
     # =====================   Log partial losses   ======================== #
     if args.sparse_config.sparse_type == 'cuda_2d':
         img_indices_bool2 = img_indices_bool.clone()
-        img_indices_bool2[:, :args.sparse_config.layout[2]] = False
+        img_indices_bool2[:, :args.sparse_config.layout[1]] = False
         img_loss2 = losses[img_indices_bool2.view(-1)].detach().sum() / max(img_indices_bool2.sum(), 1)
         torch.distributed.all_reduce(img_loss2.data)
         img_loss2.data = img_loss2.data / args.world_size
-        img_indices_bool[:, args.sparse_config.layout[2]:] = False
+        img_indices_bool[:, args.sparse_config.layout[1]:] = False
     else:
         img_loss2 = 0
     img_indices_bool = img_indices_bool.view(-1)
@@ -386,9 +346,6 @@ def forward_step(data_iterator, model, args, timers, mems):
     txt_loss.data = txt_loss.data / args.world_size
 
     # ===================== END OF BLOCK ======================= #
-    # import pdb;pdb.set_trace()
-    # with open('tmp_save.bin', 'wb') as fout:
-    #     torch.save(tokens, fout)
     return loss, mems, img_loss, txt_loss, img_loss2
 
 
@@ -471,7 +428,6 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
         timers('backward').start()
         lm_loss_reduced = backward_step(optimizer, model, lm_loss, args, timers)
         timers('backward').stop()
-
         # Update parameters.
         skipped_iter, complete = 0, False
         timers('optimizer').start()
@@ -674,7 +630,14 @@ def evaluate(data_iterator, model, args, timers, verbose=False):
 def evaluate_and_print_results(prefix, data_iterator, model,
                                args, timers, verbose=False, step=None, summary_writer=None):
     """Helper function to evaluate and dump results on screen."""
+    # import line_profiler
+    # profile = line_profiler.LineProfiler(model.module.module.transformer.layers[0].forward)
+    # profile.enable()
+    # torch.cuda.empty_cache()
     lm_loss = evaluate(data_iterator, model, args, timers, verbose)
+    # profile.disable()  # 停止分析
+    # import sys
+    # profile.print_stats(sys.stdout)
     lm_ppl = math.exp(min(20, lm_loss))
     report_evaluate_metrics(summary_writer, prefix, lm_loss, lm_ppl, step)
 
diff --git a/random_display.py b/random_display.py
index c59c51adaf56937df7c1bcef072759890006ca21..32096ddbf43b589665943d94bab91b701bec4501 100644
--- a/random_display.py
+++ b/random_display.py
@@ -6,12 +6,12 @@ import torch
 import random
 test_dir = 'tmp'
 # bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin'
-bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/quanjing005/quanjing005.bin.part_0.cogdata'
+bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/quanjing003/quanjing003.bin.part_0.cogdata'
 bin_ds = BinaryDataset(os.path.join(bin_dir), process_fn=lambda x:x, length_per_sample=64*64+32*32+64, dtype='int32', preload=False)
 args = argparse.Namespace(img_tokenizer_path='pretrained/vqvae/vqvae_hard_biggerset_011.pt', img_tokenizer_num_tokens=None)
 tokenizer = get_tokenizer(args)
 
-bin_ds = [bin_ds[random.randint(0, len(bin_ds)-1)] for i in range(16)]
+bin_ds = [bin_ds[random.randint(0, len(bin_ds)-1)] for i in range(32)]
 for x in bin_ds:
     end = x.tolist().index(-1)
     print(tokenizer.DecodeIds(x[:end])[0])
diff --git a/scripts/cuda_2d_text2image.sh b/scripts/cuda_2d_text2image.sh
index 46783ccda9f37af232fd7f46de97c96463d5fccf..3f03974a9f479821f415107e4cefb84efae97ed2 100755
--- a/scripts/cuda_2d_text2image.sh
+++ b/scripts/cuda_2d_text2image.sh
@@ -1,16 +1,16 @@
 #!/bin/bash
 
-CHECKPOINT_PATH=data/checkpoints/cogview-fixgrad-small08-25-09-38
+CHECKPOINT_PATH=data/checkpoints/cogview-long
 # CHECKPOINT_PATH=data/checkpoints/cogview-compare
-NLAYERS=16
-NHIDDEN=1024
-NATT=16
+NLAYERS=48
+NHIDDEN=2560
+NATT=40
 MAXSEQLEN=5184
 MASTER_PORT=$(shuf -n 1 -i 10000-65535)
 MPSIZE=1
 
 #SAMPLING ARGS
-TEMP=1.05
+TEMP=1.03
 #If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p
 TOPK=100
 TOPP=0
@@ -25,22 +25,24 @@ MASTER_PORT=${MASTER_PORT} python generate_samples.py \
        --hidden-size $NHIDDEN \
        --load $CHECKPOINT_PATH \
        --num-attention-heads $NATT \
-       --max-position-embeddings 5184 \
+       --max-position-embeddings 1089 \
        --fp16 \
        --temperature $TEMP \
        --top_k $TOPK \
        --top_p $TOPP \
        --sandwich-ln \
        --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \
-       --sparse-type standard \
+       --sparse-type cuda_2d \
        --max-position-embeddings-finetune $MAXSEQLEN \
        --generation-task "cuda-2d generation" \
        --input-source ./input.txt \
-       --output-path samples_text2image \
-       --batch-size 2 \
+       --output-path samples_cuda_2d2 \
+       --batch-size 3 \
        --max-inference-batch-size 4 \
        --device 0 \
-       --sparse-type standard \
+       --finetune \
+       --no-load-optim \
+       --sparse-type cuda_2d \
        $@
 
 
diff --git a/scripts/pretrain_multiple_nodes.sh b/scripts/pretrain_multiple_nodes.sh
index a03cf273fcb587611f27c318d469a96d62af2366..a1245d3f2bd70fac8f5f1809ac6c7cdf79f0f68d 100755
--- a/scripts/pretrain_multiple_nodes.sh
+++ b/scripts/pretrain_multiple_nodes.sh
@@ -2,7 +2,7 @@
 
 # Change for multinode config
 
-NUM_WORKERS=10
+NUM_WORKERS=19
 NUM_GPUS_PER_WORKER=8
 MP_SIZE=1
 
@@ -12,23 +12,22 @@ main_dir=$(dirname $script_dir)
 
 # OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=bond0 NCCL_IB_GID_INDEX=3 NCCL_NET_GDR_LEVEL=0"
 OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
-HOST_FILE_PATH="hostfile2"
+HOST_FILE_PATH="hostfile"
 # OPTIONS_NCCL=""
 # HOST_FILE_PATH="hostfile_single"
 
 small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/zijian/zijian.bin.part_0.cogdata"
 full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin"
 
-config_json="$script_dir/ds_config.json"
+config_json="$script_dir/ds_config_zero.json"
 gpt_options=" \
-       --experiment-name cogview-fixgrad-small-test \
+       --experiment-name cogview-base-continue-long \
        --img-tokenizer-num-tokens 8192 \
-       --dataset-type BinaryDataset \
+       --dataset-type CompactBinaryDataset \
        --model-parallel-size ${MP_SIZE} \
-       --num-layers 16 \
-       --hidden-size 1024 \
-       --num-attention-heads 16 \
-       --save $main_dir/data/checkpoints \
+       --num-layers 48 \
+       --hidden-size 2560 \
+       --num-attention-heads 40 \
        --train-iters 300000 \
        --resume-dataloader \
        --train-data ${full_data} \
@@ -38,24 +37,35 @@ gpt_options=" \
        --warmup .1 \
        --checkpoint-activations \
        --deepspeed-activation-checkpointing \
-       --max-position-embeddings 5184 \
+       --max-position-embeddings 1089 \
        --max-memory-length 0 \
        --sandwich-ln \
-       --txt-loss-scale 10 \
+       --txt-loss-scale 0.1 \
        --sparse-type cuda_2d \
        --fp16 \
        --save-interval 2000 \
-       --load data/checkpoints/cogview-compare
+       --no-load-optim \
+       --no-save-optim \
+       --eval-interval 1000 \
+       --save /root/checkpoints \
+       --fast-load \
+       --load data/checkpoints/cogview-continue \
+       --finetune 
 "
-       #        
+          
+#        --finetune
+       # --save $main_dir/data/checkpoints \
+       #         --restart-iter 199000 
+      
+
 
 
 
 gpt_options="${gpt_options}
-               --deepspeed \
-               --deepspeed_config ${config_json} \
+       --deepspeed \
+       --deepspeed_config ${config_json} \
 "
-
+              
 
 run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_gpt2.py $@ ${gpt_options}"
 echo ${run_cmd}
diff --git a/scripts/testnan.sh b/scripts/testnan.sh
index a263aa1747a80ac99eb8efa1f57584cff67826b8..2095c0cbad61617442c22799f1938d63fe3cbec2 100755
--- a/scripts/testnan.sh
+++ b/scripts/testnan.sh
@@ -21,11 +21,11 @@ config_json="$script_dir/ds_config.json"
 gpt_options=" \
        --experiment-name cogview-testlocal \
        --img-tokenizer-num-tokens 8192 \
-       --dataset-type BinaryDataset \
+       --dataset-type CompactBinaryDataset \
        --model-parallel-size ${MP_SIZE} \
-       --num-layers 16 \
-       --hidden-size 1024 \
-       --num-attention-heads 16 \
+       --num-layers 48 \
+       --hidden-size 2560 \
+       --num-attention-heads 40 \
        --save $main_dir/data/checkpoints \
        --train-iters 100000 \
        --resume-dataloader \
@@ -36,15 +36,19 @@ gpt_options=" \
        --warmup .1 \
        --checkpoint-activations \
        --deepspeed-activation-checkpointing \
-       --max-position-embeddings 5184 \
+       --max-position-embeddings 1089 \
        --max-memory-length 0 \
-       --txt-loss-scale 2 \
+       --txt-loss-scale 1 \
        --sandwich-ln \
-       --sparse-type cuda_2d \
+       --sparse-type standard \
        --save-interval 2500 \
-       --load data/checkpoints/cogview-fixgrad-small08-25-09-38
+       --fp16 \
+       --eval-iters 1000 \
+       --load pretrained/cogview/cogview-base
 "
-       # --fp16 \
+       # 
+              # --load data/checkpoints/cogview-fixgrad-small08-25-09-38
+
 
 gpt_options="${gpt_options}
 
diff --git a/scripts/text2image.sh b/scripts/text2image.sh
index 38509fbd8c2fd8ad07e6d59249a827bc9a0b4e8e..fdb3bf3a17ddac30b3954a50d9f219691ed13f12 100755
--- a/scripts/text2image.sh
+++ b/scripts/text2image.sh
@@ -6,7 +6,8 @@
 # NHIDDEN=1024
 # NATT=16
 
-CHECKPOINT_PATH=pretrained/cogview/cogview-base
+CHECKPOINT_PATH=data/checkpoints/cogview-continue
+# CHECKPOINT_PATH=pretrained/cogview/cogview-base
 NLAYERS=48
 NHIDDEN=2560
 NATT=40
@@ -42,8 +43,8 @@ MASTER_PORT=${MASTER_PORT} python generate_samples.py \
        --generation-task text2image \
        --input-source ./input.txt \
        --output-path samples_text2image \
-       --batch-size 4 \
-       --max-inference-batch-size 4 \
+       --batch-size 8 \
+       --max-inference-batch-size 8 \
        --device 0 \
        $@
 
diff --git a/test_sparse_attention.py b/test_sparse_attention.py
index abac1ed99ef230c2054eb997fe815c5cd119c9e0..9716123c45b87b6078201b45937baad70ed7d32a 100644
--- a/test_sparse_attention.py
+++ b/test_sparse_attention.py
@@ -4,8 +4,7 @@ from tqdm import tqdm
 
 import torch
 import numpy as np
-from mpu.sparse_transformer import standard_attention, sparse_attention_1d, sparse_attention_2d, sparse_attention_2dfull
-
+from mpu.sparse_transformer import standard_attention, sparse_attention_2d_light
 def test_sparse_attention_1d():       
     s, w, times = 4096 + 128, 128, 2
     num_pivot = 768
@@ -80,7 +79,7 @@ def test_sparse_attention_2d():
     device = 'cuda'
     b, n_head, hn = 2, 16, 1024
     h = w = 32
-    layout = [10, 64, 64+h*w, 64+h*w*5]
+    layout = [64, 64+h*w, 64+h*w*5]
     k1 = 9
     k2 = 7
     k1h = k1*2-1
@@ -94,9 +93,6 @@ def test_sparse_attention_2d():
     m = mask[0]
     for i in range(layout[1]):
         m[i, :i+1] = 1
-    m[layout[1]:, :layout[0]] = 1
-    for i in tqdm(range(layout[1], layout[2])):
-        m[i, layout[1]:i+1] = 1
     # for i in tqdm(range(layout[1], layout[2])):
     #     x = (i - layout[1]) // w
     #     y = (i - layout[1]) % w
@@ -106,15 +102,15 @@ def test_sparse_attention_2d():
     #     ry = min(w-1, y + k1 // 2)
     #     m[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1
     #     m[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1
-    for i in tqdm(range(layout[2], layout[3])):
-        x = (i - layout[2]) // (2*w)
-        y = (i - layout[2]) % (2*w)
+    for i in tqdm(range(layout[1], layout[2])):
+        x = (i - layout[1]) // (2*w)
+        y = (i - layout[1]) % (2*w)
         lx = max(0, x - k1h // 2)
         ly = max(0, y - k1 // 2)
         rx = min(2*h-1, x + k1h // 2)
         ry = min(2*w-1, y + k1 // 2)
-        m[i, layout[2]:layout[3]].view(h*2, w*2)[lx:x, ly:ry+1] = 1
-        m[i, layout[2]:layout[3]].view(h*2, w*2)[x, ly:y+1] = 1
+        m[i, layout[1]:layout[2]].view(h*2, w*2)[lx:x, ly:ry+1] = 1
+        m[i, layout[1]:layout[2]].view(h*2, w*2)[x, ly:y+1] = 1
 
         x = x // 2
         y = y // 2
@@ -122,7 +118,7 @@ def test_sparse_attention_2d():
         ly = max(0, y - k2 // 2)
         rx = min(h-1, x + k2 // 2)
         ry = min(w-1, y + k2 // 2)
-        m[i, layout[1]:layout[2]].view(h, w)[lx:rx+1, ly:ry+1] = 1
+        m[i, layout[0]:layout[1]].view(h, w)[lx:rx+1, ly:ry+1] = 1
     
     mask[1:] = mask[0]
     # mask[1][layout[1]:, layout[0]-1] = 0
@@ -133,15 +129,18 @@ def test_sparse_attention_2d():
     torch.cuda.synchronize()
     t0 = time.time()
     qkv_tmp = qkv.view(3, b, layout[-1], n_head, hn//n_head).permute(0, 1, 3, 2, 4).contiguous()
-    r1 = standard_attention(*qkv_tmp, mask.unsqueeze(1)).transpose(1, 2).reshape(b, layout[3], hn)
+    r1 = standard_attention(*qkv_tmp, mask.unsqueeze(1)).transpose(1, 2).reshape(b, layout[2], hn)
     
     torch.cuda.synchronize()
     t1 = time.time()
-    r2 = sparse_attention_2dfull(*qkv2, n_head, layout, mask[...,:layout[0]].unsqueeze(1), kernel_size=k1, kernel_size2=k2)
+    # r2 = sparse_attention_2dfull(*qkv2, n_head, layout, mask[...,:layout[0]].unsqueeze(1), kernel_size=k1, kernel_size2=k2)
+    qkv20, qkv21 = qkv2[:, :, :layout[1]], qkv2[:, :, layout[1]:]
+    r20, r21 = sparse_attention_2d_light(*qkv20, *qkv21, mask[...,:layout[1],:layout[1]].unsqueeze(1), n_head, layout[0],kernel_size=k1, kernel_size2=k2)
+    r2 = torch.cat((r20, r21), dim=1)
     torch.cuda.synchronize()
     t2 = time.time()
     print('times: standard ', t1-t0, ' sparse ', t2-t1)
-    print(( (r1[:,:layout[0]]-r2[:,:layout[0]]).abs() / (r1[:,:layout[0]].abs()+r2[:,:layout[0]].abs())).max())
+    print(( (r1[:,:layout[1]]-r2[:,:layout[1]]).abs() / (r1[:,:layout[1]].abs()+r2[:,:layout[1]].abs())).max())
     print(( (r1[:,layout[1]:]-r2[:,layout[1]:]).abs() / (r1[:,layout[1]:].abs()+r2[:,layout[1]:].abs())).max())
     qkv.retain_grad()
     l2 = r2[:,layout[1]:].sum()
@@ -153,7 +152,7 @@ def test_sparse_attention_2d():
     g1 = qkv.grad
     g2 = qkv2.grad
     print( (g1-g2).abs().max())
-    print( ((g1-g2).abs() / (g1.abs()+g2.abs()+1e-5)).max())
+    print( ((g1-g2).abs() / (g1.abs()+g2.abs()+1e-3)).max())
 
     import pdb;pdb.set_trace()
     
diff --git a/utils.py b/utils.py
index c5079ce515c1559616b3e6d24dbf365791b82cb9..0d4c51d64c77299f56b9e6861b376a4a37808518 100755
--- a/utils.py
+++ b/utils.py
@@ -248,8 +248,34 @@ def save_ds_checkpoint(iteration, model, lr_scheduler, args):
         sd['torch_rng_state'] = torch.get_rng_state()
         sd['cuda_rng_state'] = torch.cuda.get_rng_state()
         sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
+    if args.no_save_optim:
+        save_ds_checkpoint_no_optim(model, args.save, str(iteration), client_state=sd)
+    else:
+        model.save_checkpoint(args.save, str(iteration), client_state=sd)
+
+def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save_latest=True):
+    
+    os.makedirs(save_dir, exist_ok=True)
+
+    if tag is None:
+        tag = f"global_step{model.global_steps}"
+
+    # Ensure tag is a string
+    tag = str(tag)
+
+    # Ensure checkpoint tag is consistent across ranks
+    model._checkpoint_tag_validation(tag)
+
+    if model.save_non_zero_checkpoint:
+        model._create_checkpoint_file(save_dir, tag, False)
+        model._save_checkpoint(save_dir, tag, client_state=client_state)
+
+    # Save latest checkpoint tag
+    if save_latest:
+        with open(os.path.join(save_dir, 'latest'), 'w') as fd:
+            fd.write(tag)
 
-    model.save_checkpoint(args.save, str(iteration), client_state=sd)
+    return True
 
 
 def get_checkpoint_iteration(args):
@@ -296,8 +322,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states=
 
     if args.deepspeed:
         
-        checkpoint_name, sd = model.load_checkpoint(args.load, iteration, load_optimizer_states=not args.no_load_optim)
-        if "client_lr_scheduler" in sd:
+        checkpoint_name, sd = model.load_checkpoint(args.load, iteration, load_optimizer_states=not args.no_load_optim, load_module_strict=not args.finetune)
+        if args.finetune:
+            model.module.module.init_plus_from_old()
+        if (args.finetune or args.no_load_optim) and model.zero_optimization():
+            model.optimizer.refresh_fp32_params()
+        if "client_lr_scheduler" in sd and not args.finetune:
             lr_scheduler.load_state_dict(sd["client_lr_scheduler"])
             print_rank_0("Load lr scheduler state")
         if checkpoint_name is None: