diff --git a/arguments.py b/arguments.py
index 7344533a58f37189ec2d76ec5fbd6e8ee1f93876..f87d809bd6229201ae25f97c82b76a9a604f7821 100755
--- a/arguments.py
+++ b/arguments.py
@@ -157,27 +157,15 @@ def add_text_generate_args(parser):
     group.add_argument("--temperature", type=float, default=1.0)
     group.add_argument("--top_p", type=float, default=0.0)
     group.add_argument("--top_k", type=int, default=0)
-    # group.add_argument("--out-seq-length", type=int, default=256)
-    group.add_argument("--generation-task", type=str,
-                       default='text2image',
-                       choices=['text2image',
-                                'image2text',
-                                'super-resolution',
-                                'low-level super-resolution',
-                                'post-selection',
-                                'raw',
-                                'cuda-2d generation'
-                                ],
-                       help='what type of inference task to use')
+    group.add_argument("--out-seq-length", type=int, default=256)
     group.add_argument('--input-source', type=str, default='interactive',
                        help='what input mode to use, interactive or path')
     group.add_argument('--output-path', type=str, default='./samples',
                        help='path to place the generated samples')
-    group.add_argument('--debug', action='store_true',
-                       help='Debug will merge all outputs.')
     group.add_argument('--with-id', action='store_true',
                        help='If each line is prepended with an id.')
     group.add_argument('--max-inference-batch-size', type=int, default=12)
+    group.add_argument('--device', type=int, default=0)
     return parser
 
 
@@ -218,7 +206,6 @@ def add_generation_api_args(parser):
     group.add_argument('--input_rec_path', default='input/')
     group.add_argument('--check_mode', default='code')
     group.add_argument('--time_interval', default=10)
-    group.add_argument('--device', default=None)
 
     return parser
     
diff --git a/generation/__init__.py b/generation/__init__.py
index 90c43c6e006fbf1ad1e6228ef38df90e7bf7e525..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100755
--- a/generation/__init__.py
+++ b/generation/__init__.py
@@ -1,3 +0,0 @@
-from .sampling import get_batch, filling_sequence, add_interlacing_beam_marks, inverse_prompt_score
-from .magnify import magnify
-from .cuda_2d_sampling import filling_sequence_cuda_2d
\ No newline at end of file
diff --git a/generation/autoregressive_sampling.py b/generation/autoregressive_sampling.py
index f1a047ccbea954fb3bad78be2e8f8c0889f470fa..b29964f0a2fbd77945347a26a3cbfead95de7a8c 100644
--- a/generation/autoregressive_sampling.py
+++ b/generation/autoregressive_sampling.py
@@ -37,6 +37,8 @@ def update_mems(hiddens, mems, max_memory_length):
             if new_memory_length <= query_length:
                 new_mems.append(hiddens[i][:, -new_memory_length:])
             else:
+                if mems[i].shape[0] < hiddens[i].shape[0]:
+                    mems[i] = mems[i].expand(hiddens[i].shape[0], *mems[i].shape[1:])
                 new_mems.append(torch.cat((mems[i][:, -new_memory_length+query_length:], hiddens[i]), dim=1))
     return new_mems
 
