diff --git a/arguments.py b/arguments.py
index 067a9a7cc579a49bc0374963f46e836b1636076c..b06c987ea51f430abef05a49caa2df6734fa8a98 100755
--- a/arguments.py
+++ b/arguments.py
@@ -218,7 +218,8 @@ def add_text_generate_args(parser):
                                 'super-resolution',
                                 'low-level super-resolution',
                                 'post-selection',
-                                'raw'
+                                'raw',
+                                'cuda-2d generation'
                                 ],
                        help='what type of inference task to use')
     group.add_argument('--input-source', type=str, default='interactive',
@@ -300,7 +301,7 @@ def add_sparse_args(parser):
     group.add_argument("--key-window-times", type=int, default=6)
     group.add_argument("--num-pivot", type=int, default=768)
     # for cuda_2d
-    group.add_argument("--kernel-size", type=int, default=9)
+    group.add_argument("--kernel-size", type=int, default=11)
     group.add_argument("--kernel-size2", type=int, default=7)
     group.add_argument("--layout", type=str, default='0,64,1088,5184')
     return parser
@@ -309,7 +310,7 @@ def make_sparse_config(args):
     sparse_config = argparse.Namespace(sparse_type=args.sparse_type)
     if args.sparse_type == 'standard':
         pass
-    elif args.sparse_type == 'cuda_2d':
+    if args.sparse_type == 'cuda_2d' or args.generation_task == 'cuda-2d generation':
         sparse_config.kernel_size = args.kernel_size
         sparse_config.kernel_size2 = args.kernel_size2
         sparse_config.layout = [int(x) for x in args.layout.split(',')]
diff --git a/data_utils/configure_data.py b/data_utils/configure_data.py
index 4e4cb0b5ca4d5ba382777eba71a71fb6ab7d248d..08d6a59d9176f778819c5bf70703254ea5abafe0 100755
--- a/data_utils/configure_data.py
+++ b/data_utils/configure_data.py
@@ -37,12 +37,13 @@ def make_data_loader(dataset, batch_size, num_iters, args):
     drop_last = distributed
     # the GPUs in the same model parallel group receive the same data
     if distributed:
+        gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1)
         batch_sampler = DistributedBatchSampler(sampler,
                                                                     batch_size,
                                                                     drop_last,
                                                                     rank,
                                                                     world_size,
-                                                                    gradient_accumulation_steps=args.gradient_accumulation_steps)
+                                                                    gradient_accumulation_steps=gradient_accumulation_steps)
     else:
         batch_sampler = torch.utils.data.BatchSampler(sampler,
                                                         batch_size,
diff --git a/draw_diff.py b/draw_diff.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dd12a980c6e3b2ea859726d9ad1202331ed22b4
--- /dev/null
+++ b/draw_diff.py
@@ -0,0 +1,31 @@
+import numpy as np
+import torch
+def loadbao(name):
+    ret = []
+    with open(name, 'r') as fin:
+        for line in fin:
+            a, b = line.split()
+            ret.append(abs(float(b)))
+    return ret
+import torchvision
+import torchvision.transforms as transforms
+
+def sq(img, x, y, lx, ly):
+    assert len(img.shape) == 3
+    img[:,x:x+lx,y] = torch.tensor([0,1,0]).unsqueeze(-1)
+    img[:,x:x+lx,y+ly] = torch.tensor([0,1,0]).unsqueeze(-1)
+    img[:,x,y:y+ly] = torch.tensor([0,1,0]).unsqueeze(-1)
+    img[:,x+lx,y:y+ly] = torch.tensor([0,1,0]).unsqueeze(-1)
+
+transform = transforms.Compose([
+                transforms.Resize(512),
+                transforms.CenterCrop(512),
+            ])
+img = torchvision.io.read_image('bao.jpeg')
+img = transform(img) / 255.
+a = np.array(loadbao('bao.txt'))
+b = np.array(loadbao('bao2.txt'))
+for t in (a-b>2).nonzero()[0]:
+    x,y = t // 32, t % 32
+    sq(img, x*16, y*16, 15, 15)
+torchvision.utils.save_image(img, 'example_bao.jpg')
diff --git a/env/baai_setup_connection.py b/env/baai_setup_connection.py
new file mode 100644
index 0000000000000000000000000000000000000000..420e90c136e0831f7ea8028326aebb94d9e0266c
--- /dev/null
+++ b/env/baai_setup_connection.py
@@ -0,0 +1,17 @@
+import json
+import os
+import sys
+
+with open('/home/hostfile.json', 'r') as fin:
+    t = json.load(fin)
+input_txt_path = os.path.join(os.path.dirname(__file__), 'input.txt')
+with open(input_txt_path, 'w') as fout:
+    ip_list = []
+    for x in t:
+        fout.write(x['ip'])
+        fout.write(' ')
+        ip_list.append(x['ip'])
+sys.path.append(os.path.dirname(__file__))
+from setup_connection import main
+main(ip_list, 22)
+
diff --git a/env/setup_connection.py b/env/setup_connection.py
index 71a41340acb5738fd4e3dea11b9af01d59528a43..a17cee03841a0e56345187b31d1b8fef6829e631 100644
--- a/env/setup_connection.py
+++ b/env/setup_connection.py
@@ -13,19 +13,20 @@ import math
 import random
 import base64
 
-if __name__ == "__main__":
+def main(ip_list, port=2222):
     ssh_config = ''
-    line_format = 'Host node{}\n\tUser root\n\tPort 2222\n\tHostname {}\n'
-    for i, ip in enumerate(sys.argv[1:]):
-        ssh_config += line_format.format(i, ip)
+    line_format = 'Host node{}\n\tUser root\n\tPort {}\n\tHostname {}\n'
+    for i, ip in enumerate(ip_list):
+        ssh_config += line_format.format(i, port, ip)
     
     ret = os.system(f'echo \"{ssh_config}\" > ~/.ssh/config && chmod 600 ~/.ssh/config')
     assert ret == 0
 
     hostfile_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'hostfile')
     with open(hostfile_path, 'w') as fout:
-        for i, ip in enumerate(sys.argv[1:]):
+        for i, ip in enumerate(ip_list):
             fout.write(f'node{i} slots=8\n')
     print(f'Successfully generating hostfile \'{hostfile_path}\'!')
 
-
+if __name__ == "__main__":
+    main(sys.argv[1:])
\ No newline at end of file
diff --git a/generate_samples.py b/generate_samples.py
index b31630b3208e10a78061b50085bc7ecf37927dca..aff3e469b9f458bb93e63166994a663f31532933 100755
--- a/generate_samples.py
+++ b/generate_samples.py
@@ -41,14 +41,13 @@ from pretrain_gpt2 import get_model
 import math
 from copy import deepcopy
 from tqdm import tqdm
-from generation import get_batch, filling_sequence, add_interlacing_beam_marks, magnify, inverse_prompt_score
+from generation import get_batch, filling_sequence, add_interlacing_beam_marks, magnify, inverse_prompt_score, filling_sequence_local
 from torchvision.utils import save_image
 import torch.distributed as dist
 
 
 def setup_model(args):
     """Setup model and optimizer."""
-
     model = get_model(args)
 
     if args.load is not None:
@@ -163,12 +162,13 @@ def generate_images_once(model, args, raw_text, seq=None, num=8, query_template=
         assert num < mbz or num % mbz == 0
         output_tokens_list = []
         for tim in range(max(num // mbz, 1)):
-            import line_profiler
-            from mpu.sparse_transformer import standard_attention
+            # import line_profiler
+            # from mpu.sparse_transformer import standard_attention
             # profile = line_profiler.LineProfiler(model.module.forward)
             # profile = line_profiler.LineProfiler(standard_attention)
             # profile.enable()
-            output_tokens_list.append(filling_sequence(model, seq.clone(), args))
+            fill_fn = filling_sequence_local if args.generation_task == 'cuda-2d generation' else filling_sequence
+            output_tokens_list.append(fill_fn(model, seq.clone(), args))
             # torch.cuda.empty_cache()
             # profile.disable()  # 停止分析
             # import sys
@@ -182,8 +182,11 @@ def generate_images_once(model, args, raw_text, seq=None, num=8, query_template=
         for seq in output_tokens_list:
             decoded_txts, decoded_imgs = tokenizer.DecodeIds(seq.tolist())
             for i in range(len(decoded_imgs)):
-                if decoded_imgs[i].shape[-1] == 128:
-                    decoded_imgs[i] = torch.nn.functional.interpolate(decoded_imgs[i], size=(256, 256))
+                if decoded_imgs[i].shape[-1] < 512:
+                    decoded_imgs[i] = torch.nn.functional.interpolate(decoded_imgs[i], size=(512, 512))
+                # decoded_imgs[i].view(3, 32, 16, 32, 16)[:, :, :4, :, :4] = 0
+                # decoded_imgs[i].view(3, 32, 16, 32, 16)[0, :, :4, :, :4] = 1
+                # decoded_imgs[i].view(3, 32, 16, 32, 16)[1, :12, :4, :16, :4] = 1
             if args.debug:
                 imgs.extend(decoded_imgs)
             else:
@@ -209,7 +212,7 @@ def generate_images_once(model, args, raw_text, seq=None, num=8, query_template=
 
 def generate_images_continually(model, args):
     if args.generation_task == 'text2image':
-        query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024'
+        query_template = '[ROI1] {} [BASE] [BOI1] [Image200]bao.jpeg'
     elif args.generation_task == 'image2text':
         query_template = '[BASE] [BOI1] [Image]{} [EOI1] [ROI1] [MASK]*20'
     elif args.generation_task == 'low-level super-resolution':
@@ -218,6 +221,8 @@ def generate_images_continually(model, args):
         query_template = '[ROI1] {} [BASE] [BOI1] [Image]{}'
     elif args.generation_task == 'post-selection':
         query_template = '[BASE] [BOI1] [Image]{} [EOI1] [ROI1] {}'
+    elif args.generation_task == 'cuda-2d generation':
+        query_template = '[CLS] {} [BASE] [Image200]bao.jpeg [MASK]*4096'
     else:
         raise NotImplementedError
     for raw_text, seq, output_path in get_context(args, query_template):
diff --git a/generation/__init__.py b/generation/__init__.py
index 8620c9c18c17f4c46a59f42abf6adc16785a9ab2..b73ab2313cef87d5306e653894797d211f4a2c36 100755
--- a/generation/__init__.py
+++ b/generation/__init__.py
@@ -1,2 +1,3 @@
 from .sampling import get_batch, filling_sequence, add_interlacing_beam_marks, inverse_prompt_score
-from .magnify import magnify
\ No newline at end of file
+from .magnify import magnify
+from .local_sampling import filling_sequence_local
\ No newline at end of file
diff --git a/generation/local_sampling.py b/generation/local_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fba71aceb1c710076302889928f49f5bbafb1d1
--- /dev/null
+++ b/generation/local_sampling.py
@@ -0,0 +1,147 @@
+from .sampling import *
+import math
+import sys
+from copy import deepcopy
+
+def make_local_mask(sparse_config):
+    layout = sparse_config.layout
+    k1, k2 = sparse_config.kernel_size, sparse_config.kernel_size2
+    k1h = k1*2-1
+    h = w = int(math.sqrt(layout[2] - layout[1]) + 1e-3)
+    m = torch.zeros(layout[-1]+1, layout[-1], dtype=torch.bool, device='cuda')
+    for i in range(layout[1]):
+        m[i, :i] = True
+    m[layout[1]:, :layout[0]] = True
+    for i in tqdm(range(layout[1], layout[2])):
+        # m[i, layout[1]:i] = True
+        x = (i - layout[1]) // w
+        y = (i - layout[1]) % w
+        lx = max(0, x - k1h // 2)
+        ly = max(0, y - k1 // 2)
+        rx = min(h-1, x + k1h // 2)
+        ry = min(w-1, y + k1 // 2)
+        m[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = True
+        m[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = True
+        m[i, i] = False
+    for i in tqdm(range(layout[2], layout[3])):
+        x = (i - layout[2]) // (2*w)
+        y = (i - layout[2]) % (2*w)
+        lx = max(0, x - k1h // 2)
+        ly = max(0, y - k1 // 2)
+        rx = min(2*h-1, x + k1h // 2)
+        ry = min(2*w-1, y + k1 // 2)
+        m[i, layout[2]:layout[3]].view(h*2, w*2)[lx:x, ly:ry+1] = True
+        m[i, layout[2]:layout[3]].view(h*2, w*2)[x, ly:y+1] = True
+        x = x // 2
+        y = y // 2
+        lx = max(0, x - k2 // 2)
+        ly = max(0, y - k2 // 2)
+        rx = min(h-1, x + k2 // 2)
+        ry = min(w-1, y + k2 // 2)
+        m[i, layout[1]:layout[2]].view(h, w)[lx:rx+1, ly:ry+1] = True
+        m[i, i] = False
+    return m.unsqueeze(-1)
+
+def update_mems_local(hiddens, mems, start, end, mem_bag, mask):
+    if isinstance(hiddens, list):
+        hiddens = torch.stack(hiddens)
+        mem_bag[:, :, start:end] = hiddens.to('cuda')
+     # first level
+    del mems
+    mems = mem_bag.masked_select(mask[end]).view(*mem_bag.shape[:2], -1, mem_bag.shape[3]).to(hiddens.device)
+    return mems
+
+
+def filling_sequence_local(
+        model, 
+        seq, 
+        args, 
+        mems=None, 
+        invalid_slices=[], 
+        **kwargs):
+    '''
+        seq: [2, 3, 5, ..., -1(to be generated), -N (N beams), -1]
+        context_length: first non(-1)s
+    '''
+    loss_sum, loss_n = 0, 0.0001
+    tokenizer = get_tokenizer()
+    device = seq.device
+    assert len(seq.shape) == 1
+    assert args.sparse_config.sparse_type == 'standard'
+    sparse_config = deepcopy(args.sparse_config)
+    # with open('tmp_save.bin', 'rb') as fout:
+    #     seq = torch.load(fout)[1]
+    #     seq = torch.cat((seq, torch.tensor([-1], device=seq.device)))
+    # import pdb; pdb.set_trace()
+    # sparse_config.layout[0] = seq.tolist().index(-1)
+    sparse_config.layout[0] = seq.tolist().index(tokenizer['[BASE]'])
+    n_pad = sparse_config.layout[1] - sparse_config.layout[0]
+    assert n_pad > 0 # TODO long trunc
+    seq = torch.cat((seq[:sparse_config.layout[0]], torch.tensor([tokenizer['[POS8]']]* n_pad, device=seq.device, dtype=seq.dtype), seq[sparse_config.layout[0]:]))
+    out_seq_length = len(seq)
+    # building the initial tokens, attention_mask, and position_ids
+    context_length = sparse_config.layout[1] + 1
+
+    tokens, attention_mask, position_ids = get_batch(seq[:context_length], device, args)
+    tokens = tokens.expand(-min(seq), *tokens.shape[1:])
+
+    counter = context_length - 1 # == len(tokens) - 1
+    index = 0 # len(mems)
+    if mems is None:
+        mems = []
+    mem_bag = torch.zeros(args.num_layers, tokens.shape[0], out_seq_length-1, 2*args.hidden_size, device='cuda')
+    local_mask = make_local_mask(sparse_config)
+    
+    while counter < (out_seq_length - 1):
+        if counter % 100 == 0:
+            print(counter, loss_sum / loss_n, file=sys.stderr)
+        # Now, we want to generate seq[counter + 1]
+        # token[:, index: counter+1] are just added.
+        # import pdb;pdb.set_trace()
+
+        if index == 0: # first
+            logits, *qkv = model(tokens, position_ids, attention_mask, *mems)
+            mems = update_mems_local(qkv, mems, index, counter+1, mem_bag, local_mask)
+            index = counter
+        else:
+            assert tokens.shape[1] == counter + 1 
+            position_ids = torch.arange(index, counter + 1, dtype=torch.long, device=tokens.device).unsqueeze(0)
+            logits, *qkv = model(tokens[:, index: ], 
+                position_ids,
+                0, # rebuild in transformers (sep version)
+                *mems
+                )
+            mems = update_mems_local(qkv, mems, index, counter+1, mem_bag, local_mask)
+            index = counter
+        counter += 1
+        index += 1
+
+        if seq[counter] >= 0: # provided
+            tokens = torch.cat((tokens, seq[counter: counter+1].expand(tokens.shape[0], 1)), dim=1)
+            loss_n +=1
+            loss_this = F.log_softmax(logits, dim=-1)[:, -1, seq[counter]].mean()
+            print(counter-64, loss_this.item())
+            loss_sum -= loss_this
+            continue
+
+        logits = logits[:, -1] # [batch size, vocab size]
+        temp = args.temperature
+        logits /= temp
+        for invalid_slice in invalid_slices: # forbide to generate other tokens
+            logits[..., invalid_slice] = -float('Inf')
+        logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)
+        log_probs = F.softmax(logits, dim=-1)
+
+        # expand beams
+        prev = torch.multinomial(log_probs, num_samples=1)
+        tokens = torch.cat((tokens, prev.view(tokens.shape[0], 1)), dim=1)
+
+    output_tokens_list = tokens.view(tokens.shape[0], -1).contiguous()
+    output_tokens_list = torch.cat(
+        (
+            output_tokens_list[:, :sparse_config.layout[0]],
+            output_tokens_list[:, sparse_config.layout[1]+1:sparse_config.layout[2]+1],
+            torch.tensor([[tokenizer['[EOI1]']]], dtype=tokens.dtype, device=tokens.device).expand(output_tokens_list.shape[0], 1),
+            output_tokens_list[:, sparse_config.layout[2]+1:]
+        ), dim=1)
+    return output_tokens_list
\ No newline at end of file
diff --git a/generation/sampling.py b/generation/sampling.py
index 8cf03263e78ce4d80ddc2dfa5832eb22976b4997..4c3acd2501e50f74589884963a87a76840d019dd 100755
--- a/generation/sampling.py
+++ b/generation/sampling.py
@@ -20,7 +20,6 @@ import torch.nn.functional as F
 from pretrain_gpt2 import get_masks_and_position_ids
 from data_utils import get_tokenizer
 
-
 def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
     # This function has been mostly taken from huggingface conversational ai code at
     # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
@@ -61,6 +60,19 @@ def get_batch(context_tokens, device, args):
         tokens, args=args)
     return tokens, attention_mask, position_ids
 
+def update_mems(hiddens, mems, max_memory_length=10000):
+    memory_length = mems[0].size(1) if mems else 0
+    query_length = hiddens[0].size(1)
+    new_memory_length = min(max_memory_length, memory_length + query_length)
+    new_mems = []
+    with torch.no_grad():
+        for i in range(len(hiddens)):
+            if new_memory_length <= query_length:
+                new_mems.append(hiddens[i][:, -new_memory_length:])
+            else:
+                new_mems.append(torch.cat((mems[i][:, -new_memory_length+query_length:], hiddens[i]), dim=1))
+    return new_mems
+
 def filling_sequence(
         model, 
         seq, 
@@ -115,8 +127,15 @@ def filling_sequence(
 
         if index == 0: # first 
             position_ids[position_ids > offset] -= offset
-            logits, *mems = model(tokens, position_ids, attention_mask, *mems)
+            logits, *qkv = model(tokens, position_ids, attention_mask, *mems)
+            mems = update_mems(qkv, mems)
+
+            tmp = -F.log_softmax(logits, dim=-1)
+            tmp = tmp[0,:-1].gather(dim=-1,index=tokens[0,1:].unsqueeze(-1))[4:,0]
+            for i in range(1,len(tmp)):
+                print(i, tmp[i].item())
             index = counter
+            print(tmp[1:].mean(), file=sys.stderr)
         elif seq[counter + 1] >= 0: # provided
             if seq[counter + 1] == tokenizer['[ROI2]']:
                 offset = counter + 1
@@ -131,15 +150,18 @@ def filling_sequence(
             position_ids[position_ids > offset] -= offset
             # TODO each time, the feed input cannot be too long (window size), or it will have a discrepcy from sparse training, but this is not very important. 
             tokens, mems, score = shrink_beams(tokens, mems, -seq[counter + 1], score)
-            logits, *mems = model(tokens[:, index: ], 
+            logits, *qkv = model(tokens[:, index: ], 
                 position_ids,
                 0, # rebuild in transformers (sep version)
                 *mems)
+            mems = update_mems(qkv, mems)
+
             index = counter
         nb = -seq[counter + 1]
         counter += 1
         index += 1
 
+
         logits = logits[:, -1] # [batch size, vocab size]
 
         temp = args.temperature
@@ -180,7 +202,7 @@ def shrink_beams(tokens, mems, nb, score):
     new_mems = [mem[max_idx: max_idx + 1] for mem in mems]
     return tokens, new_mems, score
 
-def add_interlacing_beam_marks(seq, nb=12, period=3000):
+def add_interlacing_beam_marks(seq, nb=12, period=30000):
     assert isinstance(seq, list) or len(seq.shape) == 1
     blk_cnt = 0
     for i in range(len(seq)):
@@ -203,7 +225,9 @@ def inverse_prompt_score(model, seq, args):
     assert tokenizer['[ROI1]'] == seq[0][botext]
 
     tokens, attention_mask, position_ids = get_batch(seq, device, args)
-    logits, *mems = model(tokens, position_ids, attention_mask)
+    logits, *qkv = model(tokens, position_ids, attention_mask)
+    mems = update_mems(qkv, mems)
+
     logits[..., :tokenizer.img_tokenizer.num_tokens] = -float('Inf')
     log_probs = torch.log(F.softmax(logits, dim=-1))
 
diff --git a/mpu/local_attention_function.py b/mpu/local_attention_function.py
index 5ba073ced5728ff51fcb03115f76afcbaae42451..1a5af91d91310bb2d84112ad5349cf09c2fa2950 100644
--- a/mpu/local_attention_function.py
+++ b/mpu/local_attention_function.py
@@ -28,6 +28,7 @@ class similarFunction(Function):
         x_ori, x_loc = ctx.saved_tensors
         kH, kW = ctx.kHW
         casual_mask = ctx.casual_mask
+        grad_outputs = grad_outputs.contiguous()
         grad_ori = similar_backward(x_ori, x_loc, grad_outputs, kH, kW, True, casual_mask)
         grad_loc = similar_backward(x_ori, x_loc, grad_outputs, kH, kW, False, casual_mask)
 
@@ -50,6 +51,7 @@ class weightingFunction(Function):
         x_ori, x_weight = ctx.saved_tensors
         kH, kW = ctx.kHW
         casual_mask = ctx.casual_mask
+        grad_outputs = grad_outputs.contiguous()
         grad_ori = weighting_backward_ori(x_ori, x_weight, grad_outputs, kH, kW, casual_mask)
         grad_weight = weighting_backward_weight(x_ori, x_weight, grad_outputs, kH, kW, casual_mask)
 
@@ -139,11 +141,3 @@ class TorchLocalAttention(nn.Module):
 
         return out
     
-    
-if __name__ == '__main__':
-    b, c, h, w = 8, 3, 32, 32
-    kH, kW = 5, 5
-    x = torch.rand(b, c, h, w).cuda()
-    m = LocalAttention(c, c, kH, kW)
-    m.cuda()
-    y = m(x)
\ No newline at end of file
diff --git a/mpu/sparse_transformer.py b/mpu/sparse_transformer.py
index 33b5d42758385c222e1bd30bf53279de21e73eb2..7adc0dfbb13e25aba14aa64fa1f0285a457d2bbf 100755
--- a/mpu/sparse_transformer.py
+++ b/mpu/sparse_transformer.py
@@ -75,7 +75,7 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
     """
     def __init__(self, hidden_size, num_attention_heads,
                  attention_dropout_prob, output_dropout_prob,
-                 init_method, output_layer_init_method=None):
+                 init_method, output_layer_init_method=None,sparse_config=None):
         super(GPT2ParallelSelfAttention, self).__init__()
         # Set output layer initialization if not provided.
         if output_layer_init_method is None:
@@ -111,6 +111,11 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
             get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
             checkpoint = deepspeed.checkpointing.checkpoint
 
+        # self.offset_bias = torch.nn.Parameter(
+        #     torch.ones(num_attention_heads, sparse_config.kernel_size**2//2+1 +sparse_config.kernel_size2**2)) 
+
+        self.sparse_config = sparse_config
+
     def _transpose_for_scores(self, tensor):
         """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
         size [b, np, s, hn].
@@ -122,20 +127,21 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
         return tensor.permute(0, 2, 1, 3)
 
 
-    def forward(self, hidden_states, mask, sparse_config, mem=None):
+    def forward(self, hidden_states, mask, mem=None):
+        # import pdb;pdb.set_trace()
+        sparse_config = self.sparse_config
         # hidden_states: [b, s, h]
         # ltor_mask: [1, 1, s, s]
 
         # Attention heads. [b, s, hp]
         query_length = hidden_states.size(1)
 
-        # if mem is None:
         mixed_raw_layer = self.query_key_value(hidden_states)
 
         (mixed_query_layer,
             mixed_key_layer,
             mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
-        if mem is not None:
+        if mem is not None and len(mem) > 0:
             memk, memv = split_tensor_along_last_dim(mem, 2)
             mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1)
             mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1)
@@ -152,7 +158,7 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
             if sparse_config.sparse_type == 'standard':
                 context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn)
             else:
-                context_layer = sparse_attention(query_layer, key_layer, value_layer, sparse_config.pivot_idx, 
+                context_layer = sparse_attention_1d(query_layer, key_layer, value_layer, sparse_config.pivot_idx, 
                     mask, sparse_config.query_window, sparse_config.key_window_times, dropout_fn)
                 # inference: context_layer = sparse_attention_inference(query_layer, key_layer, value_layer, pivot_idx)
             
@@ -163,12 +169,14 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
             context_layer = context_layer.view(*new_context_layer_shape)
             
         elif sparse_config.sparse_type == 'cuda_2d':
-            context_layer = sparse_attention_2d(mixed_query_layer, mixed_key_layer, mixed_value_layer, self.num_attention_heads_per_partition,
-                 sparse_config.layout, mask, sparse_config.kernel_size, sparse_config.kernel_size2, attention_dropout=dropout_fn)
+            context_layer = sparse_attention_2dfull(mixed_query_layer, mixed_key_layer, mixed_value_layer, self.num_attention_heads_per_partition,
+                 sparse_config.layout, mask, sparse_config.kernel_size,
+                  kernel_size2=sparse_config.kernel_size2,
+                  attention_dropout=dropout_fn
+                 )
 
         # Output. [b, s, h]
         output = self.dense(context_layer)
-        
         if self.training:
             output = self.output_dropout(output)
         
@@ -228,6 +236,7 @@ class GPT2ParallelMLP(torch.nn.Module):
 
     def forward(self, hidden_states):
         # [b, s, 4hp]
+
         intermediate_parallel = self.dense_h_to_4h(hidden_states)
         intermediate_parallel = gelu(intermediate_parallel)
 
@@ -292,7 +301,9 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
             attention_dropout_prob,
             output_dropout_prob,
             init_method,
-            output_layer_init_method=output_layer_init_method)
+            output_layer_init_method=output_layer_init_method,
+            sparse_config=sparse_config
+            )
 
         # Layernorm on the input data.
         self.post_attention_layernorm = LayerNorm(hidden_size,
@@ -320,7 +331,7 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
         # Layer norm at the begining of the transformer layer.
         layernorm_output1 = self.input_layernorm(hidden_states)
         # Self attention.
-        attention_output, qkv = self.attention(layernorm_output1, ltor_mask, self.sparse_config, mem)
+        attention_output, qkv = self.attention(layernorm_output1, ltor_mask, mem)
 
         # Third LayerNorm
         if self.sandwich_ln:
@@ -501,7 +512,12 @@ class GPT2ParallelTransformer(torch.nn.Module):
                 x_, mask, mems_ = inputs[0], inputs[1], inputs[2:]
             
                 for i, layer in enumerate(layers_):
-                    mem_i_ = mems_[i] if mems_ else None
+                    if mems_:
+                        mem_i_ = mems_[i]  
+                    elif self.max_memory_length > 0:
+                        mem_i_ = []
+                    else:
+                        mem_i_ = None
                     x_, qkv = layer(x_, mask, mem=mem_i_)
                     if self.max_memory_length > 0:
                         mem_layers.append(qkv)
@@ -527,30 +543,25 @@ class GPT2ParallelTransformer(torch.nn.Module):
             for i, layer in enumerate(self.layers):
                 args = [hidden_states, attention_mask_saved]
 
-                mem_i = mems[i] if mems else None
+                if mems:
+                    mem_i = mems[i]  
+                elif self.max_memory_length > 0:
+                    mem_i = []
+                else:
+                    mem_i = None
                 hidden_states, qkv = layer(*args, mem=mem_i)
                 if self.max_memory_length > 0:
                     mem_layers.append(qkv) 
 
         # Final layer norm.
         output = self.final_layernorm(hidden_states)
-        if self.max_memory_length > 0: # TODO cache
-            mem_layers = self.update_mems(mem_layers, mems)
+        # if self.max_memory_length > 0: # TODO cache
+        #     if self.sparse_config.sparse_type != 'cuda_2d':
+        #         mem_layers = self.update_mems(mem_layers, mems)
+        #     else:
+        #         pass # handle update outside the model, because mems is not the full cached qkv.
 
         return (output, *mem_layers)
-
-    def update_mems(self, hiddens, mems):
-        memory_length = mems[0].size(1) if mems else 0
-        query_length = hiddens[0].size(1)
-        new_memory_length = min(self.max_memory_length, memory_length + query_length)
-        new_mems = []
-        with torch.no_grad():
-            for i in range(len(hiddens)):
-                if new_memory_length <= query_length:
-                    new_mems.append(hiddens[i][:, -new_memory_length:])
-                else:
-                    new_mems.append(torch.cat((mems[i][:, -new_memory_length+query_length:], hiddens[i]), dim=1))
-        return new_mems
         
 
 def _chunk(x, w, times):
@@ -600,7 +611,6 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, atte
     # [b, np, s, hn]
     
     context_layer = torch.matmul(attention_probs, value_layer)
-    
     return context_layer
 
 def sparse_attention_1d(q, k, v, pivot_idx, pivot_attention_mask, query_window=128, key_window_times=6, attention_dropout=None):
@@ -682,10 +692,10 @@ def sparse_attention_1d(q, k, v, pivot_idx, pivot_attention_mask, query_window=1
 
 def transpose_and_split(x, layout, n_head):
     x = x.transpose(1, 2)
-    x = x.reshape(x.shape[0] * n_head, x.shape[1] // n_head, x.shape[2])
+    x = x.reshape(x.shape[0]*n_head, x.shape[1] // n_head, x.shape[2])
     x_text = x[..., :layout[0]]
-    x0 = x[...,layout[1]:layout[2]].view(x.shape[0], x.shape[1], sqrt(layout[2] - layout[1]), -1).contiguous()
-    x1 = x[...,layout[2]:layout[3]].view(x.shape[0], x.shape[1], sqrt(layout[3] - layout[2]), -1).contiguous()
+    x0 = x[...,layout[1]:layout[2]].view(x.shape[0], x.shape[1], sqrt(layout[2] - layout[1]), sqrt(layout[2] - layout[1])).contiguous()
+    x1 = x[...,layout[2]:layout[3]].view(x.shape[0], x.shape[1], sqrt(layout[3] - layout[2]), sqrt(layout[3] - layout[2])).contiguous()
     return x, x_text, x0, x1
 
 def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs):
@@ -704,30 +714,35 @@ def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_s
     q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext]
     k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head)
     v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head)
-
     # import pdb; pdb.set_trace()
     # all to text
     scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d)
     scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0])
     # 0 to 0
-    scores_0_to_0 = f_similar(q0, k0, kernel_size, kernel_size, True)
+    scores_0_to_0 = f_similar(q0, k0, kernel_size*2-1, kernel_size, True)
     # 1 to 1
-    scores_1_to_1 = f_similar(q1, k1, kernel_size, kernel_size, True)    
+    scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)    
     # 1 to 0
     scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2]
     # softmax
-    probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text]
+    # if 'offset_bias' in kwargs:
+    #     p1, p2 = kernel_size**2//2 + 1, kernel_size2**2
+    #     offset_bias = kwargs['offset_bias'].expand(b, n_head, p1+p2).view(b*n_head, 1, p1+p2)
+    #     scores_0_to_0 = scores_0_to_0 * offset_bias[...,:p1]
+    #     scores_1_to_1 = scores_1_to_1 * offset_bias[...,:p1]
+    #     scores_1_to_0 = scores_1_to_0 * offset_bias[...,-p2:]
 
     scores_0 = torch.cat(
         (scores_all_to_text[:, layout[1]:layout[2]], 
         scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])), 
         dim=-1)
-    probs_0 = F.softmax(scores_0, dim=-1) # 
     scores_1 = torch.cat(
         (scores_all_to_text[:, layout[2]:layout[3]],
          scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]),
          scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])),
          dim=-1)
+    probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text]
+    probs_0 = F.softmax(scores_0, dim=-1) # 
     probs_1 = F.softmax(scores_1, dim=-1)
 
     if attention_dropout is not None:
@@ -747,11 +762,11 @@ def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_s
         probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]), 
         v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn)
     
-    context_0_to_0 = f_weighting(v0, probs_0[..., layout[0]:].view_as(scores_0_to_0).contiguous(), kernel_size, kernel_size, True)
+    context_0_to_0 = f_weighting(v0, probs_0[..., layout[0]:].view_as(scores_0_to_0).contiguous(), kernel_size*2-1, kernel_size, True)
 
     context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False)
 
-    context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size, kernel_size, True)
+    context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size*2-1, kernel_size, True)
     
     context_all_to_01 =torch.cat(
         (
@@ -761,3 +776,78 @@ def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_s
         ), dim=-1).view(b, hn, -1).transpose(1, 2)
     return context_all_to_text + context_all_to_01 
 