@@ -45,8 +47,9 @@ def filling_sequence(
         model, 
         seq, 
         batch_size,
+        strategy=BaseStrategy(),
         max_memory_length=100000,
-        strategy=BaseStrategy()
+        log_attention_weights=None
         ):
     '''
         seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
@@ -60,12 +63,12 @@ def filling_sequence(
     assert context_length > 0
     tokens, attention_mask, position_ids = get_masks_and_position_ids(seq)
     tokens = tokens[..., :context_length]
-
+    attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
     # initialize generation
     counter = context_length - 1 # Last fixed index is ``counter'' 
     index = 0 # Next forward starting index, also the length of cache.
     mems = [] # mems are the first-level citizens here, but we don't assume what is memorized.
-    
+        
     # step-by-step generation
     while counter < len(seq) - 1:
         # Now, we want to generate seq[counter + 1],
@@ -82,18 +85,20 @@ def filling_sequence(
             continue
 
         # forward
+        model.log_attention_weights = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen
         logits, *mem_kv = model(
             tokens[:, index:], 
             position_ids[..., index: counter+1],
-            attention_mask[..., index: counter+1, :counter+1], # TODO mem
+            attention_mask[..., index: counter+1, :counter+1], # TODO memlen
             *mems
         )
         mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
         counter += 1
         index = counter
-
         # sampling
         logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size]
+        tokens = tokens.expand(batch_size, -1)
         tokens, mems = strategy.forward(logits, tokens, mems)
         
+    model.log_attention_weights = None
     return tokens
\ No newline at end of file
diff --git a/generation/cuda2d_sampling.py b/generation/cuda2d_sampling.py
index bd755524ebcde00180058d111bfff13b448a9f21..20ffa38242c66421d28b8cdb84abb53373d51ec9 100644
--- a/generation/cuda2d_sampling.py
+++ b/generation/cuda2d_sampling.py
@@ -12,7 +12,7 @@ import sys
 import math
 import random
 import torch
-from .sampling_strategies import BaseStrategy
+from .sampling_strategies import IterativeEntfilterStrategy
 
 def filling_sequence(
         model, 
@@ -20,11 +20,15 @@ def filling_sequence(
         seq1, 
         warmup_steps=3,
         block_hw=(4, 4),
-        strategy=BaseStrategy(topk=10)
+        strategy=IterativeEntfilterStrategy(topk=10)
         ):
     '''
         seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
-            4095 {layout[2]} final_token
+            4095 {layout[2]} final_token.
+
+        Attention:
+        The sampling temperature are changing, temporally we hard code them here.
+        The temperature in the strategy is not used.
     '''
     assert hasattr(model, 'layout')
     layout = model.layout
@@ -46,7 +50,7 @@ def filling_sequence(
     assert seq.shape[1] == layout[-1] + 1
 
     # build initial tokens, attention_mask, and position_ids
-    tokens = seq[:, :-1].clone()
+    tokens = seq.clone()
     attention_mask = torch.ones(layout[1], layout[1]).tril().to(device)
     attention_mask[n_pad:, :n_pad] = 0
     position_ids = torch.cat((
@@ -54,104 +58,32 @@ def filling_sequence(
         torch.arange(0, layout[1] - n_pad), 
         torch.arange(0, layout[2]-layout[1]))).to(device)
 
-    # iterative refining
+    # prepare for interation
+    unfixed = (tokens < 0)
+    unfixed[:, -4096] = True
     ll, rr = block_hw
+    edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
     num_steps = warmup_steps + ll + rr - 2
-    for step_cnt in range(num_steps):
-        logits, *_dump = model(tokens, position_ids, attention_mask)
-
-        # warmup 
-        real_topk = 10
-        warmup_steps = 3
-        iterative_step= warmup_steps + 6
+    # interative refining
+    for step_cnt in range(1, num_steps+1):
+        logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask)
         if step_cnt <= warmup_steps:
             real_temp = 0.1
-        elif step_cnt == warmup_steps + 1:
-            real_temp = 0.55
-        elif step_cnt > warmup_steps + 1:
-            real_temp = 0.45
-        # if  5 < step_cnt:
-        #     real_topk = 200
-        # sampling
-        for invalid_slice in invalid_slices: # forbide to generate other tokens
-            logits[..., invalid_slice] = -float('Inf')
-        assert args.top_k > 0
-        
-        # probs0 = F.softmax(logits/real_temp, dim=-1)
-        topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1)
-        ent = -(topraw * topraw.log()).sum(dim=-1)
-        # topsum = topraw.sum(dim=-1)
-        if step_cnt > warmup_steps:
-            # import pdb;pdb.set_trace()
-            real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.3).unsqueeze(-1) + 0.6
-            # import pdb;pdb.set_trace()
+            tokens = strategy.forward(logits, tokens, real_temp)
         else:
-            real_temp2 = real_temp
-        # import pdb;pdb.set_trace()
-        probs = F.softmax(logits/real_temp2, dim=-1)
-        tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1)
-        prev = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
-        edge_idx = tk_idx[:, :, -1:]
-        edge_value = tk_value[:, :, -1:]
-        edge_mask = probs.gather(dim=-1, index=prev) < edge_value
-        prev[edge_mask] = edge_idx[edge_mask]
-        prev.squeeze_(-1)
-        # tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1])
-        # prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1)
-        # prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1)
-        # update unfixed
-        choice = 1
-        if choice == 0 and warmup_steps < step_cnt:
-            mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
-            # import pdb;pdb.set_trace()
-            dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:])
-
-            new_fixed = unfixed.clone()
-            moved_new_fixed = new_fixed[:, 2:]
-            moved_new_fixed &= dprob
-            moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not()
-            moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not()
-            # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not()
-            moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not()
-            moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not()
-            # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not()
-        elif choice == 1 and warmup_steps < step_cnt:
-            new_fixed = unfixed & False
-            ll, rr = 4, 4
+            real_temp = 1.05
+            new_tokens = strategy.forward(
+                logits, tokens, real_temp,
+                entfilter=1.3,
+                filter_topk=5,
+                temperature2=0.6
+            )
+            tokens[unfixed] = new_tokens[unfixed]
+            # fixed tokens (update unfixed)
             for x in range(min(ll, step_cnt - warmup_steps)):
                 y = step_cnt - warmup_steps - x - 1
                 if y < rr:
                     print(x,y)
-                    new_fixed[..., -4096:].view(batch_size, 64//ll, ll, 64//rr, rr)[:, :, x, :, y] = True
-            new_fixed &= unfixed
-        else:
-            new_fixed = unfixed & False # TODO
-        new_fixed[:, -1] = True
-
-        # with open(f'bed{step_cnt}.txt', 'w') as fout:
-        #     for i, prob in enumerate(topraw[0, -4096:]):
-        #         s = ' '.join([str(x) for x in prob.tolist()])
-        #         fout.write(f'{i} {s}\n')
-
-        unfixed &= new_fixed.logical_not()
-        # update seq and tokens
-        seq[new_fixed] = prev[new_fixed[:, 1:]]
-        tokens = seq[:, :-1].clone()
-        tokens[:,1:][unfixed[:, 1:-1]] = prev[:, :-1][unfixed[:, 1:-1]]
-
-        if step_cnt == iterative_step: 
-            seq[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
-            n_unfixed = unfixed.sum(dim=-1).tolist()
-            print(f'Exit with {n_unfixed} unfixed tokens.')
-            break
-        if args.debug:
-            from torchvision.utils import save_image
-            seqt = seq.clone()
-            seqt[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
-            imgs.extend([tokenizer.img_tokenizer.DecodeIds(s[-4096:]) for s in seqt])
-    if args.debug:
-        imgs = torch.cat(imgs, dim=0)
-        save_image(imgs, f'steps{device}.jpg', normalize=True)
-    model.module.transformer.max_memory_length = args.max_memory_length
-
-    return seq
\ No newline at end of file
+                    unfixed[..., -(layout[-1] - layout[-2]):].view(
+                        batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = False
+    return tokens
\ No newline at end of file
diff --git a/generation/cuda_2d_sampling.py b/generation/cuda_2d_sampling.py
index 21e8dcc2707ea067a6e99b72e988347f15ad8395..7cd04402897e113fcc4931aebb006e1c21fb7294 100644
--- a/generation/cuda_2d_sampling.py
+++ b/generation/cuda_2d_sampling.py
@@ -116,19 +116,20 @@ def filling_sequence_cuda_2d(
         # update unfixed
         choice = 1
         if choice == 0 and warmup_steps < step_cnt:
-            mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
-            # import pdb;pdb.set_trace()
-            dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:])
+            # mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
+            # # import pdb;pdb.set_trace()
+            # dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:])
 
-            new_fixed = unfixed.clone()
-            moved_new_fixed = new_fixed[:, 2:]
-            moved_new_fixed &= dprob
-            moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not()
-            moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not()
-            # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not()
-            moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not()
-            moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not()
-            # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not()
+            # new_fixed = unfixed.clone()
+            # moved_new_fixed = new_fixed[:, 2:]
+            # moved_new_fixed &= dprob
+            # moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not()
+            # moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not()
+            # # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not()
+            # moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not()
+            # moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not()
+            # # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not()
+            pass
         elif choice == 1 and warmup_steps < step_cnt:
             new_fixed = unfixed & False
             ll, rr = 4, 4
diff --git a/generation/sampling_strategies/__init__.py b/generation/sampling_strategies/__init__.py
index f1096032cb3c88f2d3ba041d1974fe28a88bf43a..2f71e09703c38106088167808d3be758ae8c9b24 100644
--- a/generation/sampling_strategies/__init__.py
+++ b/generation/sampling_strategies/__init__.py
@@ -1 +1,2 @@
-from .base_strategy import BaseStrategy
\ No newline at end of file
+from .base_strategy import BaseStrategy
+from .iterative_entfilter_strategy import IterativeEntfilterStrategy
\ No newline at end of file
diff --git a/generation/sampling_strategies/base_strategy.py b/generation/sampling_strategies/base_strategy.py
index 99a0a74e6e54b6972866d0fb664c079cac75a170..e46a8ca4e505c2891bae2599ebe10746cc54d876 100644
--- a/generation/sampling_strategies/base_strategy.py
+++ b/generation/sampling_strategies/base_strategy.py
@@ -20,27 +20,20 @@ def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
     return logits
 
 class BaseStrategy:
-    def __init__(self, invalid_slices=[], temperature=1., topk=200, debias=False):
+    def __init__(self, invalid_slices=[], temperature=1., topk=200, eps=1e-4):
         self.invalid_slices = invalid_slices
         self.temperature = temperature
         self.topk = topk
-        self.debias = debias
+        self.eps = eps
     def forward(self, logits, tokens, mems, temperature=None):
         if temperature is None:
             temperature = self.temperature 
         logits = logits / temperature
         for invalid_slice in self.invalid_slices:
-            logits[..., invalid_slice] = -float('Inf')
-        if self.debias:
-            probs = F.softmax(logits, dim=-1)
-            tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1)
-            pred = torch.multinomial(probs, num_samples=1)
-            for j in range(0, pred.shape[0]):
-                if probs[j, pred[j,-1]] < tk_value[j, -1]:
-                    pred[j, -1] = tk_idx[j, torch.randint(tk_idx.shape[-1]-100, tk_idx.shape[-1], (1,))] # 100 is the last N as outlier, which is chosen casually
-        else:
-            logits = top_k_logits_(logits)
-            probs = F.softmax(logits, dim=-1)
-            pred = torch.multinomial(probs, num_samples=1)
+            logits[..., invalid_slice] = -65504
+            
+        logits = top_k_logits_(logits, self.topk)
+        probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
+        pred = torch.multinomial(probs, num_samples=1)
         tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
         return tokens, mems
diff --git a/generation/sampling_strategies/iterative_entfilter_strategy.py b/generation/sampling_strategies/iterative_entfilter_strategy.py
new file mode 100644
index 0000000000000000000000000000000000000000..84196d0c8fae8db48ed515ff265dea9e22136409
--- /dev/null
+++ b/generation/sampling_strategies/iterative_entfilter_strategy.py
@@ -0,0 +1,55 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   iterative_entfilter_strategy.py
+@Time    :   2021/10/09 14:32:29
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+
+import torch
+import torch.nn.functional as F
+
+def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
+    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+    logits[indices_to_remove] = filter_value     
+    return logits
+
+class IterativeEntfilterStrategy:
+    def __init__(self, invalid_slices=[], temperature=1., topk=10):
+        self.invalid_slices = invalid_slices
+        self.temperature = temperature
+        self.topk = topk
+
+    def forward(self, logits, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
+        # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
+        if temperature is None:
+            temperature = self.temperature 
+        # check entropy filter
+        if entfilter is not None:
+            assert temperature2 is not None
+            topraw = (torch.topk(logits, filter_topk, dim=-1)[0]).softmax(dim=-1)
+            ent = -(topraw * topraw.log()).sum(dim=-1) # [batch_size, seq_length]
+            temperature = torch.tensor([[[temperature - temperature2]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > entfilter).unsqueeze(-1) + temperature2
+        logits = logits / temperature
+        for invalid_slice in self.invalid_slices:
+            logits[..., invalid_slice] = -float('Inf')
+        
+        # debiased topk
+        probs = F.softmax(logits, dim=-1)
+        tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1)
+        pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
+        edge_idx = tk_idx[:, :, -1:]
+        edge_value = tk_value[:, :, -1:]
+        edge_mask = probs.gather(dim=-1, index=pred) < edge_value
+        pred[edge_mask] = edge_idx[edge_mask] # replace outliers as the "filter_topk"-th token
+        pred.squeeze_(-1) # [batch_size, seq_length]
+        
+        assert tokens.shape[1] == pred.shape[1] + 1
+        tokens = torch.cat((tokens[:, :1], pred), dim=1)
+        return tokens
\ No newline at end of file
diff --git a/generation/utils.py b/generation/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bad073413b6c3efa6f71b3e46eeacb883bd504b0
--- /dev/null
+++ b/generation/utils.py
@@ -0,0 +1,78 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   utils.py
+@Time    :   2021/10/09 17:18:26
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+import time
+import stat
+from datetime import datetime
+from torchvision.utils import save_image
+import torch.distributed as dist
+
+
+def timed_name(prefix, suffix=None, path=None):
+    return os.path.join(
+        path, 
+        f"{prefix}-{datetime.now().strftime('%m-%d-%H-%M-%S')}{suffix}"
+    )
+
+def save_multiple_images(imgs, path, debug=True):
+    # imgs: list of tensor images
+    if debug:
+        imgs = torch.cat(imgs, dim=0)
+        print("\nSave to: ", path, flush=True)
+        save_image(imgs, path, normalize=True)
+    else:
+        print("\nSave to: ", path, flush=True)
+        for i in range(len(imgs)):
+            save_image(imgs[i], os.path.join(path, f'{i}.jpg'), normalize=True)
+            os.chmod(os.path.join(path,f'{i}.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
+        save_image(torch.cat(imgs, dim=0), os.path.join(path,f'concat.jpg'), normalize=True)
+        os.chmod(os.path.join(path,f'concat.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
+
+def generate_continually(func, input_source='interactive'):
+    if input_source == 'interactive':
+        while True:
+            raw_text = input("\nPlease Input Query (stop to exit) >>> ") 
+            raw_text = raw_text.strip()
+            if not raw_text:
+                print('Query should not be empty!')
+                continue
+            if raw_text == "stop":
+                return 
+            try:
+                start_time = time.time()
+                func(raw_text)
+                print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
+            except (ValueError, FileNotFoundError) as e:
+                print(e)
+                continue
+    else:
+        with open(input_source, 'r') as fin:
+            inputs = fin.readlines()
+        err_linenos = []
+        for line_no, raw_text in enumerate(inputs):
+            if line_no % dist.get_world_size() != dist.get_rank():
+                continue
+            rk = dist.get_rank()
+            print(f'Working on No. {line_no} on {rk}... ')
+            raw_text = raw_text.strip()
+            if len(raw_text) == 0:
+                continue
+            try:
+                start_time = time.time()
+                func(raw_text)
+                print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
+            except (ValueError, FileNotFoundError) as e:
+                err_linenos.append(line_no)
+                continue
+        print(err_linenos)
diff --git a/inference_cogview.py b/inference_cogview.py
new file mode 100644
index 0000000000000000000000000000000000000000..56bdd96ead78f155fce08552e6c582facbeffbdc
--- /dev/null
+++ b/inference_cogview.py
@@ -0,0 +1,97 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   inference_cogview.py
+@Time    :   2021/10/09 19:41:58
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+import argparse
+
+from arguments import get_args
+from model.cached_autoregressive_model import CachedAutoregressiveModel
+from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
+from tokenization import get_tokenizer
+from generation.sampling_strategies import BaseStrategy
+from generation.autoregressive_sampling import filling_sequence
+from generation.utils import timed_name, save_multiple_images, generate_continually
+
+def main(args):
+    initialize_distributed(args)
+    tokenizer = prepare_tokenizer(args)
+    # build model 
+    model = CachedAutoregressiveModel(args)
+    if args.fp16:
+        model = model.half()
+    model = model.to(args.device)
+    load_checkpoint(model, args)
+    set_random_seed(args.seed)
+    
+    # define function for each query
+    query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024' if not args.full_query else '{}'
+    invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
+    strategy = BaseStrategy(invalid_slices, 
+        temperature=args.temperature, topk=args.top_k)
+    
+    def process(raw_text):
+        if args.with_id:
+            query_id, raw_text = raw_text.split()
+        print('raw text: ', raw_text)
+        text = query_template.format(raw_text)
+        seq = tokenizer.parse_query(text, img_size=args.img_size)
+        if len(seq) > 1088:
+            raise ValueError('text too long.')
+        # calibrate text length
+        txt_len = seq.index(tokenizer['[BASE]'])
+        log_attention_weights = torch.zeros(len(seq), len(seq), 
+            device=args.device, dtype=torch.half if args.fp16 else torch.float32)
+        log_attention_weights[txt_len+2:, 1:txt_len] = 1.8 if txt_len <= 10 else 1.4 # TODO args
+        # generation
+        seq = torch.cuda.LongTensor(seq, device=args.device)
+        mbz = args.max_inference_batch_size
+        assert args.batch_size < mbz or args.batch_size % mbz == 0
+        output_list = []
+        for tim in range(max(args.batch_size // mbz, 1)):
+            output_list.append(
+                filling_sequence(model, seq.clone(),
+                    batch_size=min(args.batch_size, mbz),
+                    strategy=strategy,
+                    log_attention_weights=log_attention_weights
+                    )
+                )
+        output_tokens = torch.cat(output_list, dim=0)
+        # decoding
+        imgs, txts = [], []
+        for seq in output_tokens:
+            decoded_txts, decoded_imgs = tokenizer.DecodeIds(seq.tolist())
+            imgs.append(decoded_imgs[-1]) # only the last image (target)
+        # save
+        if args.with_id:
+            full_path = os.path.join(args.output_path, query_id)
+            os.makedirs(full_path, exist_ok=True)
+            save_multiple_images(imgs, full_path, False)
+        else:
+            prefix = raw_text.replace('/', '')[:20]
+            full_path = timed_name(prefix, '.jpg', args.output_path)
+            save_multiple_images(imgs, full_path, True)
+    
+    os.makedirs(args.output_path, exist_ok=True)
+    generate_continually(process, args.input_source)
+
+if __name__ == "__main__":
+    py_parser = argparse.ArgumentParser(add_help=False)
+    py_parser.add_argument('--full-query', action='store_true')
+    py_parser.add_argument('--img-size', type=int, default=256)
+
+    known, args_list = py_parser.parse_known_args()
+    args = get_args(args_list)
+    args = argparse.Namespace(**vars(args), **vars(known))
+    
+    with torch.no_grad():
+        main(args)
\ No newline at end of file
diff --git a/model/cached_autoregressive_model.py b/model/cached_autoregressive_model.py
index 225ff3e6f9ed6d83a34676a3918a30e4a16ba70f..1d19fb82569103aed762d25463d50b014a760c2b 100755
--- a/model/cached_autoregressive_model.py
+++ b/model/cached_autoregressive_model.py
@@ -37,14 +37,15 @@ class CachedAutoregressiveModel(BaseModel):
             mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1)
 
         # same as training
-        query_layer = self._transpose_for_scores(mixed_query_layer)
-        key_layer = self._transpose_for_scores(mixed_key_layer)
-        value_layer = self._transpose_for_scores(mixed_value_layer)
-        context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn=None, log_attention_weights=self.log_attention_weights)
+        query_layer = attn_module._transpose_for_scores(mixed_query_layer)
+        key_layer = attn_module._transpose_for_scores(mixed_key_layer)
+        value_layer = attn_module._transpose_for_scores(mixed_value_layer)
+        context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=self.log_attention_weights)
+        
         context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+        new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
         context_layer = context_layer.view(*new_context_layer_shape)
-        output = self.dense(context_layer)
+        output = attn_module.dense(context_layer)
         
         # new mem this layer
         new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous()
diff --git a/scripts/text2image_cogview.sh b/scripts/text2image_cogview.sh
new file mode 100755
index 0000000000000000000000000000000000000000..4fc46365914d359466343eb0185c1c147d4c7155
--- /dev/null
+++ b/scripts/text2image_cogview.sh
@@ -0,0 +1,41 @@
+#!/bin/bash
+
+CHECKPOINT_PATH=pretrained/cogview/cogview-base
+NLAYERS=48
+NHIDDEN=2560
+NATT=40
+MAXSEQLEN=1089
+MASTER_PORT=$(shuf -n 1 -i 10000-65535)
+MPSIZE=1
+
+#SAMPLING ARGS
+TEMP=1.03
+TOPK=200
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+
+MASTER_PORT=${MASTER_PORT} python inference_cogview.py \
+       --tokenizer-type cogview \
+       --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \
+       --mode inference \
+       --distributed-backend nccl \
+       --max-sequence-length 1089 \
+       --sandwich-ln \
+       --fp16 \
+       --model-parallel-size $MPSIZE \
+       --num-layers $NLAYERS \
+       --hidden-size $NHIDDEN \
+       --load $CHECKPOINT_PATH \
+       --num-attention-heads $NATT \
+       --temperature $TEMP \
+       --top_k $TOPK \
+       --sandwich-ln \
+       --input-source ./input.txt \
+       --output-path samples_text2image \
+       --batch-size 8 \
+       --max-inference-batch-size 8 \
+       --device 0 \
+       $@
+
+
diff --git a/training/__init__.py b/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc0337ffc828400323c95991e8177fdcf2ff7201
--- /dev/null
+++ b/training/__init__.py
@@ -0,0 +1,2 @@
+from .deepspeed_training import initialize_distributed, set_random_seed, prepare_tokenizer
+from .model_io import load_checkpoint
\ No newline at end of file
diff --git a/training/deepspeed_training.py b/training/deepspeed_training.py
index cff164c1c2c5512f4b7f9361022f104a7ea9c747..80d0e422fac82ddb72d3a878bc8f22caccf7163d 100644
--- a/training/deepspeed_training.py
+++ b/training/deepspeed_training.py
@@ -59,14 +59,13 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio
     else:
         args.experiment_name = args.experiment_name + datetime.now().strftime("%m-%d-%H-%M")
         
-    # Pytorch distributed.
+    # Pytorch distributed. must before seed
     initialize_distributed(args)
     set_random_seed(args.seed) # Random seeds for reproducability.
-    
     # init tokenizer
-    tokenizer = get_tokenizer(args)
+    prepare_tokenizer(args) # args.vocab_size is set.
     # Data stuff.
-    train_data, val_data, test_data, args.vocab_size = get_train_val_test_data(args, hooks['create_dataset_function'])
+    train_data, val_data, test_data = make_loaders(args, hooks['create_dataset_function'])
 
     # Model, optimizer, and learning rate.
     model, optimizer = setup_model_and_optimizer(args, model_cls)
@@ -514,72 +513,39 @@ def initialize_distributed(args):
 
     # Optional DeepSpeed Activation Checkpointing Features
     if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing:
-        set_deepspeed_activation_checkpointing(args)
+        set_deepspeed_activation_checkpointing(args) # TODO manual model-parallel seed
 
 
 def set_random_seed(seed):
     """Set random seed for reproducability."""
-
     if seed is not None and seed > 0:
         random.seed(seed)
         np.random.seed(seed)
         torch.manual_seed(seed)
-        mpu.model_parallel_cuda_manual_seed(seed)
-
-
-def get_train_val_test_data(args, create_dataset_function):
-    """Load the data on rank zero and boradcast number of tokens to all GPUS."""
-
-    (train_data, val_data, test_data) = (None, None, None)
-
-    # Data loader only on rank 0 of each model parallel group.
-    if mpu.get_model_parallel_rank() == 0:
-        train_data, val_data, test_data = make_loaders(args, create_dataset_function)
-        num_tokens = get_tokenizer().num_tokens
-
-        before = num_tokens
-        after = before
-        multiple = args.make_vocab_size_divisible_by * \
-                   mpu.get_model_parallel_world_size()
-        while (after % multiple) != 0:
-            after += 1
-        print_rank_0('> padded vocab (size: {}) with {} dummy '
-                     'tokens (new size: {})'.format(
-                         before, after - before, after))
-        token_counts = torch.cuda.LongTensor(
-            [after, int(args.do_train), int(args.do_valid), int(args.do_test)])
-    else:
-        token_counts = torch.cuda.LongTensor([0, 0, 0, 0])
-    # Broadcast num tokens.
-    torch.distributed.broadcast(token_counts,
-                                mpu.get_model_parallel_src_rank(),
-                                group=mpu.get_model_parallel_group())
-    num_tokens = token_counts[0].item()
-    args.do_train = token_counts[1].item()
-    args.do_valid = token_counts[2].item()
-    args.do_test = token_counts[3].item()
-
-    return train_data, val_data, test_data, num_tokens
-
-def see_memory_usage(message, force=False):
-    if not force:
-        return
-    dist.barrier()
-    if dist.get_rank() == 0:
-        print(message)
-        print("Memory Allocated ", torch.cuda.memory_allocated()/(1024*1024*1024), "GigaBytes")
-        print("Max Memory Allocated ", torch.cuda.max_memory_allocated()/(1024*1024*1024), "GigaBytes")
-        print("Cache Allocated ", torch.cuda.memory_cached()/(1024*1024*1024), "GigaBytes")
-        print("Max cache Allocated ", torch.cuda.max_memory_cached()/(1024*1024*1024), "GigaBytes")
-        print(" ")
-
-def seed_torch(seed=1029):
-    random.seed(seed)
-    np.random.seed(seed)
-    torch.manual_seed(seed)
-    torch.cuda.manual_seed(seed)
-    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
-    torch.backends.cudnn.benchmark = False
-    torch.backends.cudnn.deterministic = True
-    torch.backends.cudnn.enabled = False
+        torch.cuda.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
+        torch.backends.cudnn.benchmark = False
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.enabled = False
+        torch.backends.cuda.matmul.allow_tf32 = False
+        if hasattr(mpu, 'model_parallel_cuda_manual_seed'):
+            mpu.model_parallel_cuda_manual_seed(seed)
+        
+        
+def prepare_tokenizer(args):
+    tokenizer = get_tokenizer(args)
+    num_tokens = tokenizer.num_tokens
+    before = num_tokens
+    after = before
+    multiple = args.make_vocab_size_divisible_by * \
+               mpu.get_model_parallel_world_size()
+    while (after % multiple) != 0:
+        after += 1
+    print_rank_0('> padded vocab (size: {}) with {} dummy '
+                 'tokens (new size: {})'.format(
+        before, after - before, after))
+    args.vocab_size = after
+    print("prepare tokenizer done", flush=True)
+    return tokenizer
+
 
diff --git a/training/model_io.py b/training/model_io.py
index ee2e8f1be426a0dc4a215d161e583246c387c583..ce434fdfb63396aab4fdcfa3f2b761625e7011f6 100644
--- a/training/model_io.py
+++ b/training/model_io.py
@@ -121,7 +121,7 @@ def load_checkpoint(model, args):
                 torch.distributed.get_rank(), checkpoint_name))
     sd = torch.load(checkpoint_name, map_location='cpu')
     
-    assert not args.do_train or args.deepspeed
+    assert not hasattr(args, 'do_train') or not args.do_train or args.deepspeed
     if args.deepspeed:
         module = model.module
     else: # inference without deepspeed
@@ -136,8 +136,10 @@ def load_checkpoint(model, args):
             raise ValueError(f'Missing keys for inference: {missing_keys}.')
         else: # new params
             assert all(name.find('mixins')>=0 for name in missing_keys)
+            assert args.mode == 'finetune'
             module.reinit() # initialize mixins
-    model.optimizer.refresh_fp32_params() # restore fp32 weights from module
+    if args.mode != 'inference':
+        model.optimizer.refresh_fp32_params() # restore fp32 weights from module
 
     # Iterations.
     if args.mode == 'finetune':