+
+def sparse_attention_2dfull(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs):
+    '''
+    q, k, v: [batch_size, 64+1024+4096, hidden_size]
+    n_head: int
+    layout: [endoftext/startofpad, startof0, startof1, endofall]
+    attention_mask_text2d: [batch_size, sq_len, endoftext]
+    '''
+    from .local_attention_function import f_similar, f_weighting
+    b, sq_len, hn = q.shape
+    alpha = sqrt((layout[3] - layout[2]) // (layout[2] - layout[1]))
+
+    q = q / math.sqrt(hn // n_head) # normalization
+
+    q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext]
+    k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head)
+    v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head)
+    # import pdb; pdb.set_trace()
+    # all to text
+    scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d)
+    scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0])
+    # 0 to 0
+    if not hasattr(sparse_attention_2dfull, 'attention_mask0'):
+        sparse_attention_2dfull.attention_mask0 = torch.ones((layout[2] - layout[1], layout[2] - layout[1]), device=q.device, dtype=q.dtype).tril_()
+    attention_mask0 = sparse_attention_2dfull.attention_mask0
+    scores_0_to_0 = torch.einsum('bhi,bhj->bij', q0.view(*q0.shape[:2], -1), k0.view(*k0.shape[:2], -1)) * attention_mask0 - 10000.0 * (1.0 - attention_mask0)
+    # 1 to 1
+    scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)    
+    # 1 to 0
+    scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2]
+    # softmax
+
+    scores_0 = torch.cat(
+        (scores_all_to_text[:, layout[1]:layout[2]], 
+        scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])), 
+        dim=-1)
+    scores_1 = torch.cat(
+        (scores_all_to_text[:, layout[2]:layout[3]],
+         scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]),
+         scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])),
+         dim=-1)
+    probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text]
+    probs_0 = F.softmax(scores_0, dim=-1) # 
+    probs_1 = F.softmax(scores_1, dim=-1)
+
+    if attention_dropout is not None:
+        with get_cuda_rng_tracker().fork():
+            probs_0 = attention_dropout(probs_0)
+            probs_1 = attention_dropout(probs_1)
+    # weighting
+    pad = torch.zeros(layout[1], device=q.device, dtype=q.dtype)
+    probs_all_to_text = torch.cat((
+        probs_text,
+        pad[-layout[0]:].expand(b*n_head, layout[1]-layout[0], layout[0]),
+        probs_0[:, :, :layout[0]],
+        probs_1[:, :, :layout[0]]
+    ), dim=1)
+
+    context_all_to_text = torch.einsum('bhij,bhcj->bihc', 
+        probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]), 
+        v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn)
+    
+    context_0_to_0 = torch.einsum('bcj,bij->bci', v0.view(*v0.shape[:2], -1), probs_0[..., layout[0]:].view_as(scores_0_to_0))
+
+    context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False)
+
+    context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size*2-1, kernel_size, True)
+    
+    context_all_to_01 =torch.cat(
+        (
+            pad.expand(b*n_head, hn//n_head, layout[1]),
+            context_0_to_0.view(b*n_head, hn//n_head, layout[2]-layout[1]),
+            (context_1_to_0 + context_1_to_1).view(b*n_head, hn//n_head, layout[3]-layout[2])
+        ), dim=-1).view(b, hn, -1).transpose(1, 2)
+    return context_all_to_text + context_all_to_01 
diff --git a/pretrain_gpt2.py b/pretrain_gpt2.py
index f9d83e5184c2f9037a337228f86e9231e0dbba9d..784153e8b65c70c7004525dfdf208fd69c9cf35a 100755
--- a/pretrain_gpt2.py
+++ b/pretrain_gpt2.py
@@ -55,12 +55,10 @@ from data_utils import make_loaders, get_tokenizer, detect_new_datasets
 
 import stat
 
-def get_model(args):
+def get_model(args, sparse_config=None):
     """Build the model."""
 
     print_rank_0('building CogView2 model ...')
-    # print(args.vocab_size)
-    # ml = max(args.max_position_embeddings, args.max_position_embeddings_finetune)
     ml = args.max_position_embeddings
     model = GPT2Model(num_layers=args.num_layers,
                       vocab_size=args.vocab_size,
@@ -74,7 +72,7 @@ def get_model(args):
                       checkpoint_activations=args.checkpoint_activations,
                       checkpoint_num_layers=args.checkpoint_num_layers,
                       parallel_output=True,
-                      sparse_config=args.sparse_config,
+                      sparse_config=sparse_config if sparse_config is not None else args.sparse_config,
                       sandwich_ln=args.sandwich_ln
                       )
 
@@ -218,8 +216,9 @@ def get_masks_and_position_ids(data,
         # single direction, [PAD]s are at the end of the seq, so doesn't matter.
         if args.sparse_config.sparse_type == 'cuda_2d':
             layout = args.sparse_config.layout
-            unpad_indices = (data[:, :layout[1]] >= 0) * 10000.
-            starts = (torch.arange(layout[1], device=data.device).expand_as(unpad_indices) + unpad_indices).min(dim=-1)[1]
+            unpad_indices = (data[:, :layout[1]+1] >= 0) * 10000.
+            unpad_indices[:, -1] = 9000.
+            starts = (torch.arange(layout[1]+1, device=data.device).expand_as(unpad_indices) + unpad_indices).min(dim=-1)[1]
             layout[0] = starts.max().item()
             attention_mask = torch.ones((batch_size, seq_length, layout[0]), device=data.device)
             for i in range(batch_size):
@@ -229,7 +228,22 @@ def get_masks_and_position_ids(data,
         elif args.sparse_config.sparse_type == 'standard':
             attention_mask = torch.ones((batch_size, seq_length, seq_length), device=data.device)
             attention_mask.tril_()
-            attention_mask = attention_mask.unsqueeze(1)
+            # attention_mask = torch.zeros((seq_length, seq_length), device=data.device)
+            # h = w = 32
+            # k1=9
+            # layout = [10, 64, 64+h*w, 64+h*w*5]
+            # for i in range(layout[1]):
+            #     attention_mask[i, :i+1] = 1
+            # for i in range(layout[1], layout[2]):
+            #     x = (i - layout[1]) // w
+            #     y = (i - layout[1]) % w
+            #     lx = max(0, x - k1 // 2)
+            #     ly = max(0, y - k1 // 2)
+            #     rx = min(h-1, x + k1 // 2)
+            #     ry = min(w-1, y + k1 // 2)
+            #     attention_mask[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1
+            #     attention_mask[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1
+            # attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
         elif args.sparse_config.sparse_type == 'torch_1d':
             raise NotImplementedError
 
@@ -283,6 +297,18 @@ def get_batch(data_iterator, args, timers):
     # Unpack.
     tokens_ = data_b['text'].long()
     loss_mask = data_b['loss_mask'].float()
+
+    #FIXME change order
+    if args.sparse_config.sparse_type == 'cuda_2d':
+        assert args.sparse_config.layout[-1] == 64+32**2+64**2
+    tokens_new = tokens_.clone()
+    tokens_new[:, 64:64+32**2] = tokens_[:, 64+64**2:]
+    tokens_new[:, 64+32**2:] = tokens_[:, 64:64+64**2]
+    tokens_ = tokens_new
+    # only 32# 
+    # tokens_ = tokens_[:, :64+32**2]
+    # loss_mask = loss_mask[:, :64+32**2]
+    
     if args.dataset_type == 'BinaryDataset':
         labels = tokens_.contiguous()
         loss_mask = loss_mask.contiguous()
@@ -326,7 +352,7 @@ def forward_step(data_iterator, model, args, timers, mems):
     img_txt_sep = tokenizer.img_tokenizer.num_tokens
     img_indices_bool = (tokens.detach() < img_txt_sep) & (loss_mask > 0)
     txt_indices_bool = (~img_indices_bool) & (loss_mask > 0)
- 
+
     # Forward model.
     logits, *mems = model(tokens, position_ids, attention_mask, *mems)
     losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(),
@@ -339,6 +365,15 @@ def forward_step(data_iterator, model, args, timers, mems):
     loss = torch.sum(losses) / loss_mask.sum()
 
     # =====================   Log partial losses   ======================== #
+    if args.sparse_config.sparse_type == 'cuda_2d':
+        img_indices_bool2 = img_indices_bool.clone()
+        img_indices_bool2[:, :args.sparse_config.layout[2]] = False
+        img_loss2 = losses[img_indices_bool2.view(-1)].detach().sum() / max(img_indices_bool2.sum(), 1)
+        torch.distributed.all_reduce(img_loss2.data)
+        img_loss2.data = img_loss2.data / args.world_size
+        img_indices_bool[:, args.sparse_config.layout[2]:] = False
+    else:
+        img_loss2 = 0
     img_indices_bool = img_indices_bool.view(-1)
     txt_indices_bool = txt_indices_bool.view(-1)
     img_loss = losses[img_indices_bool].detach().sum() / max(img_indices_bool.sum(), 1)
@@ -351,8 +386,10 @@ def forward_step(data_iterator, model, args, timers, mems):
     txt_loss.data = txt_loss.data / args.world_size
 
     # ===================== END OF BLOCK ======================= #
-    
-    return loss, mems, img_loss, txt_loss
+    # import pdb;pdb.set_trace()
+    # with open('tmp_save.bin', 'wb') as fout:
+    #     torch.save(tokens, fout)
+    return loss, mems, img_loss, txt_loss, img_loss2
 
 
 def backward_step(optimizer, model, lm_loss, args, timers):
@@ -423,12 +460,12 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
     while True:
         # Forward model for one step.
         timers('forward').start()
-        lm_loss, mems, img_loss, txt_loss = forward_step(data_iterator, model, args, timers, mems)
+        lm_loss, mems, img_loss, txt_loss, img_loss2 = forward_step(data_iterator, model, args, timers, mems)
         timers('forward').stop()
 
         if (img_loss + txt_loss).isnan().any() or (img_loss + txt_loss).isinf().any():
             print('Skipping backward and optimizer step for nan or inf in forwarding!')
-            return (img_loss + txt_loss), 1, mems, img_loss, txt_loss
+            return (img_loss + txt_loss), 1, mems, img_loss, txt_loss, img_loss2
 
         # Calculate gradients, reduce across processes, and clip.
         timers('backward').start()
@@ -459,15 +496,17 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
         timers('optimizer').stop()
         if complete:
             break
-    return lm_loss_reduced, skipped_iter, mems, img_loss, txt_loss
+    return lm_loss_reduced, skipped_iter, mems, img_loss, txt_loss, img_loss2
 
 
-def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time, step, total_step, args, img_loss, txt_loss):
+def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time, step, total_step, args, img_loss, txt_loss, img_loss2):
     log_string = ' iteration {:8d}/{:8d} |'.format(step, total_step)
     log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(elapsed_time)
     log_string += ' learning rate {:.3E} |'.format(lr)
     log_string += ' lm loss {:.6E} |'.format(loss)
     log_string += ' img loss {:.6E} |'.format(img_loss)
+    if args.sparse_config.sparse_type == 'cuda_2d':
+        log_string += ' img loss2 {:.6E} |'.format(img_loss2)
     log_string += ' unscaled txt loss {:.6E} |'.format(txt_loss)
     if args.fp16:
         log_string += ' loss scale {:.1f} |'.format(
@@ -514,13 +553,13 @@ def train(model, optimizer, lr_scheduler,
         if args.iteration % 100 == 0:
             new_loaders = detect_new_datasets(args)
             if new_loaders is not None:
-                print(f'Loatding new datasets ... Now we train models on {args.train_data}.')
+                print(f'Loading new datasets ... Now we train models on {args.train_data}.')
                 train_data_iterator = iter(new_loaders[0])
                 val_data_iterator = iter(new_loaders[1])
                 # TODO close the original
 
 
-        lm_loss, skipped_iter, mems, img_loss, txt_loss = train_step(train_data_iterator,
+        lm_loss, skipped_iter, mems, img_loss, txt_loss, img_loss2 = train_step(train_data_iterator,
                                            model,
                                            optimizer,
                                            lr_scheduler,
@@ -544,7 +583,7 @@ def train(model, optimizer, lr_scheduler,
             elapsed_time = timers('interval time').elapsed()
             report_iteration_metrics(summary_writer, optimizer, learning_rate, avg_lm_loss,
                                     elapsed_time * 1000.0 / args.log_interval, args.iteration, args.train_iters, args,
-                                    avg_img_loss, avg_txt_loss)
+                                    avg_img_loss, avg_txt_loss, img_loss2)
             total_lm_loss = 0.0
             total_img_loss = 0.0
             total_txt_loss = 0.0
@@ -588,6 +627,7 @@ def evaluate(data_iterator, model, args, timers, verbose=False):
 
     total_lm_loss = 0
     mems = []
+    # with open('grad_scale_fp32.txt', 'w') as fout: 
     with torch.no_grad():
         iteration = 0
         while iteration < args.eval_iters:
@@ -595,8 +635,21 @@ def evaluate(data_iterator, model, args, timers, verbose=False):
             if verbose and iteration % args.log_interval == 0:
                 print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters))
             # Forward evaluation.
-            lm_loss, mems, img_loss, txt_loss = forward_step(data_iterator, model, args, timers, mems=mems)
-
+            lm_loss, mems, img_loss, txt_loss, img_loss2 = forward_step(data_iterator, model, args, timers, mems=mems)
+
+            # (lm_loss).backward()
+            # for name, param in model.named_parameters():
+            #     v_max = param.data.abs().max().item()
+            #     v_mean = param.data.abs().mean().item()
+            #     fout.write(f'name: {name}, v_max: {v_max}, v_mean: {v_mean}\n')
+            #     if param.grad is not None:
+            #         g_max = param.grad.max().item()
+            #         g_zero_bool = param.grad.abs() == 0
+            #         g_zero = (g_zero_bool.sum() / param.grad.numel()).item()
+            #         g_nz_mean = param.grad[~g_zero_bool].abs().mean().item()
+            #         fout.write(f'g_max: {g_max}, g_zero: {g_zero}, g_nz_mean: {g_nz_mean}\n')
+            # fout.flush()
+            # import pdb;pdb.set_trace()
             '''when contiguous memory optimizations are enabled, the buffers
             allocated by the optimizations are deallocated during backward pass
             in the absence of backward pass the buffers should be reset after each
@@ -825,6 +878,16 @@ def main():
         evaluate_and_print_results(prefix, test_data_iterator,
                                    model, args, timers, True)
 
+def seed_torch(seed=1029):
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
+    torch.backends.cudnn.benchmark = False
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.enabled = False
 
 if __name__ == "__main__":
+    torch.backends.cuda.matmul.allow_tf32 = False
     main()
diff --git a/random_display.py b/random_display.py
new file mode 100644
index 0000000000000000000000000000000000000000..c59c51adaf56937df7c1bcef072759890006ca21
--- /dev/null
+++ b/random_display.py
@@ -0,0 +1,25 @@
+from data_utils.datasets import BinaryDataset
+from data_utils import get_tokenizer
+import argparse
+import os
+import torch
+import random
+test_dir = 'tmp'
+# bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin'
+bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/quanjing005/quanjing005.bin.part_0.cogdata'
+bin_ds = BinaryDataset(os.path.join(bin_dir), process_fn=lambda x:x, length_per_sample=64*64+32*32+64, dtype='int32', preload=False)
+args = argparse.Namespace(img_tokenizer_path='pretrained/vqvae/vqvae_hard_biggerset_011.pt', img_tokenizer_num_tokens=None)
+tokenizer = get_tokenizer(args)
+
+bin_ds = [bin_ds[random.randint(0, len(bin_ds)-1)] for i in range(16)]
+for x in bin_ds:
+    end = x.tolist().index(-1)
+    print(tokenizer.DecodeIds(x[:end])[0])
+
+from torchvision.utils import save_image
+imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64:64+64**2], dtype=torch.long, device='cuda')) for x in bin_ds], dim=0)
+save_image(imgs, os.path.join(test_dir, 'testcase512.jpg'), normalize=True)
+imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64+64**2:64+64**2+32**2], dtype=torch.long,device='cuda')) for x in bin_ds], dim=0)
+save_image(imgs, os.path.join(test_dir, 'testcase256.jpg'), normalize=True)
+# imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64+64**2+32**2:], dtype=torch.long,device='cuda')) for x in bin_ds], dim=0)
+# save_image(imgs, os.path.join(test_dir, 'testcase128.jpg'), normalize=True)
\ No newline at end of file
diff --git a/scripts/cuda_2d_text2image.sh b/scripts/cuda_2d_text2image.sh
new file mode 100755
index 0000000000000000000000000000000000000000..46783ccda9f37af232fd7f46de97c96463d5fccf
--- /dev/null
+++ b/scripts/cuda_2d_text2image.sh
@@ -0,0 +1,46 @@
+#!/bin/bash
+
+CHECKPOINT_PATH=data/checkpoints/cogview-fixgrad-small08-25-09-38
+# CHECKPOINT_PATH=data/checkpoints/cogview-compare
+NLAYERS=16
+NHIDDEN=1024
+NATT=16
+MAXSEQLEN=5184
+MASTER_PORT=$(shuf -n 1 -i 10000-65535)
+MPSIZE=1
+
+#SAMPLING ARGS
+TEMP=1.05
+#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p
+TOPK=100
+TOPP=0
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+
+MASTER_PORT=${MASTER_PORT} python generate_samples.py \
+       --deepspeed \
+       --model-parallel-size $MPSIZE \
+       --num-layers $NLAYERS \
+       --hidden-size $NHIDDEN \
+       --load $CHECKPOINT_PATH \
+       --num-attention-heads $NATT \
+       --max-position-embeddings 5184 \
+       --fp16 \
+       --temperature $TEMP \
+       --top_k $TOPK \
+       --top_p $TOPP \
+       --sandwich-ln \
+       --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \
+       --sparse-type standard \
+       --max-position-embeddings-finetune $MAXSEQLEN \
+       --generation-task "cuda-2d generation" \
+       --input-source ./input.txt \
+       --output-path samples_text2image \
+       --batch-size 2 \
+       --max-inference-batch-size 4 \
+       --device 0 \
+       --sparse-type standard \
+       $@
+
+
diff --git a/scripts/ds_config_zero.json b/scripts/ds_config_zero.json
index 5a1ea682f08d854ba98468980ad27e46c7879aee..1f0f35b84c9f2720dc39d1240cafd63a9e26b2f7 100755
--- a/scripts/ds_config_zero.json
+++ b/scripts/ds_config_zero.json
@@ -1,6 +1,6 @@
 {
-  "train_micro_batch_size_per_gpu": 4,
-  "gradient_accumulation_steps": 1,
+  "train_micro_batch_size_per_gpu": 6,
+  "gradient_accumulation_steps": 5,
   "steps_per_print": 1,
   "gradient_clipping": 0.1,
   "zero_optimization": {
@@ -23,7 +23,7 @@
   "optimizer": {
     "type": "Adam",
     "params": {
-      "lr": 0.00005,
+      "lr": 0.0005,
       "betas": [
         0.9,
         0.95
diff --git a/scripts/pretrain_multiple_nodes.sh b/scripts/pretrain_multiple_nodes.sh
index 4af4c0560f9f4a899b8444e55f1447462100d446..a03cf273fcb587611f27c318d469a96d62af2366 100755
--- a/scripts/pretrain_multiple_nodes.sh
+++ b/scripts/pretrain_multiple_nodes.sh
@@ -2,7 +2,7 @@
 
 # Change for multinode config
 
-NUM_WORKERS=2
+NUM_WORKERS=10
 NUM_GPUS_PER_WORKER=8
 MP_SIZE=1
 
@@ -11,35 +11,44 @@ script_dir=$(dirname $script_path)
 main_dir=$(dirname $script_dir)
 
 # OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=bond0 NCCL_IB_GID_INDEX=3 NCCL_NET_GDR_LEVEL=0"
-OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=ib0 NCCL_NET_GDR_LEVEL=2"
-HOST_FILE_PATH="hostfile"
+OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
+HOST_FILE_PATH="hostfile2"
 # OPTIONS_NCCL=""
 # HOST_FILE_PATH="hostfile_single"
 
+small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/zijian/zijian.bin.part_0.cogdata"
+full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin"
 
-config_json="$script_dir/ds_config_zero.json"
+config_json="$script_dir/ds_config.json"
 gpt_options=" \
-       --experiment-name cogview-ali_fashion_tutorial-12-1024-16 \
+       --experiment-name cogview-fixgrad-small-test \
        --img-tokenizer-num-tokens 8192 \
-       --dataset-type TokenizedDataset \
+       --dataset-type BinaryDataset \
        --model-parallel-size ${MP_SIZE} \
-       --num-layers 12 \
+       --num-layers 16 \
        --hidden-size 1024 \
        --num-attention-heads 16 \
        --save $main_dir/data/checkpoints \
-       --train-iters 40000 \
+       --train-iters 300000 \
        --resume-dataloader \
-       --train-data ./data/ali_vqvae_hard_biggerset_011.lmdb \
+       --train-data ${full_data} \
        --split 949,50,1 \
        --distributed-backend nccl \
        --lr-decay-style cosine \
        --warmup .1 \
        --checkpoint-activations \
        --deepspeed-activation-checkpointing \
-       --max-position-embeddings 1089 \
+       --max-position-embeddings 5184 \
        --max-memory-length 0 \
+       --sandwich-ln \
+       --txt-loss-scale 10 \
+       --sparse-type cuda_2d \
        --fp16 \
+       --save-interval 2000 \
+       --load data/checkpoints/cogview-compare
 "
+       #        
+
 
 
 gpt_options="${gpt_options}
diff --git a/scripts/pretrain_single_node.sh b/scripts/pretrain_single_node.sh
index a55298542516da980a31f96ee4dcb8f0ecca3a1a..84436360d7ec7fbdaa64fe7f7b69d086a436cbe4 100755
--- a/scripts/pretrain_single_node.sh
+++ b/scripts/pretrain_single_node.sh
@@ -15,13 +15,13 @@ OPTIONS_NCCL="NCCL_DEBUG=info"
 HOST_FILE_PATH="hostfile_single"
 
 
-config_json="$script_dir/ds_config.json"
+config_json="$script_dir/ds_config_zero.json"
 gpt_options=" \
        --experiment-name cogview-testlocal \
        --img-tokenizer-num-tokens 8192 \
        --dataset-type BinaryDataset \
        --model-parallel-size ${MP_SIZE} \
-       --num-layers 12 \
+       --num-layers 48 \
        --hidden-size 2560 \
        --num-attention-heads 40 \
        --save $main_dir/data/checkpoints \
diff --git a/scripts/testnan.sh b/scripts/testnan.sh
new file mode 100755
index 0000000000000000000000000000000000000000..a263aa1747a80ac99eb8efa1f57584cff67826b8
--- /dev/null
+++ b/scripts/testnan.sh
@@ -0,0 +1,59 @@
+#! /bin/bash
+
+# Change for multinode config
+
+NUM_WORKERS=1
+NUM_GPUS_PER_WORKER=1
+MP_SIZE=1
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+main_dir=$(dirname $script_dir)
+
+# OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=bond0 NCCL_IB_GID_INDEX=3 NCCL_NET_GDR_LEVEL=0"
+OPTIONS_NCCL="NCCL_DEBUG=info"
+HOST_FILE_PATH="hostfile_single"
+
+small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/zijian/zijian.bin.part_0.cogdata"
+full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_3leveltokens/merge.bin"
+
+config_json="$script_dir/ds_config.json"
+gpt_options=" \
+       --experiment-name cogview-testlocal \
+       --img-tokenizer-num-tokens 8192 \
+       --dataset-type BinaryDataset \
+       --model-parallel-size ${MP_SIZE} \
+       --num-layers 16 \
+       --hidden-size 1024 \
+       --num-attention-heads 16 \
+       --save $main_dir/data/checkpoints \
+       --train-iters 100000 \
+       --resume-dataloader \
+       --test-data ${full_data} \
+       --split 949,50,1 \
+       --distributed-backend nccl \
+       --lr-decay-style cosine \
+       --warmup .1 \
+       --checkpoint-activations \
+       --deepspeed-activation-checkpointing \
+       --max-position-embeddings 5184 \
+       --max-memory-length 0 \
+       --txt-loss-scale 2 \
+       --sandwich-ln \
+       --sparse-type cuda_2d \
+       --save-interval 2500 \
+       --load data/checkpoints/cogview-fixgrad-small08-25-09-38
+"
+       # --fp16 \
+
+gpt_options="${gpt_options}
+
+"
+            #    --deepspeed \
+            #    --deepspeed_config ${config_json} \
+
+run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_gpt2.py $@ ${gpt_options}"
+echo ${run_cmd}
+eval ${run_cmd}
+
+set +x
diff --git a/scripts/text2image.sh b/scripts/text2image.sh
index d3a6ae7ef5a7967201882e8546db7a746e2f0d9d..38509fbd8c2fd8ad07e6d59249a827bc9a0b4e8e 100755
--- a/scripts/text2image.sh
+++ b/scripts/text2image.sh
@@ -42,8 +42,8 @@ MASTER_PORT=${MASTER_PORT} python generate_samples.py \
        --generation-task text2image \
        --input-source ./input.txt \
        --output-path samples_text2image \
-       --batch-size 8 \
-       --max-inference-batch-size 8 \
+       --batch-size 4 \
+       --max-inference-batch-size 4 \
        --device 0 \
        $@
 
diff --git a/test_sparse_attention.py b/test_sparse_attention.py
index a8c22f8c9778360632ae00fd2531655ab8e2b5a0..abac1ed99ef230c2054eb997fe815c5cd119c9e0 100644
--- a/test_sparse_attention.py
+++ b/test_sparse_attention.py
@@ -4,7 +4,7 @@ from tqdm import tqdm
 
 import torch
 import numpy as np
-from mpu.sparse_transformer import standard_attention, sparse_attention_1d, sparse_attention_2d
+from mpu.sparse_transformer import standard_attention, sparse_attention_1d, sparse_attention_2d, sparse_attention_2dfull
 
 def test_sparse_attention_1d():       
     s, w, times = 4096 + 128, 128, 2
@@ -78,11 +78,12 @@ def test_sparse_attention_1d():
 def test_sparse_attention_2d():
     dtype = torch.float
     device = 'cuda'
-    b, n_head, hn = 1, 40, 2560
+    b, n_head, hn = 2, 16, 1024
     h = w = 32
-    layout = [10, 10, 10+h*w, 10+h*w*5]
+    layout = [10, 64, 64+h*w, 64+h*w*5]
     k1 = 9
     k2 = 7
+    k1h = k1*2-1
 
     qkv = torch.rand(3, b, layout[-1], hn, dtype=dtype, device=device)
     qkv2 = qkv.clone()
@@ -95,20 +96,22 @@ def test_sparse_attention_2d():
         m[i, :i+1] = 1
     m[layout[1]:, :layout[0]] = 1
     for i in tqdm(range(layout[1], layout[2])):
-        x = (i - layout[1]) // w
-        y = (i - layout[1]) % w
-        lx = max(0, x - k1 // 2)
-        ly = max(0, y - k1 // 2)
-        rx = min(h-1, x + k1 // 2)
-        ry = min(w-1, y + k1 // 2)
-        m[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1
-        m[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1
+        m[i, layout[1]:i+1] = 1
+    # for i in tqdm(range(layout[1], layout[2])):
+    #     x = (i - layout[1]) // w
+    #     y = (i - layout[1]) % w
+    #     lx = max(0, x - k1 // 2)
+    #     ly = max(0, y - k1 // 2)
+    #     rx = min(h-1, x + k1 // 2)
+    #     ry = min(w-1, y + k1 // 2)
+    #     m[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1
+    #     m[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1
     for i in tqdm(range(layout[2], layout[3])):
         x = (i - layout[2]) // (2*w)
         y = (i - layout[2]) % (2*w)
-        lx = max(0, x - k1 // 2)
+        lx = max(0, x - k1h // 2)
         ly = max(0, y - k1 // 2)
-        rx = min(2*h-1, x + k1 // 2)
+        rx = min(2*h-1, x + k1h // 2)
         ry = min(2*w-1, y + k1 // 2)
         m[i, layout[2]:layout[3]].view(h*2, w*2)[lx:x, ly:ry+1] = 1
         m[i, layout[2]:layout[3]].view(h*2, w*2)[x, ly:y+1] = 1
@@ -121,7 +124,7 @@ def test_sparse_attention_2d():
         ry = min(w-1, y + k2 // 2)
         m[i, layout[1]:layout[2]].view(h, w)[lx:rx+1, ly:ry+1] = 1
     
-    # mask[1:] = mask[0]
+    mask[1:] = mask[0]
     # mask[1][layout[1]:, layout[0]-1] = 0
 
     print('finish making mask...')
@@ -134,22 +137,25 @@ def test_sparse_attention_2d():
     
     torch.cuda.synchronize()
     t1 = time.time()
-    r2 = sparse_attention_2d(*qkv2, n_head, layout, mask[...,:layout[0]].unsqueeze(1), kernel_size=k1, kernel_size2=k2)
+    r2 = sparse_attention_2dfull(*qkv2, n_head, layout, mask[...,:layout[0]].unsqueeze(1), kernel_size=k1, kernel_size2=k2)
     torch.cuda.synchronize()
     t2 = time.time()
     print('times: standard ', t1-t0, ' sparse ', t2-t1)
     print(( (r1[:,:layout[0]]-r2[:,:layout[0]]).abs() / (r1[:,:layout[0]].abs()+r2[:,:layout[0]].abs())).max())
     print(( (r1[:,layout[1]:]-r2[:,layout[1]:]).abs() / (r1[:,layout[1]:].abs()+r2[:,layout[1]:].abs())).max())
     qkv.retain_grad()
-    l2 = r2[:,layout[1]:].mean()
-    l1 = r1[:,layout[1]:].mean()
+    l2 = r2[:,layout[1]:].sum()
+    l1 = r1[:,layout[1]:].sum()
+
     l2.backward()
     l1.backward()
 
     g1 = qkv.grad
     g2 = qkv2.grad
     print( (g1-g2).abs().max())
-    # import pdb;pdb.set_trace()
+    print( ((g1-g2).abs() / (g1.abs()+g2.abs()+1e-5)).max())
+
+    import pdb;pdb.set_trace()
     
 
 def seed_torch(seed=1029):
diff --git a/utils.py b/utils.py
index 257568e740f908e549e3b908765e582500ac16e0..c5079ce515c1559616b3e6d24dbf365791b82cb9 100755
--- a/utils.py
+++ b/utils.py
@@ -373,7 +373,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states=
 
     if mpu.get_data_parallel_rank() == 0:
         print('  successfully loaded {}'.format(checkpoint_name))
-
+    del sd
     return iteration