diff --git a/SwissArmyTransformer/arguments.py b/SwissArmyTransformer/arguments.py
index dfa75c910489d592262a7483c6b0cfa2d901660b..83e3322219a3d8f17dd4f1ced5e78223744ca91d 100755
--- a/SwissArmyTransformer/arguments.py
+++ b/SwissArmyTransformer/arguments.py
@@ -33,6 +33,8 @@ def add_model_config_args(parser):
                        help='num of transformer attention heads')
     group.add_argument('--hidden-size', type=int, default=1024,
                        help='tansformer hidden size')
+    group.add_argument('--inner-hidden-size', type=int, default=None)
+    group.add_argument('--hidden-size-per-attention-head', type=int, default=None)
     group.add_argument('--num-layers', type=int, default=24,
                        help='num decoder layers')
     group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
@@ -129,6 +131,8 @@ def add_training_args(parser):
     
     group.add_argument('--fp16', action='store_true',
                        help='Run model in fp16 mode')
+    group.add_argument('--bf16', action='store_true',
+                       help='Run model in fp16 mode')
     
     return parser
 
@@ -146,7 +150,8 @@ def add_evaluation_args(parser):
                             'validation/test for')
     group.add_argument('--eval-interval', type=int, default=1000,
                        help='interval between running evaluation on validation set')
-
+    group.add_argument('--strict-eval', action='store_true',
+                       help='won\'t enlarge or randomly map eval-data, and eval full eval-data.')
     return parser
 
 
@@ -297,25 +302,28 @@ def get_args(args_list=None):
         print('using world size: {} and model-parallel size: {} '.format(
             args.world_size, args.model_parallel_size))
 
-    if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_config is not None:
-        with open(args.deepspeed_config) as file:
-            deepspeed_config = json.load(file)
-        if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
-            args.fp16 = True
-        else:
-            args.fp16 = False
+    if hasattr(args, "deepspeed") and args.deepspeed:
         if args.checkpoint_activations:
             args.deepspeed_activation_checkpointing = True
-        if "train_micro_batch_size_per_gpu" in deepspeed_config:
-            args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
-        if "gradient_accumulation_steps" in deepspeed_config:
-            args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
         else:
-            args.gradient_accumulation_steps = None
-        if "optimizer" in deepspeed_config:
-            optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
-            args.lr = optimizer_params_config.get("lr", args.lr)
-            args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
+            args.deepspeed_activation_checkpointing = False
+        if args.deepspeed_config is not None:
+            with open(args.deepspeed_config) as file:
+                deepspeed_config = json.load(file)
+            if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
+                args.fp16 = True
+            else:
+                args.fp16 = False
+            if "train_micro_batch_size_per_gpu" in deepspeed_config:
+                args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
+            if "gradient_accumulation_steps" in deepspeed_config:
+                args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
+            else:
+                args.gradient_accumulation_steps = None
+            if "optimizer" in deepspeed_config:
+                optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
+                args.lr = optimizer_params_config.get("lr", args.lr)
+                args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
     return args
 
 
diff --git a/SwissArmyTransformer/data_utils/configure_data.py b/SwissArmyTransformer/data_utils/configure_data.py
index b7d6299c69ef5d14030876fb6462b89b2daec654..5d53cded92c2264a82021ffb16912d8be8a45ee9 100755
--- a/SwissArmyTransformer/data_utils/configure_data.py
+++ b/SwissArmyTransformer/data_utils/configure_data.py
@@ -24,14 +24,17 @@ from .samplers import DistributedBatchSampler
 from SwissArmyTransformer import mpu
 
 
-def make_data_loader(dataset, batch_size, num_iters, args):
+def make_data_loader(dataset, batch_size, args):
     world_size = torch.distributed.get_world_size(
         group=mpu.get_data_parallel_group())
     rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
     distributed = world_size > 1
 
     sampler = torch.utils.data.SequentialSampler(dataset)
-    drop_last = distributed
+    # drop_last = distributed
+    drop_last = True # TODO will always drop last to keep the consistency. 
+    # or, how to avg in eval last batch?
+    
     # the GPUs in the same model parallel group receive the same data
     if distributed: # TODO reformat this, but it is not urgent
         gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1)
@@ -52,7 +55,7 @@ def make_data_loader(dataset, batch_size, num_iters, args):
     return data_loader
 
 
-def make_dataset_full(path, split, args, create_dataset_function, **kwargs):
+def make_dataset_full(path, split, args, create_dataset_function, random_mapping=True, **kwargs):
     """function to create datasets+tokenizers for common options"""
     print('make dataset ...', path)
     if split is None:
@@ -67,12 +70,8 @@ def make_dataset_full(path, split, args, create_dataset_function, **kwargs):
     ds = ConcatDataset(ds)
     if should_split(split):
         ds = split_ds(ds, split, block_size=args.block_size)
-    else:
+    elif random_mapping:
         ds = RandomMappingDataset(ds)
-
-    # if should_split(split):
-    #     ds = split_ds(ds, split) # Large dataset, cannot shuffle, randomly mapping
-    # FIXME this will merge valid set and train set.
     return ds
 
 def make_loaders(args, create_dataset_function):
@@ -115,25 +114,25 @@ def make_loaders(args, create_dataset_function):
     # make training and val dataset if necessary
     if valid is None and args.valid_data is not None:
         eval_set_args['path'] = args.valid_data
-        valid = make_dataset(**eval_set_args, args=args)
+        valid = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
     if test is None and args.test_data is not None:
         eval_set_args['path'] = args.test_data
-        test = make_dataset(**eval_set_args, args=args)
+        test = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
 
     # wrap datasets with data loader
     if train is not None and args.batch_size > 0:
-        train = make_data_loader(train, batch_size, args.train_iters, args)
+        train = make_data_loader(train, batch_size, args)
         args.do_train = True
     else:
         args.do_train = False
     eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
     if valid is not None:
-        valid = make_data_loader(valid, eval_batch_size, args.train_iters, args)
+        valid = make_data_loader(valid, eval_batch_size, args)
         args.do_valid = True
     else:
         args.do_valid = False
     if test is not None:
-        test = make_data_loader(test, eval_batch_size, len(test) // eval_batch_size + 1, args)
+        test = make_data_loader(test, eval_batch_size, args)
         args.do_test = True
     else:
         args.do_test = False
diff --git a/SwissArmyTransformer/data_utils/datasets.py b/SwissArmyTransformer/data_utils/datasets.py
index bec1207112194e70221c8f52852f2c9855ed951e..ef46e2da7bc185465320792a89b6c2804dce1ab6 100755
--- a/SwissArmyTransformer/data_utils/datasets.py
+++ b/SwissArmyTransformer/data_utils/datasets.py
@@ -65,3 +65,18 @@ class BinaryDataset(Dataset):
     def __getitem__(self, index):
         return self.process_fn(self.bin[index])
 
+class TSVDataset(Dataset):
+    def __init__(self, path, process_fn, with_heads=True, **kwargs):
+        self.process_fn = process_fn
+        with open(path, 'r') as fin:
+            if with_heads:
+                self.heads = fin.readline().split('\t')
+            else:
+                self.heads = None
+            self.items = [line.split('\t') for line in fin]
+
+    def __len__(self):
+        return len(self.items)
+    
+    def __getitem__(self, index):
+        return self.process_fn(self.items[index])
diff --git a/SwissArmyTransformer/generation/autoregressive_sampling.py b/SwissArmyTransformer/generation/autoregressive_sampling.py
index 37a24a88a004bfff63a87ebdcf3c909b29256a20..eb5acf42c303db3340e64249108f7f8e697293b4 100644
--- a/SwissArmyTransformer/generation/autoregressive_sampling.py
+++ b/SwissArmyTransformer/generation/autoregressive_sampling.py
@@ -56,7 +56,8 @@ def filling_sequence(
         max_memory_length=100000,
         log_attention_weights=None,
         get_masks_and_position_ids=get_masks_and_position_ids_default,
-        mems=None
+        mems=None,
+        **kw_args
         ):
     '''
         seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
@@ -99,13 +100,15 @@ def filling_sequence(
         else:
             log_attention_weights_part = None
 
-        logits, *mem_kv = model(
+        logits, *output_per_layers = model(
             tokens[:, index:], 
             position_ids[..., index: counter+1],
             attention_mask[..., index: counter+1, :counter+1], # TODO memlen
             mems=mems,
-            log_attention_weights=log_attention_weights_part
+            log_attention_weights=log_attention_weights_part,
+            **kw_args
         )
+        mem_kv = [o['mem_kv'] for o in output_per_layers]
         mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
         counter += 1
         index = counter
diff --git a/SwissArmyTransformer/model/__init__.py b/SwissArmyTransformer/model/__init__.py
index 32f46e4bce6d09b3f4780f9adbef6120960158be..4fbcd53865ddcd51433a352cec079a8562558b5b 100755
--- a/SwissArmyTransformer/model/__init__.py
+++ b/SwissArmyTransformer/model/__init__.py
@@ -2,4 +2,5 @@ from .base_model import BaseModel
 from .cached_autoregressive_model import CachedAutoregressiveModel
 from .cuda2d_model import Cuda2dModel
 from .glm_model import GLMModel
-from .encoder_decoder_model import EncoderDecoderModel
\ No newline at end of file
+from .encoder_decoder_model import EncoderDecoderModel
+from .t5_model import T5Model
diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py
index c9e1c9017782b073546fa12423690b71888a51b0..76e8c6a9e78527f6ba431c8c428add2245795917 100644
--- a/SwissArmyTransformer/model/base_model.py
+++ b/SwissArmyTransformer/model/base_model.py
@@ -7,6 +7,7 @@
 '''
 
 # here put the import lib
+from functools import partial
 import os
 import sys
 import math
@@ -14,19 +15,39 @@ import random
 import torch
 
 from SwissArmyTransformer.mpu import BaseTransformer
+from SwissArmyTransformer.mpu.transformer import standard_attention
+
+def non_conflict(func):
+    func.non_conflict = True
+    return func
 
 class BaseMixin(torch.nn.Module):
     def __init__(self):
         super(BaseMixin, self).__init__()
         # define new params
+
     def reinit(self, *pre_mixins):
         # reload the initial params from previous trained modules
         pass
-    # can also define hook-functions here
+
+    # can define hook-functions here
     # ...
 
+    # If the hook is just a pre- or post- transformation,
+    # You can use @non_conflict to mark it,
+    # and run `old_impl` to make it compatible with other mixins.
+    # Eg., 
+    # 
+    # @non_conflict
+    # def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kw_args):
+    #     new_q, new_k, new_v = pre_hack(q, k, v)
+    #     attn_result = old_impl(q, k, v, mask, dropout_fn, **kw_args)
+    #     attn_result = post_hack(attn_result)
+    #     return attn_result
+
+
 class BaseModel(torch.nn.Module):
-    def __init__(self, args, transformer=None, parallel_output=True):
+    def __init__(self, args, transformer=None, **kwargs):
         super(BaseModel, self).__init__()
         self.mixins = torch.nn.ModuleDict()
         self.collect_hooks_()
@@ -42,14 +63,16 @@ class BaseModel(torch.nn.Module):
                 embedding_dropout_prob=args.hidden_dropout,
                 attention_dropout_prob=args.attention_dropout,
                 output_dropout_prob=args.hidden_dropout,
+                inner_hidden_size=args.inner_hidden_size,
+                hidden_size_per_attention_head=args.hidden_size_per_attention_head,
                 checkpoint_activations=args.checkpoint_activations,
                 checkpoint_num_layers=args.checkpoint_num_layers,
                 sandwich_ln=args.sandwich_ln,
-                parallel_output=parallel_output,
-                hooks=self.hooks
+                hooks=self.hooks,
+                **kwargs
             )
 
-    def reinit(self): # will be called when loading model
+    def reinit(self):  # will be called when loading model
         # if some mixins are loaded, overrides this function
         for m in self.mixins.values():
             m.reinit(self.transformer)
@@ -58,11 +81,11 @@ class BaseModel(torch.nn.Module):
         assert name not in self.mixins
         assert isinstance(new_mixin, BaseMixin)
 
-        self.mixins[name] = new_mixin # will auto-register parameters
-        object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr
+        self.mixins[name] = new_mixin  # will auto-register parameters
+        object.__setattr__(new_mixin, 'transformer', self.transformer)  # cannot use pytorch set_attr
 
         if reinit:
-            new_mixin.reinit(self.transformer, **self.mixins) # also pass current mixins
+            new_mixin.reinit(self.transformer, **self.mixins)  # also pass current mixins
         self.collect_hooks_()
 
     def del_mixin(self, name):
@@ -82,18 +105,25 @@ class BaseModel(torch.nn.Module):
 
     def collect_hooks_(self):
         names = ['word_embedding_forward', 'position_embedding_forward',
-                'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
-                'branch_embedding_forward', 'branch_final_forward'
-                ]
+                 'attention_forward', 'cross_attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
+                 'cross_layer_embedding_forward',
+                 'attention_fn'
+                 ]
         hooks = {}
         hook_origins = {}
         for name in names:
             for mixin_name, m in self.mixins.items():
                 if hasattr(m, name):
-                    if name in hooks: # conflict
-                        raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
-                    hooks[name] = getattr(m, name)
-                    hook_origins[name] = mixin_name
+                    if name in hooks: # if this hook name is already registered
+                        if hasattr(getattr(m, name), 'non_conflict'):
+                            hooks[name] = partial(getattr(m, name), old_impl=hooks[name])
+                            hook_origins[name] = mixin_name + ' -> ' + hook_origins[name]
+                        else: # conflict
+                            raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
+                    else: # new hook
+                        hooks[name] = getattr(m, name)
+                        hook_origins[name] = mixin_name
+
             if hasattr(self, name):
                 # if name in hooks: # defined in mixins, can override
                 #     print(f'Override {name} in {hook_origins[name]}...')
@@ -104,4 +134,4 @@ class BaseModel(torch.nn.Module):
         return hooks
 
     def disable_untrainable_params(self):
-        pass
\ No newline at end of file
+        pass
diff --git a/SwissArmyTransformer/model/cached_autoregressive_model.py b/SwissArmyTransformer/model/cached_autoregressive_model.py
index a0e1699e923e9002e7a46ac69692ae7934032a57..8caed663b27e4b6c1b4382090ddd754a1c875d3e 100755
--- a/SwissArmyTransformer/model/cached_autoregressive_model.py
+++ b/SwissArmyTransformer/model/cached_autoregressive_model.py
@@ -13,43 +13,31 @@ import math
 import random
 import torch
 
-from .base_model import BaseModel, BaseMixin
+from .base_model import BaseModel, BaseMixin, non_conflict
 from SwissArmyTransformer.mpu.transformer import standard_attention, split_tensor_along_last_dim
 
 class CachedAutoregressiveMixin(BaseMixin):
     def __init__(self):
-        super().__init__()
-        
-    def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, log_attention_weights=None, **kwargs):
-        attn_module = self.transformer.layers[layer_id].attention
-        mem = mems[layer_id] if mems is not None else None
-        
-        mixed_raw_layer = attn_module.query_key_value(hidden_states)
-        (mixed_query_layer,
-            mixed_key_layer,
-            mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
-        
-        if mem is not None: # the first time, mem is None
-            b = mixed_key_layer.shape[0] # might change batch_size
-            memk, memv = split_tensor_along_last_dim(mem.expand(b, -1, -1), 2)
-            mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1)
-            mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1)
+        super().__init__()     
+           
+    @non_conflict
+    def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, cross_attention=False, old_impl=standard_attention,
+                     **kw_args):
+        if not cross_attention:
+            mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size
+            b, nh, seq_len, hidden_size = k.shape
+
+            cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2)
+            kw_args['output_this_layer']['mem_kv'] = cache_kv
+
+            if mem is not None: # the first time, mem is None
+                # might change batch_size
+                mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4)
+                memk, memv = mem[0], mem[1]
+                k = torch.cat((memk, k), dim=2)
+                v = torch.cat((memv, v), dim=2)
+        return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, mems=mems, **kw_args)
 
-        # same as training
-        query_layer = attn_module._transpose_for_scores(mixed_query_layer)
-        key_layer = attn_module._transpose_for_scores(mixed_key_layer)
-        value_layer = attn_module._transpose_for_scores(mixed_value_layer)
-        context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=log_attention_weights)
-        
-        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-        new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
-        context_layer = context_layer.view(*new_context_layer_shape)
-        output = attn_module.dense(context_layer)
-        
-        # new mem this layer
-        new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous()
-            
-        return output, new_mem
 
 class CachedAutoregressiveModel(BaseModel):
     def __init__(self, args, transformer=None):
diff --git a/SwissArmyTransformer/model/common_layers.py b/SwissArmyTransformer/model/common_layers.py
deleted file mode 100644
index 90d24334511192b727ed0a0b4eeff588b8fdb960..0000000000000000000000000000000000000000
--- a/SwissArmyTransformer/model/common_layers.py
+++ /dev/null
@@ -1,91 +0,0 @@
-# -*- encoding: utf-8 -*-
-'''
-@File    :   components.py
-@Time    :   2021/11/23 18:20:22
-@Author  :   Ming Ding 
-@Contact :   dm18@mails.tsinghua.edu.cn
-'''
-
-# here put the import lib
-import os
-import sys
-import math
-import random
-import torch
-from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim
-from SwissArmyTransformer.mpu.transformer import standard_attention, LayerNorm
-
-class CrossAttention(torch.nn.Module):
-    def __init__(self, hidden_size, num_attention_heads,
-                attention_dropout_prob, output_dropout_prob,
-                init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None):
-        super(CrossAttention, self).__init__()
-        # Set output layer initialization if not provided.
-        if output_layer_init_method is None:
-            output_layer_init_method = init_method
-        if inner_hidden_size is None:
-            inner_hidden_size = hidden_size
-        self.inner_hidden_size = inner_hidden_size
-        if enc_hidden_size is None:
-            enc_hidden_size = hidden_size
-        self.enc_hidden_size = enc_hidden_size
-
-        # To make user understand better, temporally not support model parallel
-        world_size = 1
-        self.hidden_size_per_partition = divide(hidden_size, world_size)
-        self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
-        self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
-
-        # To map encoder outputs
-        self.kv_linear = torch.nn.Linear(
-            enc_hidden_size, inner_hidden_size * 2
-        )
-        init_method(self.kv_linear.weight)
-
-        # To map self
-        self.q_linear = torch.nn.Linear(
-            hidden_size, inner_hidden_size
-        )
-        init_method(self.q_linear.weight)
-
-        self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
-
-        self.dense = torch.nn.Linear(
-            inner_hidden_size,
-            hidden_size,
-        )
-        output_layer_init_method(self.dense.weight)
-        self.output_dropout = torch.nn.Dropout(output_dropout_prob)
-
-
-    def _transpose_for_scores(self, tensor):
-        """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
-        size [b, np, s, hn].
-        """
-        new_tensor_shape = tensor.size()[:-1] + \
-                            (self.num_attention_heads_per_partition,
-                            self.hidden_size_per_attention_head)
-        tensor = tensor.view(*new_tensor_shape)
-        return tensor.permute(0, 2, 1, 3)
-
-    def forward(self, hidden_states, mask, encoder_outputs, **kw_args):
-        
-        query_layer = self.q_linear(hidden_states)
-        key_layer, value_layer = split_tensor_along_last_dim(self.kv_linear(encoder_outputs), 2)
-        
-        dropout_fn = self.attention_dropout if self.training else None
-
-        query_layer = self._transpose_for_scores(query_layer)
-        key_layer = self._transpose_for_scores(key_layer)
-        value_layer = self._transpose_for_scores(value_layer)
-        
-        context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn)
-        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
-        context_layer = context_layer.view(*new_context_layer_shape)
-        output = self.dense(context_layer)
-        
-        if self.training:
-            output = self.output_dropout(output)
-        
-        return output
diff --git a/SwissArmyTransformer/model/cuda2d_model.py b/SwissArmyTransformer/model/cuda2d_model.py
index cda027c5d8bcdb46cb0d343f4c0a0338dfccb5e0..4fdc87d3bee12f1b132f25f979d33c2582d36fa8 100644
--- a/SwissArmyTransformer/model/cuda2d_model.py
+++ b/SwissArmyTransformer/model/cuda2d_model.py
@@ -88,7 +88,7 @@ class Cuda2dModel(BaseModel):
         output_1 = dense_plus(context_layer1)
         output = torch.cat((output_0, output_1), dim=1)
         
-        return output, None
+        return output
     
     def disable_untrainable_params(self):
         self.transformer.requires_grad_(False)
diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py
index e70245368e68e6d43ee06d97a959868194b712d6..341a800083d5a317dfc55cde1592915a8296c4ca 100644
--- a/SwissArmyTransformer/model/encoder_decoder_model.py
+++ b/SwissArmyTransformer/model/encoder_decoder_model.py
@@ -14,131 +14,83 @@ import random
 import torch
 import argparse
 from .base_model import BaseModel, BaseMixin
-from .common_layers import CrossAttention, LayerNorm
+from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
 
 
-class CrossAttentionMixin(BaseMixin):
-    def __init__(self, num_layers, hidden_size, num_attention_heads,
-                attention_dropout_prob, output_dropout_prob,
-                init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None):
-        super().__init__()
-            
-        self.cross_attentions = torch.nn.ModuleList(
-            [CrossAttention(
-                hidden_size, num_attention_heads,
-                attention_dropout_prob, output_dropout_prob,
-                init_method, enc_hidden_size=enc_hidden_size, inner_hidden_size=inner_hidden_size, 
-                output_layer_init_method=output_layer_init_method
-            ) for layer_id in range(num_layers)]
-        ) # Just copy args
-        self.cross_lns = torch.nn.ModuleList(
-            [LayerNorm(hidden_size, 1e-5)
-            for layer_id in range(num_layers)]
-        )
-        
-
-    def layer_forward(self, hidden_states, mask, layer_id, **kw_args):
-        layer = self.transformer.layers[layer_id]
-        encoder_outputs = kw_args['encoder_outputs']
-        '''
-            hidden_states: [batch, seq_len, hidden_size]
-            mask: [(1, 1), seq_len, seq_len]
-            encoder_outputs: [batch, enc_seq_len, enc_hidden_size]
-        '''
-        # Layer norm at the begining of the transformer layer.
-        layernorm_output = layer.input_layernorm(hidden_states)
-        attention_output, output_this_layer = layer.attention(layernorm_output, mask, **kw_args)
-        # Third LayerNorm
-        if layer.sandwich_ln:
-            attention_output = layer.third_layernorm(attention_output)
-        # Residual connection.
-        hidden_states = hidden_states + attention_output
-
-        # Cross attention.
-        layernorm_output = self.cross_lns[layer_id](hidden_states)
-        cross_attn_output = self.cross_attentions[layer_id](
-            layernorm_output, 
-            torch.ones(1, 1, device=hidden_states.device, dtype=hidden_states.dtype), 
-            encoder_outputs
-            )
-        hidden_states = hidden_states + cross_attn_output
-
-        # Layer norm post the layer attention.
-        layernorm_output = layer.post_attention_layernorm(hidden_states)
-        # MLP.
-        mlp_output = layer.mlp(layernorm_output, **kw_args)
+class EncoderFinalMixin(BaseMixin):
+    def final_forward(self, logits, **kwargs):
+        logits = copy_to_model_parallel_region(logits)
+        return logits
 
-        # Fourth LayerNorm
-        if layer.sandwich_ln:
-            mlp_output = layer.fourth_layernorm(mlp_output)
-        output = hidden_states + mlp_output
-
-        return output, output_this_layer
-
-    
-class DecoderModel(BaseModel):
-    def __init__(self, args, transformer=None):
-        dec_args = argparse.Namespace(**vars(args))
-        dec_args.enc_hidden_size = dec_args.hidden_size # used for cross attn
-        override_attrs = ['num_layers', 'vocab_size', 
-            'hidden_size', 'num_attention_heads', 
-            'max_sequence_length', 'sandwich_ln' # TODO
-            ]
-        for name in override_attrs:
-            dec_attr = getattr(dec_args, 'dec_' + name, None)
-            if dec_attr is not None: # else use encoder-config
-                setattr(dec_args, name, dec_attr)
-
-        super().__init__(dec_args, transformer=transformer)
-        self.add_mixin('cross_attention',
-            CrossAttentionMixin(
-                dec_args.num_layers,
-                dec_args.hidden_size, dec_args.num_attention_heads,
-                dec_args.attention_dropout, dec_args.hidden_dropout,
-                self.transformer.init_method, 
-                enc_hidden_size=dec_args.enc_hidden_size, 
-                inner_hidden_size=getattr(dec_args, 'dec_inner_hidden_size', None), 
-                output_layer_init_method=self.transformer.output_layer_init_method
-            )
-        )
 
 class EncoderDecoderModel(torch.nn.Module):
-    def __init__(self, args, encoder=None, decoder=None):
+    def __init__(self, args, encoder=None, decoder=None, tie_word_embeddings=True, parallel_output=False, **kwargs):
         super(EncoderDecoderModel, self).__init__()
         if encoder is not None:
             assert isinstance(encoder, BaseModel)
             self.encoder = encoder
         else:
-            self.encoder = BaseModel(args)
+            self.encoder = BaseModel(args, **kwargs)
+        self.encoder.add_mixin("final", EncoderFinalMixin())
         
         if decoder is not None:
             assert isinstance(decoder, BaseModel)
             self.decoder = decoder
         else:
-            self.decoder = DecoderModel(args)
+            dec_args = argparse.Namespace(**vars(args))
+            dec_args.enc_hidden_size = dec_args.hidden_size  # used for cross attn
+            override_attrs = ['num_layers', 'hidden_size', 'num_attention_heads',
+                              'max_sequence_length', 'inner_hidden_size', 'hidden_size_per_attention_head']
+            for name in override_attrs:
+                dec_attr = getattr(dec_args, 'dec_' + name, None)
+                if dec_attr is not None:  # else use encoder-config
+                    setattr(dec_args, name, dec_attr)
+            self.decoder = BaseModel(args, is_decoder=True, parallel_output=parallel_output, **kwargs)
+
+        self.tie_word_embeddings = tie_word_embeddings
+        if tie_word_embeddings:
+            self.decoder.transformer.word_embeddings = self.encoder.transformer.word_embeddings
 
     def reinit(self):
         self.encoder.reinit()
         self.decoder.reinit()
-    
+
     def disable_untrainable_params(self):
         self.encoder.disable_untrainable_params()
         self.decoder.disable_untrainable_params()
+
+    def encode(self, input_ids, position_ids, attention_mask=None, **kw_args):
+        encoder_outputs, *_dumps = self.encoder(input_ids, position_ids, attention_mask, **kw_args)
+        return encoder_outputs
+    
+    def decode(self, input_ids, position_ids, attention_mask, encoder_outputs,cross_attention_mask=None, **kw_args):
+        if attention_mask is None:
+            batch_size, seq_length = input_ids.size()[:2]
+            seq_ids = torch.arange(seq_length, device=input_ids.device)
+            attention_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+            attention_mask = attention_mask.to(self.decoder.transformer.word_embeddings.weight.dtype)
+            attention_mask = attention_mask[:, None, :, :]
+        # If no context, please explicitly pass ``encoder_outputs=None''
+        return self.decoder(input_ids, position_ids, attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args)
     
-    def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids, dec_attention_mask, *, branch_input=None, **kw_args):
-        mask_one = torch.ones(1, 1, device=enc_input_ids.device, dtype=dec_attention_mask.dtype)
-        enc_outputs, *_dumps = self.encoder(enc_input_ids, enc_position_ids, mask_one, branch_input=branch_input, **kw_args)
-        dec_outputs, *dec_mems = self.decoder(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=enc_outputs, branch_input=branch_input, **kw_args)
-        return enc_outputs, dec_outputs, *dec_mems
+    def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids, *, enc_attention_mask=None, dec_attention_mask=None, cross_attention_mask=None, **kw_args):
+        # Please use self.decoder for auto-regressive generation.
+        batch_size, seq_length = enc_input_ids.size()[:2]
+        if enc_attention_mask is None:
+            enc_attention_mask = torch.ones(1, 1, 1, seq_length, dtype=self.encoder.transformer.word_embeddings.weight.dtype, device=enc_input_ids.device)
+        if cross_attention_mask is None:
+            cross_attention_mask = enc_attention_mask
+        encoder_outputs = self.encode(enc_input_ids, enc_position_ids, enc_attention_mask, **kw_args)
+        decoder_outputs, *mems = self.decode(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args)
+        return encoder_outputs, decoder_outputs, *mems
 
     @classmethod
     def add_model_specific_args(cls, parser):
         group = parser.add_argument_group('EncoderDecoderModel', 'T5 or Bart')
-        group.add_argument("--dec_num_layers", type=int, default=None)
-        group.add_argument("--dec_vocab_size", type=int, default=None)
-        group.add_argument("--dec_hidden_size", type=int, default=None)
-        group.add_argument("--dec_num_attention_heads", type=int, default=None)
-        group.add_argument("--dec_max_sequence_length", type=int, default=None)
-        group.add_argument("--dec_sandwich_ln", action='store_true')
-        group.add_argument("--dec_inner_hidden_size", type=int, default=None)
-        return parser
\ No newline at end of file
+        group.add_argument("--dec-num-layers", type=int, default=None)
+        group.add_argument("--dec-hidden-size", type=int, default=None)
+        group.add_argument("--dec-num-attention-heads", type=int, default=None)
+        group.add_argument("--dec-max-sequence-length", type=int, default=None)
+        group.add_argument("--dec-inner-hidden-size", type=int, default=None)
+        group.add_argument("--dec-hidden-size-per-attention-head", type=int, default=None)
+        return parser
diff --git a/SwissArmyTransformer/model/finetune/__init__.py b/SwissArmyTransformer/model/finetune/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f2cc6a0915d3df1762edd4f202b2f96b81a5609
--- /dev/null
+++ b/SwissArmyTransformer/model/finetune/__init__.py
@@ -0,0 +1,2 @@
+from .mlp_head import MLPHeadMixin
+from .prompt_tuning import PrefixTuningMixin, PTuningV2Mixin
\ No newline at end of file
diff --git a/SwissArmyTransformer/model/finetune/mlp_head.py b/SwissArmyTransformer/model/finetune/mlp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..e25d661d0d36a24887614dbcbb07faaa445956e2
--- /dev/null
+++ b/SwissArmyTransformer/model/finetune/mlp_head.py
@@ -0,0 +1,36 @@
+
+# -*- encoding: utf-8 -*-
+'''
+@File    :   mlp_head.py
+@Time    :   2021/12/12 20:44:09
+@Author  :   Ming Ding 
+@Contact :   dm18@mails.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+
+import torch
+from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin, non_conflict
+
+class MLPHeadMixin(BaseMixin):
+    def __init__(self, hidden_size, *output_sizes, bias=True, activation_func=torch.nn.functional.relu, init_mean=0, init_std=0.005):
+        super().__init__()
+        self.activation_func = activation_func
+        last_size = hidden_size
+        self.layers = torch.nn.ModuleList()
+        for sz in output_sizes:
+            this_layer = torch.nn.Linear(last_size, sz, bias=bias)
+            last_size = sz
+            torch.nn.init.normal_(this_layer.weight, mean=init_mean, std=init_std)
+            self.layers.append(this_layer)
+
+    def final_forward(self, logits, **kw_args):
+        for i, layer in enumerate(self.layers):
+            if i > 0:
+                logits = self.activation_func(logits)
+            logits = layer(logits)
+        return logits
\ No newline at end of file
diff --git a/SwissArmyTransformer/model/finetune/prompt_tuning.py b/SwissArmyTransformer/model/finetune/prompt_tuning.py
new file mode 100644
index 0000000000000000000000000000000000000000..51cd1326b2a42cbcb89a327ab3b15892e3c0a2c1
--- /dev/null
+++ b/SwissArmyTransformer/model/finetune/prompt_tuning.py
@@ -0,0 +1,45 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   prompt_tuning.py
+@Time    :   2021/12/12 20:45:18
+@Author  :   Ming Ding 
+@Contact :   dm18@mails.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+
+from SwissArmyTransformer.mpu.transformer import standard_attention
+from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin, non_conflict
+
+
+class PrefixTuningMixin(BaseMixin):
+    def __init__(self, num_layers, hidden_size_per_attention_head, num_attention_heads, prefix_len):
+        super().__init__()
+        self.prefix = torch.nn.ParameterList([
+            torch.nn.Parameter(torch.randn(2, num_attention_heads, prefix_len, hidden_size_per_attention_head)*0.01)
+            for layer_id in range(num_layers)
+        ])
+        self.prefix_len = prefix_len
+
+    @non_conflict
+    def attention_fn(self, q, k, v, mask, dropout_fn, old_impl=standard_attention, **kw_args):
+        prefix_k, prefix_v = self.prefix[kw_args['layer_id']]
+
+        b, nh, seq_len, hidden_size = k.shape
+        prefix_k = prefix_k.unsqueeze(0).expand(b, nh, -1, hidden_size)
+        prefix_v = prefix_v.unsqueeze(0).expand(b, nh, -1, hidden_size)
+
+        k = torch.cat((k, prefix_k), dim=2)
+        v = torch.cat((v, prefix_v), dim=2)
+        if mask.numel() > 1:
+            mask_prefixed = torch.ones(self.prefix_len, device=mask.device, dtype=mask.dtype)
+            mask_prefixed = mask_prefixed.expand(*(mask.size()[:-1]), -1)
+            mask = torch.cat((mask, mask_prefixed), dim=-1)
+        return old_impl(q, k, v, mask, dropout_fn, **kw_args)
+
+PTuningV2Mixin = PrefixTuningMixin
\ No newline at end of file
diff --git a/SwissArmyTransformer/model/mixins.py b/SwissArmyTransformer/model/mixins.py
index 2a76b8038265ff00b7dee078df70b169ebb34105..d1fcf89f7311f2f52835e54285dadfd016714fde 100644
--- a/SwissArmyTransformer/model/mixins.py
+++ b/SwissArmyTransformer/model/mixins.py
@@ -17,41 +17,45 @@ from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
 from SwissArmyTransformer.mpu.transformer import unscaled_init_method
 from .base_model import BaseMixin
 from .cached_autoregressive_model import CachedAutoregressiveMixin
+from .finetune import *
 
 class PositionEmbeddingMixin(BaseMixin):
-    def __init__(self, additional_sequence_length, hidden_size, 
-                init_method_std=0.02, reinit_slice=slice(-1024, None)
-        ):
+    def __init__(self, additional_sequence_length, hidden_size,
+                 init_method_std=0.02, reinit_slice=slice(-1024, None)
+                 ):
         super(PositionEmbeddingMixin, self).__init__()
         self.reinit_slice = reinit_slice
         self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
         torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
+
     def reinit(self, *pre_mixins):
         old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
         old_len, hidden_size = old_weights.shape
         assert hidden_size == self.position_embeddings.weight.shape[-1]
         self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
 
+
 class AttentionMixin(BaseMixin):
     def __init__(self, num_layers,
-                hidden_size, 
-                init_method=unscaled_init_method(0.02),
-                output_layer_init_method=unscaled_init_method(0.02)
-        ):
+                 hidden_size,
+                 init_method=unscaled_init_method(0.02),
+                 output_layer_init_method=unscaled_init_method(0.02)
+                 ):
         super(AttentionMixin, self).__init__()
-        self.num_layers = num_layers # replace attention in the LAST n layers
+        self.num_layers = num_layers  # replace attention in the LAST n layers
         self.query_key_value = torch.nn.ModuleList(
-            [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
-                gather_output=False,init_method=init_method)
-                for layer_id in range(num_layers)
-            ])
+            [ColumnParallelLinear(hidden_size, 3 * hidden_size, stride=3,
+                                  gather_output=False, init_method=init_method)
+             for layer_id in range(num_layers)
+             ])
         self.dense = torch.nn.ModuleList(
             [RowParallelLinear(hidden_size,
-                hidden_size,
-                input_is_parallel=True,
-                init_method=output_layer_init_method)
-                for layer_id in range(num_layers)
-            ])
+                               hidden_size,
+                               input_is_parallel=True,
+                               init_method=output_layer_init_method)
+             for layer_id in range(num_layers)
+             ])
+
     def reinit(self, *pre_mixins):
         start_layer = len(self.transformer.layers) - self.num_layers
         assert start_layer >= 0
diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1a50951ea352babcd763a0178ed2de2013695dc
--- /dev/null
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -0,0 +1,287 @@
+import math
+import torch
+import torch.nn.functional as F
+from .mixins import BaseMixin
+from .encoder_decoder_model import EncoderDecoderModel
+from .base_model import non_conflict
+from SwissArmyTransformer.mpu import get_model_parallel_world_size
+from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP
+from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
+from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim, unscaled_init_method
+from SwissArmyTransformer.mpu.layers import ColumnParallelLinear, VocabParallelEmbedding
+
+
+class T5PositionEmbeddingMixin(BaseMixin):
+    def position_embedding_forward(self, position_ids, **kw_args):
+        return None
+
+
+class T5LayerNorm(torch.nn.Module):
+    def __init__(self, hidden_size, eps=1e-6):
+        """
+        Construct a layernorm module in the T5 style No bias and no subtraction of mean.
+        """
+        super().__init__()
+        self.weight = torch.nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states):
+        # layer norm should always be calculated in float32
+        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+        # convert into float16 or bfloat16 if necessary
+        if self.weight.dtype == torch.float16:
+            hidden_states = hidden_states.to(torch.float16)
+        elif self.weight.dtype == torch.bfloat16:
+            hidden_states = hidden_states.to(torch.bfloat16)
+        return self.weight * hidden_states
+
+
+class T5AttentionMixin(BaseMixin):
+    def __init__(self, relative_attention_num_buckets, num_attention_heads, is_decoder=False):
+        super().__init__()
+        self.relative_attention_num_buckets = relative_attention_num_buckets
+        world_size = get_model_parallel_world_size()
+        self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
+        self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets,
+                                                          self.num_attention_heads_per_partition)
+        self.is_decoder = is_decoder
+
+    @staticmethod
+    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+        """
+        Adapted from Mesh Tensorflow:
+        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+        Translate relative position to a bucket number for relative attention. The relative position is defined as
+        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+        This should allow for more graceful generalization to longer sequences than the model has been trained on
+
+        Args:
+            relative_position: an int32 Tensor
+            bidirectional: a boolean - whether the attention is bidirectional
+            num_buckets: an integer
+            max_distance: an integer
+
+        Returns:
+            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
+        """
+        relative_buckets = 0
+        if bidirectional:
+            num_buckets //= 2
+            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
+            relative_position = torch.abs(relative_position)
+        else:
+            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
+        # now relative_position is in the range [0, inf)
+
+        # half of the buckets are for exact increments in positions
+        max_exact = num_buckets // 2
+        is_small = relative_position < max_exact
+
+        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+        relative_postion_if_large = max_exact + (
+                torch.log(relative_position.float() / max_exact)
+                / math.log(max_distance / max_exact)
+                * (num_buckets - max_exact)
+        ).to(torch.long)
+        relative_postion_if_large = torch.min(
+            relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
+        )
+
+        relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
+        return relative_buckets
+
+    def compute_bias(self, query_length, key_length):
+        """Compute binned relative position bias"""
+        context_position = torch.arange(query_length, dtype=torch.long)[:, None]
+        memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
+        relative_position = memory_position - context_position  # shape (query_length, key_length)
+        relative_position_bucket = self._relative_position_bucket(
+            relative_position,  # shape (query_length, key_length)
+            bidirectional=(not self.is_decoder),
+            num_buckets=self.relative_attention_num_buckets,
+        )
+        relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
+        # shape (query_length, key_length, num_heads)
+        values = self.relative_attention_bias(relative_position_bucket)
+        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
+        return values
+
+    @non_conflict
+    def attention_fn(self, q, k, v, mask, dropout_fn, position_bias=None, old_impl=standard_attention,
+                     cross_attention=False, **kw_args):
+        log_attention_weights = None
+        if not cross_attention:
+            if position_bias is None:
+                seq_length = q.size(2)
+                key_length = k.size(2)
+                position_bias = self.compute_bias(key_length, key_length)
+                position_bias = position_bias[:, :, -seq_length:, :]
+            kw_args['output_cross_layer']['position_bias'] = position_bias
+            log_attention_weights = position_bias
+        return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, position_bias=position_bias,
+                        log_attention_weights=log_attention_weights, scaling_attention_score=False, **kw_args)
+
+
+class T5DecoderFinalMixin(BaseMixin):
+    def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True):
+        super().__init__()
+        self.hidden_size = hidden_size
+        self.tie_word_embeddings = tie_word_embeddings
+        if not tie_word_embeddings:
+            self.lm_head = VocabParallelEmbedding(
+                vocab_size, hidden_size, init_method=unscaled_init_method(0.02))
+
+    def final_forward(self, logits, **kwargs):
+        logits_parallel = copy_to_model_parallel_region(logits)
+        if self.tie_word_embeddings:
+            logits_parallel = logits_parallel * (self.hidden_size ** -0.5)
+            logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight)
+        else:
+            logits_parallel = F.linear(logits_parallel, self.lm_head.weight)
+        return logits_parallel
+
+
+def t5_gelu(x):
+    """
+    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
+    the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+    """
+    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
+
+
+class T5GatedGeluMLPMixin(BaseMixin):
+    def __init__(self, num_layers, hidden_size, inner_hidden_size=None, bias=True, init_method_std=0.02):
+        super().__init__()
+        self.hidden_size = hidden_size
+        if inner_hidden_size is None:
+            inner_hidden_size = 4 * hidden_size
+        self.inner_hidden_size = inner_hidden_size
+        self.init_method_std = init_method_std
+        self.gated_h_to_4h_list = torch.nn.ModuleList([
+            ColumnParallelLinear(
+                self.hidden_size,
+                self.inner_hidden_size,
+                gather_output=False,
+                init_method=self._init_weights,
+                bias=bias,
+                module=self,
+                name="gated_h_to_4h"
+            )
+            for layer_id in range(num_layers)])
+
+    def _init_weights(self, weight, **kwargs):
+        torch.nn.init.normal_(weight, mean=0, std=self.init_method_std * (self.hidden_size ** -0.5))
+
+    def mlp_forward(self, hidden_states, layer_id=None, **kw_args):
+        mlp_module = self.transformer.layers[layer_id].mlp
+        hidden_gelu = t5_gelu(mlp_module.dense_h_to_4h(hidden_states))
+        hidden_linear = self.gated_h_to_4h_list[layer_id](hidden_states)
+        hidden_states = hidden_gelu * hidden_linear
+        output = mlp_module.dense_4h_to_h(hidden_states)
+
+        if self.training:
+            output = mlp_module.dropout(output)
+        return output
+
+
+class T5Model(EncoderDecoderModel):
+    def __init__(self, args, **kwargs):
+        self.init_method_std = args.init_method_std
+        super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False,
+                         layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu,
+                         init_method=self._init_weights)
+        self.encoder.add_mixin(
+            "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads)
+        )
+        self.encoder.add_mixin(
+            "t5-position", T5PositionEmbeddingMixin()
+        )
+        del self.encoder.transformer.position_embeddings
+        num_attention_heads = args.dec_num_attention_heads if args.dec_num_attention_heads is not None else args.num_attention_heads
+        self.decoder.add_mixin(
+            "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, num_attention_heads, is_decoder=True)
+        )
+        self.decoder.add_mixin(
+            "t5-position", T5PositionEmbeddingMixin()
+        )
+        self.decoder.add_mixin(
+            "t5-final",
+            T5DecoderFinalMixin(args.vocab_size, args.hidden_size, tie_word_embeddings=not args.no_share_embeddings)
+        )
+        del self.decoder.transformer.position_embeddings
+        if args.gated_gelu_mlp:
+            self.encoder.add_mixin(
+                "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std,
+                                                 inner_hidden_size=args.inner_hidden_size, bias=False)
+            )
+            self.decoder.add_mixin(
+                "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std,
+                                                 inner_hidden_size=args.inner_hidden_size, bias=False)
+            )
+
+    def _init_weights(self, weight, module, name):
+        init_method_std = self.init_method_std
+        if isinstance(module, MLP):
+            if name == "dense_h_to_4h":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
+            elif name == "dense_4h_to_h":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
+            else:
+                raise NotImplementedError(name)
+        elif isinstance(module, SelfAttention):
+            if name == "query_key_value":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
+                torch.nn.init.normal_(weight[:module.inner_hidden_size], mean=0, std=init_method_std * (
+                        (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5))
+            elif name == "dense":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
+            else:
+                raise NotImplementedError(name)
+        elif isinstance(module, CrossAttention):
+            if name == "query":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (
+                        (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5))
+            elif name == "key_value":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
+            elif name == "dense":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
+            else:
+                raise NotImplementedError(name)
+        else:
+            raise NotImplementedError(module)
+
+    @classmethod
+    def add_model_specific_args(cls, parser):
+        super().add_model_specific_args(parser)
+        parser.add_argument("--relative-attention-num-buckets", type=int, default=None)
+        parser.add_argument("--init-method-std", type=float, default=0.02)
+        parser.add_argument("--gated-gelu-mlp", action='store_true')
+        parser.add_argument("--no-share-embeddings", action='store_true')
+
+    def encode(self, input_ids, attention_mask=None, **kw_args):
+        return super().encode(input_ids, None, attention_mask, **kw_args)
+
+    def decode(self, input_ids, attention_mask=None, encoder_outputs=None, cross_attention_mask=None, **kw_args):
+        return super().decode(input_ids, None, attention_mask, encoder_outputs=encoder_outputs,
+                              cross_attention_mask=cross_attention_mask, **kw_args)
+
+    def forward(self, enc_input_ids, dec_input_ids, *, enc_attention_mask=None, dec_attention_mask=None,
+                cross_attention_mask=None, **kw_args):
+        batch_size, seq_length = enc_input_ids.size()[:2]
+        if enc_attention_mask is None:
+            enc_attention_mask = torch.ones(1, 1, 1, seq_length,
+                                            dtype=self.encoder.transformer.word_embeddings.weight.dtype,
+                                            device=enc_input_ids.device)
+        if cross_attention_mask is None:
+            cross_attention_mask = enc_attention_mask
+        encoder_outputs = self.encode(enc_input_ids, enc_attention_mask, **kw_args)
+        decoder_outputs, *mems = self.decode(dec_input_ids, dec_attention_mask,
+                                             encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask,
+                                             **kw_args)
+        return encoder_outputs, decoder_outputs, *mems
diff --git a/SwissArmyTransformer/mpu/layers.py b/SwissArmyTransformer/mpu/layers.py
index 4af3f33180aae46bba28efa213a9a941df5cd3ce..8d7cf370e75f2629ae96b1b2f442f4dfe6bae70f 100755
--- a/SwissArmyTransformer/mpu/layers.py
+++ b/SwissArmyTransformer/mpu/layers.py
@@ -37,7 +37,7 @@ from .utils import VocabUtility
 
 def _initialize_affine_weight(weight, output_size, input_size,
                               per_partition_size, partition_dim, init_method,
-                              stride=1, return_master_weight=False):
+                              stride=1, return_master_weight=False, module=None, name=None):
     """Initialize affine weight for model parallel.
 
     Build the master weight on all processes and scatter
@@ -45,7 +45,7 @@ def _initialize_affine_weight(weight, output_size, input_size,
     # If we only use 1 process for model parallelism, bypass scatter.
     world_size = get_model_parallel_world_size()
     if world_size == 1:
-        init_method(weight)
+        init_method(weight, module=module, name=name)
         if return_master_weight:
             return weight
         return None
@@ -54,7 +54,7 @@ def _initialize_affine_weight(weight, output_size, input_size,
     master_weight = torch.empty(output_size, input_size,
                                 dtype=weight.dtype,
                                 requires_grad=False)
-    init_method(master_weight)
+    init_method(master_weight, module=module, name=name)
 
     # Split and copy
     per_partition_per_stride_size = divide(per_partition_size, stride)
@@ -200,7 +200,7 @@ class ColumnParallelLinear(torch.nn.Module):
     """
     def __init__(self, input_size, output_size, bias=True, gather_output=True,
                  init_method=init.xavier_normal_, stride=1,
-                 keep_master_weight_for_test=False):
+                 keep_master_weight_for_test=False, module=None, name=None):
         super(ColumnParallelLinear, self).__init__()
 
         # Keep input parameters
@@ -230,7 +230,7 @@ class ColumnParallelLinear(torch.nn.Module):
         self.master_weight = _initialize_affine_weight(
             self.weight, self.output_size, self.input_size,
             self.output_size_per_partition, 0, init_method,
-            stride=stride, return_master_weight=keep_master_weight_for_test)
+            stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name)
 
     def forward(self, input_):
         # Set up backprop all-reduce.
@@ -274,7 +274,7 @@ class RowParallelLinear(torch.nn.Module):
     def __init__(self, input_size, output_size, bias=True,
                  input_is_parallel=False,
                  init_method=init.xavier_normal_, stride=1,
-                 keep_master_weight_for_test=False):
+                 keep_master_weight_for_test=False, module=None, name=None):
         super(RowParallelLinear, self).__init__()
 
         # Keep input parameters
@@ -303,7 +303,7 @@ class RowParallelLinear(torch.nn.Module):
         self.master_weight = _initialize_affine_weight(
             self.weight, self.output_size, self.input_size,
             self.input_size_per_partition, 1, init_method,
-            stride=stride, return_master_weight=keep_master_weight_for_test)
+            stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name)
 
     def forward(self, input_):
         # Set up backprop all-reduce.
diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index 3764f0552b1c4163536b061700b8d5d7b6e3ff79..289056948feff75964127fd5a2e4b0f4e8ed3944 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -22,55 +22,62 @@ import torch
 import torch.nn.functional as F
 from apex.normalization.fused_layer_norm import FusedLayerNorm
 
+from SwissArmyTransformer import mpu
 from .initialize import get_model_parallel_world_size
 from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
 from .mappings import gather_from_model_parallel_region, copy_to_model_parallel_region
 
-from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint, get_cuda_rng_tracker
+from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
 
 from .utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu
 from .utils import split_tensor_along_last_dim
 
+
 class LayerNorm(FusedLayerNorm):
     def __init__(self, *args, pb_relax=False, **kwargs):
         super().__init__(*args, **kwargs)
         self.pb_relax = pb_relax
+
     def forward(self, x):
         if not self.pb_relax:
             return super().forward(x)
-        return super().forward(x / (x.abs().max().detach()/8))
+        return super().forward(x / (x.abs().max().detach() / 8))
+
 
 def standard_attention(query_layer, key_layer, value_layer, attention_mask,
-                    attention_dropout=None, log_attention_weights=None):
+                       attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs):
     # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. 
     # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. 
 
-    attention_scores = torch.matmul(
-        query_layer / math.sqrt(query_layer.shape[-1]),
-        key_layer.transpose(-1, -2)
-    )
+    if scaling_attention_score:
+        query_layer = query_layer / math.sqrt(query_layer.shape[-1])
+    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
     if log_attention_weights is not None:
         attention_scores += log_attention_weights
-    
-    if not(attention_mask.shape[-2] == 1 and (attention_mask > 0).all()):
+
+    if not (attention_mask.shape[-2] == 1 and (attention_mask > 0).all()):
         # if auto-regressive, skip
         attention_scores = torch.mul(attention_scores, attention_mask) - \
-                10000.0 * (1.0 - attention_mask)
+                           10000.0 * (1.0 - attention_mask)
 
     attention_probs = F.softmax(attention_scores, dim=-1)
 
     if attention_dropout is not None:
-        with get_cuda_rng_tracker().fork():
+        if mpu.get_cuda_rng_tracker is not None:
+            with mpu.get_cuda_rng_tracker().fork():
+                attention_probs = attention_dropout(attention_probs)
+        else:
             attention_probs = attention_dropout(attention_probs)
 
     context_layer = torch.matmul(attention_probs, value_layer)
     return context_layer
 
+
 class SelfAttention(torch.nn.Module):
     def __init__(self, hidden_size, num_attention_heads,
-                attention_dropout_prob, output_dropout_prob,
-                init_method, layer_id, output_layer_init_method=None,
-                hooks={}):
+                 attention_dropout_prob, output_dropout_prob,
+                 init_method, layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True,
+                 hooks={}):
         super(SelfAttention, self).__init__()
         # Set output layer initialization if not provided.
         if output_layer_init_method is None:
@@ -79,47 +86,61 @@ class SelfAttention(torch.nn.Module):
         self.layer_id = layer_id
         # Per attention head and per partition values.
         world_size = get_model_parallel_world_size()
-        self.hidden_size_per_partition = divide(hidden_size, world_size)
-        self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
+        self.hidden_size = hidden_size
+        if hidden_size_per_attention_head is None:
+            self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
+        else:
+            self.hidden_size_per_attention_head = hidden_size_per_attention_head
         self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
+        self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
+        self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
 
         # Strided linear layer.
         self.query_key_value = ColumnParallelLinear(
             hidden_size,
-            3*hidden_size,
+            3 * self.inner_hidden_size,
             stride=3,
             gather_output=False,
-            init_method=init_method
+            init_method=init_method,
+            bias=bias,
+            module=self,
+            name="query_key_value"
         )
         self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
 
         self.dense = RowParallelLinear(
-            hidden_size,
+            self.inner_hidden_size,
             hidden_size,
             input_is_parallel=True,
-            init_method=output_layer_init_method
+            init_method=output_layer_init_method,
+            bias=bias,
+            module=self,
+            name="dense"
         )
         self.output_dropout = torch.nn.Dropout(output_dropout_prob)
 
-
     def _transpose_for_scores(self, tensor):
         """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
         size [b, np, s, hn].
         """
         new_tensor_shape = tensor.size()[:-1] + \
-                            (self.num_attention_heads_per_partition,
+                           (self.num_attention_heads_per_partition,
                             self.hidden_size_per_attention_head)
         tensor = tensor.view(*new_tensor_shape)
         return tensor.permute(0, 2, 1, 3)
 
-    def forward(self, hidden_states, mask, **kw_args):
+    def forward(self, hidden_states, mask, *args, **kw_args):
         if 'attention_forward' in self.hooks:
-            return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id)
+            return self.hooks['attention_forward'](hidden_states, mask, **kw_args)
         else:
+            attention_fn = standard_attention
+            if 'attention_fn' in self.hooks:
+                attention_fn = self.hooks['attention_fn']
+
             mixed_raw_layer = self.query_key_value(hidden_states)
             (mixed_query_layer,
-                mixed_key_layer,
-                mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
+             mixed_key_layer,
+             mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
 
             dropout_fn = self.attention_dropout if self.training else None
 
@@ -127,7 +148,8 @@ class SelfAttention(torch.nn.Module):
             key_layer = self._transpose_for_scores(mixed_key_layer)
             value_layer = self._transpose_for_scores(mixed_value_layer)
 
-            context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn)
+            context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args)
+
             context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
             new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
             context_layer = context_layer.view(*new_context_layer_shape)
@@ -136,40 +158,137 @@ class SelfAttention(torch.nn.Module):
             if self.training:
                 output = self.output_dropout(output)
 
-            return output, None
+            return output
+
+
+class CrossAttention(torch.nn.Module):
+    """Parallel cross-attention layer for Transformer"""
+
+    def __init__(self, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, init_method,
+                 layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True, hooks={}):
+        super().__init__()
+        # Set output layer initialization if not provided.
+        if output_layer_init_method is None:
+            output_layer_init_method = init_method
+        self.hooks = hooks
+        self.layer_id = layer_id
+        # Per attention head and per partition values.
+        world_size = get_model_parallel_world_size()
+        self.hidden_size = hidden_size
+        if hidden_size_per_attention_head is None:
+            self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
+        else:
+            self.hidden_size_per_attention_head = hidden_size_per_attention_head
+        self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
+        self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
+        self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
+        # Strided linear layer.
+        self.query = ColumnParallelLinear(hidden_size, self.inner_hidden_size,
+                                          gather_output=False,
+                                          init_method=init_method, bias=bias, module=self, name="query")
+        self.key_value = ColumnParallelLinear(hidden_size, 2 * self.inner_hidden_size,
+                                              stride=2,
+                                              gather_output=False,
+                                              init_method=init_method, bias=bias, module=self, name="key_value")
+        # Dropout. Note that for a single iteration, this layer will generate
+        # different outputs on different number of parallel partitions but
+        # on average it should not be partition dependent.
+        self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
+
+        # Output.
+        self.dense = RowParallelLinear(
+            self.inner_hidden_size,
+            hidden_size,
+            input_is_parallel=True,
+            init_method=output_layer_init_method, bias=bias, module=self, name="dense")
+        self.output_dropout = torch.nn.Dropout(output_dropout_prob)
+
+    def _transpose_for_scores(self, tensor):
+        """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
+        size [b, np, s, hn].
+        """
+        new_tensor_shape = tensor.size()[:-1] + \
+                           (self.num_attention_heads_per_partition,
+                            self.hidden_size_per_attention_head)
+        tensor = tensor.view(*new_tensor_shape)
+        return tensor.permute(0, 2, 1, 3)
+
+    def forward(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args):
+        # hidden_states: [b, s, h]
+        if 'cross_attention_forward' in self.hooks:
+            return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, **kw_args)
+        else:
+            attention_fn = standard_attention
+            if 'attention_fn' in self.hooks:
+                attention_fn = self.hooks['attention_fn']
+
+            mixed_query_layer = self.query(hidden_states)
+            mixed_x_layer = self.key_value(encoder_outputs)
+            (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2)
+
+            dropout_fn = self.attention_dropout if self.training else None
+            # Reshape and transpose [b, np, s, hn]
+            query_layer = self._transpose_for_scores(mixed_query_layer)
+            key_layer = self._transpose_for_scores(mixed_key_layer)
+            value_layer = self._transpose_for_scores(mixed_value_layer)
+
+            context_layer = attention_fn(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn,
+                                         cross_attention=True, **kw_args)
+            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+            # [b, s, hp]
+            context_layer = context_layer.view(*new_context_layer_shape)
+
+            # Output. [b, s, h]
+            output = self.dense(context_layer)
+            if self.training:
+                output = self.output_dropout(output)
+
+            return output
 
 
 class MLP(torch.nn.Module):
-    def __init__(self, hidden_size, output_dropout_prob, init_method,
-                output_layer_init_method=None, layer_id=None, hooks={}):
+    def __init__(self, hidden_size, output_dropout_prob, init_method, inner_hidden_size=None,
+                 output_layer_init_method=None, layer_id=None, hooks={}, bias=True, activation_func=gelu):
         super(MLP, self).__init__()
         self.layer_id = layer_id
+        self.activation_func = activation_func
         # Set output layer initialization if not provided.
         if output_layer_init_method is None:
             output_layer_init_method = init_method
         self.hooks = hooks
         # Project to 4h.
+        self.hidden_size = hidden_size
+        if inner_hidden_size is None:
+            inner_hidden_size = 4 * hidden_size
+        self.inner_hidden_size = inner_hidden_size
         self.dense_h_to_4h = ColumnParallelLinear(
-            hidden_size,
-            4*hidden_size,
+            self.hidden_size,
+            self.inner_hidden_size,
             gather_output=False,
-            init_method=init_method
+            init_method=init_method,
+            bias=bias,
+            module=self,
+            name="dense_h_to_4h"
         )
         # Project back to h.
         self.dense_4h_to_h = RowParallelLinear(
-            4*hidden_size,
-            hidden_size,
+            self.inner_hidden_size,
+            self.hidden_size,
             input_is_parallel=True,
-            init_method=output_layer_init_method
+            init_method=output_layer_init_method,
+            bias=bias,
+            module=self,
+            name="dense_4h_to_h"
         )
         self.dropout = torch.nn.Dropout(output_dropout_prob)
 
     def forward(self, hidden_states, **kw_args):
         if 'mlp_forward' in self.hooks:
-            output = self.hooks['mlp_forward'](hidden_states, **kw_args, layer_id=self.layer_id)
+            output = self.hooks['mlp_forward'](hidden_states, **kw_args)
         else:
             intermediate_parallel = self.dense_h_to_4h(hidden_states)
-            intermediate_parallel = gelu(intermediate_parallel)
+            intermediate_parallel = self.activation_func(intermediate_parallel)
             output = self.dense_4h_to_h(intermediate_parallel)
 
         if self.training:
@@ -179,27 +298,34 @@ class MLP(torch.nn.Module):
 
 class BaseTransformerLayer(torch.nn.Module):
     def __init__(
-        self,
-        hidden_size,
-        num_attention_heads,
-        attention_dropout_prob,
-        output_dropout_prob,
-        layernorm_epsilon,
-        init_method,
-        layer_id,
-        output_layer_init_method=None,
-        sandwich_ln=True,
-        hooks={}
+            self,
+            hidden_size,
+            num_attention_heads,
+            attention_dropout_prob,
+            output_dropout_prob,
+            layernorm_epsilon,
+            init_method,
+            layer_id,
+            inner_hidden_size=None,
+            hidden_size_per_attention_head=None,
+            output_layer_init_method=None,
+            sandwich_ln=True,
+            layernorm=LayerNorm,
+            is_decoder=False,
+            use_bias=True,
+            activation_func=gelu,
+            hooks={}
     ):
         super(BaseTransformerLayer, self).__init__()
         # Set output layer initialization if not provided.
         if output_layer_init_method is None:
             output_layer_init_method = init_method
         self.layer_id = layer_id
+        self.is_decoder = is_decoder
         self.hooks = hooks
 
         # Layernorm on the input data.
-        self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
+        self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
 
         # Self attention.
         self.attention = SelfAttention(
@@ -209,37 +335,57 @@ class BaseTransformerLayer(torch.nn.Module):
             output_dropout_prob,
             init_method,
             layer_id,
+            hidden_size_per_attention_head=hidden_size_per_attention_head,
             output_layer_init_method=output_layer_init_method,
+            bias=use_bias,
             hooks=hooks
         )
 
         # Layernorm on the input data.
-        self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
+        self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
         self.sandwich_ln = sandwich_ln
         if sandwich_ln:
-            self.third_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
-            self.fourth_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
+            self.third_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
+            self.fourth_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
+
+        # Cross attention.
+        if self.is_decoder:
+            self.cross_attention = CrossAttention(
+                hidden_size,
+                num_attention_heads,
+                attention_dropout_prob,
+                output_dropout_prob,
+                init_method,
+                layer_id,
+                hidden_size_per_attention_head=hidden_size_per_attention_head,
+                output_layer_init_method=output_layer_init_method,
+                bias=use_bias,
+                hooks=hooks
+            )
+            self.post_cross_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
 
         # MLP
         self.mlp = MLP(
             hidden_size,
             output_dropout_prob,
             init_method,
+            inner_hidden_size=inner_hidden_size,
             output_layer_init_method=output_layer_init_method,
+            bias=use_bias,
             layer_id=layer_id,
+            activation_func=activation_func,
             hooks=hooks
         )
 
-    def forward(self, hidden_states, mask, **kw_args):
+    def forward(self, hidden_states, mask, *args, **kw_args):
         '''
             hidden_states: [batch, seq_len, hidden_size]
             mask: [(1, 1), seq_len, seq_len]
         '''
-
         # Layer norm at the begining of the transformer layer.
         layernorm_output1 = self.input_layernorm(hidden_states)
         # Self attention.
-        attention_output, output_this_layer = self.attention(layernorm_output1, mask, **kw_args)
+        attention_output = self.attention(layernorm_output1, mask, **kw_args)
 
         # Third LayerNorm
         if self.sandwich_ln:
@@ -249,6 +395,18 @@ class BaseTransformerLayer(torch.nn.Module):
         layernorm_input = hidden_states + attention_output
         # Layer norm post the self attention.
         layernorm_output = self.post_attention_layernorm(layernorm_input)
+
+        if self.is_decoder:
+            encoder_outputs = kw_args['encoder_outputs']
+            if encoder_outputs is not None:
+                assert 'cross_attention_mask' in kw_args
+                # Cross attention
+                attention_output = self.cross_attention(layernorm_output, **kw_args)
+                # Residual connection.
+                layernorm_input = layernorm_input + attention_output
+                # Layer norm post the cross attention
+                layernorm_output = self.post_cross_attention_layernorm(layernorm_input)
+
         # MLP.
         mlp_output = self.mlp(layernorm_output, **kw_args)
 
@@ -259,7 +417,8 @@ class BaseTransformerLayer(torch.nn.Module):
         # Second residual connection.
         output = layernorm_input + mlp_output
 
-        return output, output_this_layer # temporally, output_this_layer is only from attention
+        return output, kw_args['output_this_layer'], kw_args['output_cross_layer']
+
 
 class BaseTransformer(torch.nn.Module):
     def __init__(self,
@@ -275,18 +434,26 @@ class BaseTransformer(torch.nn.Module):
                  checkpoint_num_layers=1,
                  layernorm_epsilon=1.0e-5,
                  init_method_std=0.02,
+                 inner_hidden_size=None,
+                 hidden_size_per_attention_head=None,
                  sandwich_ln=True,
                  parallel_output=True,
+                 is_decoder=False,
+                 use_bias=True,
+                 activation_func=gelu,
+                 layernorm=LayerNorm,
+                 init_method=None,
                  hooks={}
                  ):
         super(BaseTransformer, self).__init__()
 
         # recording parameters
+        self.is_decoder = is_decoder
         self.parallel_output = parallel_output
         self.checkpoint_activations = checkpoint_activations
         self.checkpoint_num_layers = checkpoint_num_layers
         self.max_sequence_length = max_sequence_length
-        self.hooks = copy.copy(hooks) # hooks will be updated each forward
+        self.hooks = copy.copy(hooks)  # hooks will be updated each forward
 
         # create embedding parameters
         self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
@@ -298,8 +465,13 @@ class BaseTransformer(torch.nn.Module):
         torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
 
         # create all layers
-        self.output_layer_init_method = scaled_init_method(init_method_std, num_layers)
-        self.init_method = unscaled_init_method(init_method_std)
+        if init_method is None:
+            self.output_layer_init_method = scaled_init_method(init_method_std, num_layers)
+            self.init_method = unscaled_init_method(init_method_std)
+        else:
+            self.output_layer_init_method = init_method
+            self.init_method = init_method
+
         def get_layer(layer_id):
             return BaseTransformerLayer(
                 hidden_size,
@@ -309,32 +481,39 @@ class BaseTransformer(torch.nn.Module):
                 layernorm_epsilon,
                 self.init_method,
                 layer_id,
+                inner_hidden_size=inner_hidden_size,
+                hidden_size_per_attention_head=hidden_size_per_attention_head,
                 output_layer_init_method=self.output_layer_init_method,
+                is_decoder=self.is_decoder,
                 sandwich_ln=sandwich_ln,
+                layernorm=layernorm,
+                use_bias=use_bias,
+                activation_func=activation_func,
                 hooks=self.hooks
-                )
+            )
+
         self.layers = torch.nn.ModuleList(
             [get_layer(layer_id) for layer_id in range(num_layers)])
 
         # Final layer norm before output.
-        self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
+        self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
 
-    def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, output_hidden_states=False,
-                **kw_args):
-        # sanity check 
+    def forward(self, input_ids, position_ids, attention_mask, *,
+                output_hidden_states=False, **kw_args):
+        # sanity check
         assert len(input_ids.shape) == 2
         batch_size, query_length = input_ids.shape
+        if attention_mask is None:
+            attention_mask = torch.ones(1, 1, device=input_ids.device).type_as(
+                next(self.parameters())
+            )  # None means full attention
         assert len(attention_mask.shape) == 2 or \
-            len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
-        assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor)
-        # branch_input is a new part of input need layer-by-layer update,
-        #   but with different hidden_dim and computational routine.
-        #   In most cases, you can just ignore it.
+               len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
 
         # embedding part
         if 'word_embedding_forward' in self.hooks:
             hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args)
-        else: # default
+        else:  # default
             hidden_states = self.word_embeddings(input_ids)
 
         if 'position_embedding_forward' in self.hooks:
@@ -343,62 +522,132 @@ class BaseTransformer(torch.nn.Module):
             assert len(position_ids.shape) <= 2
             assert position_ids.shape[-1] == query_length
             position_embeddings = self.position_embeddings(position_ids)
-        hidden_states = hidden_states + position_embeddings
+        if position_embeddings is not None:
+            hidden_states = hidden_states + position_embeddings
         hidden_states = self.embedding_dropout(hidden_states)
 
-        hidden_states_outputs = [hidden_states] if output_hidden_states else []
-        # branch related embedding
-        if branch_input is None and 'branch_embedding_forward' in self.hooks:
-            branch_input = self.hooks['branch_embedding_forward'](branch_input, **kw_args)
+        # initial output_cross_layer
+        if 'cross_layer_embedding_forward' in self.hooks:
+            output_cross_layer = self.hooks['cross_layer_embedding_forward'](hidden_states, **kw_args)
+        else:
+            output_cross_layer = {}
 
-        # define custom_forward for checkpointing
         output_per_layers = []
         if self.checkpoint_activations:
-            def custom(start, end):
+            # define custom_forward for checkpointing
+            def custom(start, end, kw_args_index, cross_layer_index):
                 def custom_forward(*inputs):
                     layers_ = self.layers[start:end]
                     x_, mask = inputs[0], inputs[1]
-                    if len(inputs) > 2: # have branch_input
-                        branch_ = inputs[2]
+
+                    # recover kw_args and output_cross_layer
+                    flat_inputs = inputs[2:]
+                    kw_args, output_cross_layer = {}, {}
+                    for k, idx in kw_args_index.items():
+                        kw_args[k] = flat_inputs[idx]
+                    for k, idx in cross_layer_index.items():
+                        output_cross_layer[k] = flat_inputs[idx]
+                    # -----------------
+
                     output_per_layers_part = []
                     for i, layer in enumerate(layers_):
-                        if len(inputs) > 2:
-                            x_, branch_, output_this_layer = self.hooks['layer_forward'](
-                                x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args
-                            )
-                        elif 'layer_forward' in self.hooks:
-                            x_, output_this_layer = self.hooks['layer_forward'](
-                                x_, mask, layer_id=layer.layer_id, **kw_args
+                        if 'layer_forward' in self.hooks:
+                            layer_ret = self.hooks['layer_forward'](
+                                x_, mask, layer_id=layer.layer_id,
+                                **kw_args, **output_cross_layer,
+                                output_this_layer={}, output_cross_layer={}
                             )
                         else:
-                            x_, output_this_layer = layer(x_, mask, **kw_args)
+                            layer_ret = layer(
+                                x_, mask, layer_id=layer.layer_id,
+                                **kw_args, **output_cross_layer,
+                                output_this_layer={}, output_cross_layer={}
+                            )
+                        if torch.is_tensor(layer_ret): # only hidden_states
+                            x_, output_this_layer, output_cross_layer = layer_ret, {}, {}
+                        elif len(layer_ret) == 2: # hidden_states & output_this_layer
+                            x_, output_this_layer = layer_ret
+                            output_cross_layer = {}
+                        elif len(layer_ret) == 3:
+                            x_, output_this_layer, output_cross_layer = layer_ret
+                        assert isinstance(output_this_layer, dict)
+                        assert isinstance(output_cross_layer, dict)
+                        if output_hidden_states:
+                            output_this_layer['hidden_states'] = x_
                         output_per_layers_part.append(output_this_layer)
-                    return x_, output_per_layers_part
+
+                    # flatten for re-aggregate keywords outputs
+                    flat_outputs = []
+                    for output_this_layer in output_per_layers_part:
+                        for k in output_this_layer:
+                            # TODO add warning for depth>=2 grad tensors
+                            flat_outputs.append(output_this_layer[k])
+                            output_this_layer[k] = len(flat_outputs) - 1
+                    for k in output_cross_layer:
+                        flat_outputs.append(output_cross_layer[k])
+                        output_cross_layer[k] = len(flat_outputs) - 1
+                    # --------------------
+
+                    return x_, output_per_layers_part, output_cross_layer, flat_outputs
                 return custom_forward
 
+            # prevent to lose requires_grad in checkpointing.
+            # To save memory when only finetuning the final layers, don't use checkpointing.
+            if self.training:
+                hidden_states.requires_grad_(True)
+
             l, num_layers = 0, len(self.layers)
             chunk_length = self.checkpoint_num_layers
+            output_this_layer = []
             while l < num_layers:
                 args = [hidden_states, attention_mask]
-                if branch_input is not None:
-                    hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, branch_input)
-                else:
-                    hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args)
-                if output_hidden_states:
-                    hidden_states_outputs.append(hidden_states)
+                # flatten kw_args and output_cross_layer
+                flat_inputs, kw_args_index, cross_layer_index = [], {}, {}
+                for k, v in kw_args.items():
+                    flat_inputs.append(v)
+                    kw_args_index[k] = len(flat_inputs) - 1
+                for k, v in output_cross_layer.items():
+                    flat_inputs.append(v)
+                    cross_layer_index[k] = len(flat_inputs) - 1
+                # --------------------
+
+                hidden_states, output_per_layers_part, output_cross_layer, flat_outputs = \
+                    checkpoint(custom(l, l + chunk_length, kw_args_index, cross_layer_index), *args, *flat_inputs)
+
+                # recover output_per_layers_part, output_cross_layer
+                for output_this_layer in output_per_layers_part:
+                    for k in output_this_layer:
+                        output_this_layer[k] = flat_outputs[output_this_layer[k]]
+                for k in output_cross_layer:
+                    output_cross_layer[k] = flat_outputs[output_cross_layer[k]]
+                # --------------------
+
                 output_per_layers.extend(output_per_layers_part)
                 l += chunk_length
         else:
+            output_this_layer = []
             for i, layer in enumerate(self.layers):
                 args = [hidden_states, attention_mask]
-                if branch_input is not None: # customized layer_forward with branch_input
-                    hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_args)
-                elif 'layer_forward' in self.hooks: # customized layer_forward
-                    hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args)
+
+                if 'layer_forward' in self.hooks: # customized layer_forward
+                    layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i),
+                        **kw_args,
+                        **output_cross_layer,
+                        output_this_layer={}, output_cross_layer={}
+                    )
                 else:
-                    hidden_states, output_this_layer = layer(*args, **kw_args)
+                    layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer,
+                        output_this_layer={}, output_cross_layer={})
+                if torch.is_tensor(layer_ret): # only hidden_states
+                    hidden_states, output_this_layer, output_cross_layer = layer_ret, {}, {}
+                elif len(layer_ret) == 2: # hidden_states & output_this_layer
+                    hidden_states, output_this_layer = layer_ret
+                    output_cross_layer = {}
+                elif len(layer_ret) == 3:
+                    hidden_states, output_this_layer, output_cross_layer = layer_ret
+
                 if output_hidden_states:
-                    hidden_states_outputs.append(hidden_states)
+                    output_this_layer['hidden_states'] = hidden_states
                 output_per_layers.append(output_this_layer)
 
         # Final layer norm.
@@ -410,19 +659,10 @@ class BaseTransformer(torch.nn.Module):
             logits_parallel = copy_to_model_parallel_region(logits)
             logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight)
 
-        # branch related embedding
-        if branch_input is None and 'branch_final_forward' in self.hooks:
-            branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args)
-
         if not self.parallel_output:
             logits_parallel = gather_from_model_parallel_region(logits_parallel)
 
         outputs = [logits_parallel]
-        if branch_input is not None:
-            outputs.append(branch_input)
-        if output_hidden_states:
-            outputs.append(hidden_states_outputs)
         outputs.extend(output_per_layers)
 
         return outputs
-
diff --git a/SwissArmyTransformer/mpu/utils.py b/SwissArmyTransformer/mpu/utils.py
index c83f501889ccabb33cce33260f7d4fee9eadcab7..2f4f29f08227269634af831b91d2fe9b301b4e9e 100755
--- a/SwissArmyTransformer/mpu/utils.py
+++ b/SwissArmyTransformer/mpu/utils.py
@@ -75,7 +75,7 @@ def sqrt(x):
 
 def unscaled_init_method(sigma):
     """Init method based on N(0, sigma)."""
-    def init_(tensor):
+    def init_(tensor, **kwargs):
         return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
 
     return init_
@@ -83,7 +83,7 @@ def unscaled_init_method(sigma):
 def scaled_init_method(sigma, num_layers):
     """Init method based on N(0, sigma/sqrt(2*num_layers)."""
     std = sigma / math.sqrt(2.0 * num_layers)
-    def init_(tensor):
+    def init_(tensor, **kwargs):
         return torch.nn.init.normal_(tensor, mean=0.0, std=std)
 
     return init_
diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py
index fa93f3fd575bb4049c0b3bbbddd7d66e196d1e2b..2279ee051c622f17991f0337206933d2a05c3653 100644
--- a/SwissArmyTransformer/tokenization/__init__.py
+++ b/SwissArmyTransformer/tokenization/__init__.py
@@ -29,7 +29,8 @@ def _export_vocab_size_to_args(args, original_num_tokens):
     print_rank_0('> padded vocab (size: {}) with {} dummy '
                  'tokens (new size: {})'.format(
         before, after - before, after))
-    args.vocab_size = after
+    if not args.vocab_size:
+        args.vocab_size = after
     print_rank_0("prepare tokenizer done")
     return tokenizer
 
@@ -51,7 +52,7 @@ def get_tokenizer(args=None, outer_tokenizer=None):
             from .cogview import UnifiedTokenizer
             get_tokenizer.tokenizer = UnifiedTokenizer(
                 args.img_tokenizer_path,
-                # txt_tokenizer_type=args.tokenizer_type,
+                txt_tokenizer_type='cogview',
                 device=torch.cuda.current_device()
             )
         elif args.tokenizer_type.startswith('glm'):
@@ -63,6 +64,10 @@ def get_tokenizer(args=None, outer_tokenizer=None):
             elif args.tokenizer_type == "glm_ChineseSPTokenizer":
                 from .glm import ChineseSPTokenizer
                 get_tokenizer.tokenizer = ChineseSPTokenizer(args.tokenizer_model_type, **kwargs)
+        elif args.tokenizer_type.startswith('hf'):
+            from .hf_tokenizer import HFT5Tokenizer
+            if args.tokenizer_type == "hf_T5Tokenizer":
+                get_tokenizer.tokenizer = HFT5Tokenizer(args.tokenizer_model_type, cache_dir=args.cache_dir)
         else:
             assert args.vocab_size > 0
             get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size)
diff --git a/SwissArmyTransformer/tokenization/cogview/sp_tokenizer.py b/SwissArmyTransformer/tokenization/cogview/sp_tokenizer.py
index ee8c907742dea9df7a715d1f44fb3ad067ddadb0..c9e731e6fdcc7fc52c6a0cb90d1c14f3ca0306fb 100755
--- a/SwissArmyTransformer/tokenization/cogview/sp_tokenizer.py
+++ b/SwissArmyTransformer/tokenization/cogview/sp_tokenizer.py
@@ -22,6 +22,8 @@ python setup.py install
 
 PRETRAINED_MODEL_FILE = os.path.join(os.path.dirname(os.path.dirname(__file__)),
      'embed_assets', 'chinese_sentencepiece/cog-pretrain.model')
+PRETRAINED_MODEL_FILE_ICE = os.path.join(os.path.dirname(os.path.dirname(__file__)),
+     'embed_assets', 'chinese_sentencepiece/ice.model') # merge xlnet 3,2000 En tokens
 
 
 def get_pairs(word):
@@ -148,5 +150,10 @@ def get_encoder(encoder_file, bpe_file):
         )
 
 
-def from_pretrained():
-    return get_encoder(PRETRAINED_MODEL_FILE, "")
\ No newline at end of file
+def from_pretrained(tokenizer_type='cogview'):
+    if tokenizer_type == 'cogview_ICE':
+        return get_encoder(PRETRAINED_MODEL_FILE_ICE, "")
+    elif tokenizer_type == 'cogview':
+        return get_encoder(PRETRAINED_MODEL_FILE, "")
+    else:
+        raise ValueError('Unknown cogview tokenizer.')
\ No newline at end of file
diff --git a/SwissArmyTransformer/tokenization/cogview/unified_tokenizer.py b/SwissArmyTransformer/tokenization/cogview/unified_tokenizer.py
index 8a06004d332955aa4b2f1fde712b52f188bc2a59..7b17f751aa8ca1d583f4e2641f43270c8667fe13 100755
--- a/SwissArmyTransformer/tokenization/cogview/unified_tokenizer.py
+++ b/SwissArmyTransformer/tokenization/cogview/unified_tokenizer.py
@@ -20,10 +20,10 @@ from .sp_tokenizer import from_pretrained
 from .vqvae_tokenizer import VQVAETokenizer, sqrt_int
 
 class UnifiedTokenizer(object):
-    def __init__(self, img_tokenizer_path, device):
+    def __init__(self, img_tokenizer_path, txt_tokenizer_type, device):
         self.device = device
         self.img_tokenizer = VQVAETokenizer(model_path=img_tokenizer_path, device=self.device)
-        self.txt_tokenizer = from_pretrained()
+        self.txt_tokenizer = from_pretrained(txt_tokenizer_type)
         self.num_tokens = self.img_tokenizer.num_tokens + self.txt_tokenizer.num_tokens
         self.raw_command_tokens = [
             ('[PAD]', 0),
diff --git a/SwissArmyTransformer/tokenization/glm/tokenization.py b/SwissArmyTransformer/tokenization/glm/tokenization.py
index 67be818a0f1d42b0d35696c24577b5e5391ff3b4..9b9a8abc0202b46a95efc8a8338e90e0da9fd60f 100644
--- a/SwissArmyTransformer/tokenization/glm/tokenization.py
+++ b/SwissArmyTransformer/tokenization/glm/tokenization.py
@@ -312,11 +312,11 @@ class Tokenizer(object):
         tokenization.tokenization = [self.IdToToken(idx) for idx in tokenization.tokenization]
         return tokenization
 
-    def IdToToken(self, Id):
+    def IdToToken(self, idx):
         """convert Id to token accounting for command tokens"""
-        if isinstance(Id, CommandToken):
-            return Id.token
-        return self.tokens[Id]
+        if isinstance(idx, CommandToken):
+            return idx.token
+        return self.tokens[idx]
 
     def TokenToId(self, token):
         """convert token to Id accounting for command tokens"""
@@ -324,16 +324,16 @@ class Tokenizer(object):
             return token.Id
         return self.vocab[token]
 
-    def DecodeIds(self, Ids):
+    def DecodeIds(self, ids):
         """
         convert Ids to tokens accounting for command tokens, tokens
         are joined and returned as a string.
         """
         rtn_strs = []
         current_str = []
-        if isinstance(Ids, Tokenization):
-            Ids = Ids.tokenization
-        for Id in Ids:
+        if isinstance(ids, Tokenization):
+            ids = ids.tokenization
+        for Id in ids:
             if isinstance(Id, CommandToken):
                 rtn_strs.append(self._decode(current_str))
                 current_str = []
@@ -353,11 +353,11 @@ class Tokenizer(object):
         output = self.clean_up_tokenization(output)
         return output
 
-    def DecodeTokens(self, Tokens):
+    def DecodeTokens(self, tokens):
         """
         convert tokens to a string accounting for command and type tokens.
         """
-        Ids = [self.TokenToId(token) for token in Tokens]
+        Ids = [self.TokenToId(token) for token in tokens]
         return self.DecodeIds(Ids)
 
 
diff --git a/SwissArmyTransformer/tokenization/hf_tokenizer.py b/SwissArmyTransformer/tokenization/hf_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..790b60d56441267b9ccd883b15cc0b11ab857018
--- /dev/null
+++ b/SwissArmyTransformer/tokenization/hf_tokenizer.py
@@ -0,0 +1,80 @@
+from transformers import T5Tokenizer
+from .glm.tokenization import Tokenization, CommandToken
+
+
+PRETRAINED_VOCAB_FILES_MAP = {
+    "t5-small": "/dataset/fd5061f6/yanan/huggingface_models/t5-small",
+    "t5-base": "/dataset/fd5061f6/yanan/huggingface_models/t5-base",
+    "t5-large": "/dataset/fd5061f6/yanan/huggingface_models/t5-large",
+    "t5-3b": "/dataset/fd5061f6/yanan/huggingface_models/t5-3b",
+    "t5-11b": "/dataset/fd5061f6/yanan/huggingface_models/t5-11b"
+}
+
+class HFTokenizer:
+    def __init__(self, model_cls, model_type_or_path=None, cache_dir=None, command_tokens=None):
+        if model_type_or_path in PRETRAINED_VOCAB_FILES_MAP:
+            model_type_or_path = PRETRAINED_VOCAB_FILES_MAP[model_type_or_path]
+        self.text_tokenizer = model_cls.from_pretrained(model_type_or_path, cache_dir=cache_dir)
+        self.num_tokens = len(self.text_tokenizer)
+        self._command_tokens = []
+        self.command_name_map = {}
+        self.command_token_map = {}
+        self.command_id_map = {}
+
+    def __len__(self):
+        return len(self.text_tokenizer)
+
+    @property
+    def command_tokens(self):
+        return self._command_tokens
+
+    @command_tokens.setter
+    def command_tokens(self, command_tokens):
+        self._command_tokens = command_tokens
+        self.command_name_map = {tok.name: tok for tok in self.command_tokens}
+        self.command_token_map = {tok.token: tok for tok in self.command_tokens}
+        self.command_id_map = {tok.Id: tok for tok in self.command_tokens}
+
+    def get_command(self, name):
+        """get command token corresponding to `name`"""
+        return self.command_name_map[name]
+
+    def EncodeAsIds(self, text, process_fn=None):
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        ids = self.text_tokenizer.encode(processed_text, add_special_tokens=False)
+        tokenization = Tokenization(ids, processed_text, text)
+        return tokenization
+
+    def DecodeIds(self, ids):
+        if isinstance(ids, Tokenization):
+            ids = ids.tokenization
+        return self.text_tokenizer.decode(ids)
+
+    def DecodeTokens(self, tokens):
+        return self.text_tokenizer.convert_tokens_to_string(tokens)
+
+    def IdToToken(self, Id):
+        if isinstance(Id, CommandToken):
+            return Id.token
+        return self.text_tokenizer.convert_ids_to_tokens(Id)
+
+    def TokenToId(self, token):
+        if isinstance(token, CommandToken):
+            return token.Id
+        return self.text_tokenizer.convert_tokens_to_ids(token)
+
+
+class HFT5Tokenizer(HFTokenizer):
+    def __init__(self, model_type_or_path=None, cache_dir=None):
+        super().__init__(T5Tokenizer, model_type_or_path=model_type_or_path, cache_dir=cache_dir)
+        command_tokens = [
+            CommandToken('eos', '</s>', self.TokenToId("</s>")),
+            CommandToken('pad', '<pad>', self.TokenToId("<pad>")),
+            CommandToken('sop', '<pad>', self.TokenToId("<pad>"))
+        ]
+        for i in range(100):
+            command_tokens.append(CommandToken(f'MASK{i}', f'<extra_id_{i}>', self.TokenToId(f'<extra_id_{i}>')))
+        self.command_tokens = command_tokens
+
diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py
index c7860e3a9883811732a968b168894b2a668f29e4..beacffc39d8dfce4609c25fcf40c37c1d78e6c6e 100644
--- a/SwissArmyTransformer/training/deepspeed_training.py
+++ b/SwissArmyTransformer/training/deepspeed_training.py
@@ -100,14 +100,6 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio
         if val_data is not None:
             start_iter_val = (args.train_iters // args.save_interval) * args.eval_interval
             val_data.batch_sampler.start_iter = start_iter_val % len(val_data)
-    if train_data is not None:
-        train_data_iterator = iter(train_data)
-    else:
-        train_data_iterator = None
-    if val_data is not None:
-        val_data_iterator = iter(val_data)
-    else:
-        val_data_iterator = None
 
     # init hook before training
     if hooks['init_function'] is not None:
@@ -121,16 +113,12 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio
                 def save_on_exit(args_, model_, optimizer_, lr_scheduler_):
                     save_checkpoint(args_.iteration, model_, optimizer_, lr_scheduler_, args_)
                 iteration, skipped = train(model, optimizer,
-                                           lr_scheduler,
-                                           train_data_iterator,
-                                           val_data_iterator,
-                                           timers, args, summary_writer=summary_writer,
-                                           hooks=hooks
-                                           )
-        if args.do_valid:
-            prefix = 'the end of training for val data'
-            val_loss = evaluate_and_print_results(prefix, val_data_iterator,
-                model, args, timers, False, hooks=hooks)
+                    lr_scheduler,
+                    train_data,
+                    val_data,
+                    timers, args, summary_writer=summary_writer,
+                    hooks=hooks
+                    )
 
     # final save
     if args.save and iteration != 0:  # TODO save
@@ -140,7 +128,7 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio
     if args.do_test and test_data is not None:
         prefix = 'the end of training for test data'
         evaluate_and_print_results(prefix, iter(test_data),
-            model, args, timers, True, hooks=hooks)
+            model, len(test_data) if args.strict_eval else args.eval_iters, args, timers, True, hooks=hooks)
 
 
 def get_model(args, model_cls):
@@ -156,6 +144,8 @@ def get_model(args, model_cls):
 
     if args.fp16:
         model.half()
+    elif args.bf16:
+        model.bfloat16()
     model.cuda(torch.cuda.current_device())
 
     return model
@@ -258,11 +248,21 @@ def get_learning_rate_scheduler(optimizer, iteration, args,
 
 
 def train(model, optimizer, lr_scheduler,
-          train_data_iterator, val_data_iterator, timers, args,
-          summary_writer=None, hooks={}):
+        train_data, val_data, timers, args, 
+        summary_writer=None, hooks={}):
     """Train the model."""
+    if train_data is not None:
+        train_data_iterator = iter(train_data)
+    else:
+        train_data_iterator = None
+    if val_data is not None:
+        val_data_iterator = iter(val_data)
+    else:
+        val_data_iterator = None
+        
     # Turn on training mode which enables dropout.
     model.train()
+    
 
     # Tracking loss.
     total_lm_loss = 0.0
@@ -316,10 +316,14 @@ def train(model, optimizer, lr_scheduler,
 
         # Evaluation
         if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid:
+            if args.strict_eval:
+                val_data_iterator = iter(val_data)
+                eval_iters = len(val_data)
+            else:
+                eval_iters = args.eval_iters
             prefix = 'iteration {}'.format(args.iteration)
             evaluate_and_print_results(
-                prefix, val_data_iterator, model, args, timers, False, step=args.iteration,
-                summary_writer=summary_writer, hooks=hooks)
+                prefix, val_data_iterator, model, eval_iters, args, timers, False, step=args.iteration, summary_writer=summary_writer, hooks=hooks)
 
         if args.exit_interval and args.iteration % args.exit_interval == 0:
             torch.distributed.barrier()
@@ -412,21 +416,19 @@ def backward_step(optimizer, model, loss, args, timers):
 
     return
 
-
-def evaluate(data_iterator, model, args, timers, verbose=False, hooks={}):
+def evaluate(data_iterator, model, eval_iters, args, timers, verbose=False, hooks={}):
     """Evaluation."""
     forward_step = hooks['forward_step']
-
     # Turn on evaluation mode which disables dropout.
     model.eval()
 
-    total_lm_loss = 0
+    total_lm_loss, metrics_total = 0, {}
     with torch.no_grad():
         iteration = 0
-        while iteration < args.eval_iters:
+        while iteration < eval_iters:
             iteration += 1
             if verbose and iteration % args.log_interval == 0:
-                print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters))
+                print_rank_0('Evaluating iter {}/{}'.format(iteration, eval_iters))
             # Forward evaluation.
             lm_loss, metrics = forward_step(data_iterator, model, args, timers)
             '''when contiguous memory optimizations are enabled, the buffers
@@ -436,27 +438,31 @@ def evaluate(data_iterator, model, args, timers, verbose=False, hooks={}):
             if args.deepspeed and args.deepspeed_activation_checkpointing:
                 deepspeed.checkpointing.reset()
             total_lm_loss += lm_loss.data.detach().float().item()
+            for name in metrics:
+                if name not in metrics_total:
+                    metrics_total[name] = 0.0
+                metrics_total[name] += metrics[name]
 
     # Move model back to the train mode.
     model.train()
 
-    total_lm_loss /= args.eval_iters
-    return total_lm_loss
+    total_lm_loss /= eval_iters
+    metrics_avg = {key: value / eval_iters for key, value in metrics_total.items()}
+    return total_lm_loss, metrics_avg
 
-
-def evaluate_and_print_results(prefix, data_iterator, model,
-                               args, timers, verbose=False, step=None, summary_writer=None, hooks={}):
+def evaluate_and_print_results(prefix, data_iterator, model, eval_iters,
+                            args, timers, verbose=False, step=None, summary_writer=None, hooks={}):
     """Helper function to evaluate and dump results on screen."""
     # import line_profiler
     # profile = line_profiler.LineProfiler(model.module.module.transformer.layers[0].forward)
     # profile.enable()
     # torch.cuda.empty_cache()
-    lm_loss = evaluate(data_iterator, model, args, timers, verbose, hooks=hooks)
+    lm_loss, metrics = evaluate(data_iterator, model, eval_iters, args, timers, verbose, hooks=hooks)
     # profile.disable()
     # import sys
     # profile.print_stats(sys.stdout)
     lm_ppl = math.exp(min(20, lm_loss))
-    report_evaluate_metrics(summary_writer, prefix, lm_loss, lm_ppl, step)
+    report_evaluate_metrics(summary_writer, prefix, lm_loss, lm_ppl, step, metrics)
 
     return lm_loss
 
@@ -480,10 +486,12 @@ def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time,
             summary_writer.add_scalar('Train/'+key, avg_metrics[key], step)
 
 
-def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step):
+def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step, avg_metrics):
     string = ' validation loss at {} | '.format(prefix)
     string += 'LM loss: {:.6E} | '.format(loss)
     string += 'LM PPL: {:.6E}'.format(ppl)
+    for key in avg_metrics:
+        string += ' {} {:.6E} |'.format(key, avg_metrics[key])
     length = len(string) + 1
     print_rank_0('-' * 100)
     print_rank_0('-' * length)
@@ -492,7 +500,9 @@ def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step):
     if summary_writer is not None:
         summary_writer.add_scalar(f'Train/valid_ppl', ppl, step)
         summary_writer.add_scalar(f'Train/valid_loss', loss, step)
-
+        for key in avg_metrics:
+            summary_writer.add_scalar('Train/valid_'+key, avg_metrics[key], step)
+        
 
 '''
     Optional DeepSpeed Activation Checkpointing features
@@ -538,7 +548,8 @@ def initialize_distributed(args):
     # Optional DeepSpeed Activation Checkpointing Features
     if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing:
         set_deepspeed_activation_checkpointing(args)  # TODO manual model-parallel seed
-
+    else:
+        mpu.get_cuda_rng_tracker = None
 
 def set_random_seed(seed):
     """Set random seed for reproducability."""
diff --git a/examples/glm/finetune_glm_sst2.py b/examples/glm/finetune_glm_sst2.py
new file mode 100755
index 0000000000000000000000000000000000000000..14ef44d9cfa606d78fa9ea461f609453b24d8e73
--- /dev/null
+++ b/examples/glm/finetune_glm_sst2.py
@@ -0,0 +1,109 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   finetune_glm_sst2.py
+@Time    :   2021/12/12 20:53:28
+@Author  :   Ming Ding 
+@Contact :   dm18@mails.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+
+from SwissArmyTransformer.data_utils.datasets import TSVDataset
+import torch
+import argparse
+import numpy as np
+
+from SwissArmyTransformer import mpu, get_args, get_tokenizer
+from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin, non_conflict
+from SwissArmyTransformer.training.deepspeed_training import training_main
+from SwissArmyTransformer.data_utils import TSVDataset
+from SwissArmyTransformer.model import GLMModel
+from SwissArmyTransformer.mpu.transformer import standard_attention
+from SwissArmyTransformer.model.mixins import MLPHeadMixin, PrefixTuningMixin
+
+class ClassificationModel(GLMModel):
+    def __init__(self, args, transformer=None, parallel_output=True):
+        super().__init__(args, transformer=transformer, parallel_output=parallel_output)
+        self.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1))
+        self.add_mixin('prefix-tuning', PrefixTuningMixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.prefix_len))
+    def disable_untrainable_params(self):
+        self.transformer.word_embeddings.requires_grad_(False)
+        # for layer_id in range(len(self.transformer.layers)):
+        #     self.transformer.layers[layer_id].requires_grad_(False)
+    
+def get_batch(data_iterator, args, timers):
+    # Items and their type.
+    keys = ['sentence', 'label']
+    datatype = torch.int64
+
+    # Broadcast data.
+    timers('data loader').start()
+    if data_iterator is not None:
+        data = next(data_iterator)
+    else:
+        data = None
+    timers('data loader').stop()
+    data_b = mpu.broadcast_data(keys, data, datatype)
+    # Unpack.
+    tokens = data_b['sentence'].long()
+    labels = data_b['label'].long()
+    batch_size, seq_length = tokens.size()
+    
+    position_ids = torch.zeros(2, seq_length, device=tokens.device, dtype=torch.long)
+    torch.arange(0, seq_length, out=position_ids[0, :seq_length])
+    position_ids = position_ids.unsqueeze(0)
+    
+    attention_mask = torch.ones((batch_size, 1, seq_length, seq_length), device=tokens.device)
+
+    attention_mask[...,:seq_length] -= (tokens==-1).view(batch_size, 1, 1, seq_length).float()
+    # Convert
+    if args.fp16:
+        attention_mask = attention_mask.half()
+    return tokens, labels, attention_mask, position_ids, (tokens!=-1)
+
+
+def forward_step(data_iterator, model, args, timers):
+    """Forward step."""
+
+    # Get the batch.
+    timers('batch generator').start()
+    tokens, labels, attention_mask, position_ids, loss_mask = get_batch(
+        data_iterator, args, timers)
+    timers('batch generator').stop()
+
+    logits, *mems = model(tokens, position_ids, attention_mask)
+    pred = ((logits.contiguous().float().squeeze(-1)) * loss_mask).sum(dim=-1) / loss_mask.sum(dim=-1)
+    loss = torch.nn.functional.binary_cross_entropy_with_logits(
+        pred, 
+        labels.float()
+        )
+    acc = ((pred > 0.).long() == labels).sum() / labels.numel()
+    return loss, {'acc': acc}
+
+def create_dataset_function(path, args):
+    tokenizer = get_tokenizer()
+    def process_fn(row):
+        sentence, label = tokenizer._encode(row[0]), int(row[1])
+        sentence = [tokenizer.get_command('ENC').Id] + sentence + [tokenizer.get_command('eos').Id]
+        if len(sentence) >= args.sample_length:
+            sentence = sentence[:args.sample_length]
+        else:
+            sentence.extend([-1] * (args.sample_length-len(sentence)))
+        return {'sentence': np.array(sentence, dtype=np.int64), 'label': label}
+    return TSVDataset(path, process_fn, with_heads=True)
+
+if __name__ == '__main__':    
+    py_parser = argparse.ArgumentParser(add_help=False)
+    py_parser.add_argument('--new_hyperparam', type=str, default=None)
+    py_parser.add_argument('--sample_length', type=int, default=80)
+    py_parser.add_argument('--prefix_len', type=int, default=16)
+    known, args_list = py_parser.parse_known_args()
+    args = get_args(args_list)
+    args = argparse.Namespace(**vars(args), **vars(known))
+    # from cogdata.utils.ice_tokenizer import get_tokenizer as get_ice
+    # tokenizer = get_tokenizer(args=args, outer_tokenizer=get_ice())
+    training_main(args, model_cls=ClassificationModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
diff --git a/examples/glm/scripts/ds_config_ft.json b/examples/glm/scripts/ds_config_ft.json
new file mode 100755
index 0000000000000000000000000000000000000000..17effe048977096633ce31ef75a1cd0d9b2d813b
--- /dev/null
+++ b/examples/glm/scripts/ds_config_ft.json
@@ -0,0 +1,30 @@
+{
+  "train_micro_batch_size_per_gpu":64,
+  "gradient_accumulation_steps": 1,
+  "steps_per_print": 10,
+  "gradient_clipping": 0.1,
+  "fp16": {
+    "enabled": true,
+    "loss_scale": 0,
+    "loss_scale_window": 400,
+    "hysteresis": 2,
+    "min_loss_scale": 1
+  },
+  "optimizer": {
+    "type": "Adam",
+    "params": {
+      "lr": 0.00001,
+      "betas": [
+        0.9,
+        0.95
+      ],
+      "eps": 1e-8,
+      "weight_decay": 0
+    }
+  },
+  "activation_checkpointing": {
+    "partition_activations": false,
+    "contiguous_memory_optimization": false
+  },
+  "wall_clock_breakdown": false
+}
diff --git a/examples/glm/scripts/finetune_sst2.sh b/examples/glm/scripts/finetune_sst2.sh
new file mode 100755
index 0000000000000000000000000000000000000000..4e4809a07557b96afa37b00e1ea74a5830d8608e
--- /dev/null
+++ b/examples/glm/scripts/finetune_sst2.sh
@@ -0,0 +1,62 @@
+#! /bin/bash
+
+# Change for multinode config
+CHECKPOINT_PATH=/dataset/fd5061f6/sat_pretrained/glm
+
+NUM_WORKERS=1
+NUM_GPUS_PER_WORKER=1
+MP_SIZE=1
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+main_dir=$(dirname $script_dir)
+source $main_dir/config/model_glm_roberta_large.sh
+
+OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
+HOST_FILE_PATH="hostfile"
+HOST_FILE_PATH="hostfile_single"
+
+en_data="/dataset/fd5061f6/english_data/glue_data/SST-2/train.tsv"
+eval_data="/dataset/fd5061f6/english_data/glue_data/SST-2/dev.tsv"
+
+
+config_json="$script_dir/ds_config_ft.json"
+gpt_options=" \
+       --experiment-name finetune-glm-sst2 \
+       --model-parallel-size ${MP_SIZE} \
+       --mode finetune \
+       --train-iters 6000 \
+       --resume-dataloader \
+       $MODEL_ARGS \
+       --train-data ${en_data} \
+       --valid-data ${eval_data} \
+       --distributed-backend nccl \
+       --lr-decay-style cosine \
+       --warmup .02 \
+       --checkpoint-activations \
+       --fp16 \
+       --save-interval 6000 \
+       --eval-interval 100 \
+       --save /root/checkpoints \
+       --split 1 \
+       --strict-eval \
+       --eval-batch-size 8 
+"
+       # --load  /root/checkpoints/pretrain-bert-mid-std-fulltrain12-02-06-10
+       #  \       --sandwich-ln
+       # --split 949,50,1 \
+       # --load /root/checkpoints/pretrain-bert-mid11-28-15-38 \
+
+
+
+gpt_options="${gpt_options}
+       --deepspeed \
+       --deepspeed_config ${config_json} \
+"
+              
+
+run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} finetune_glm_sst2.py $@ ${gpt_options}"
+echo ${run_cmd}
+eval ${run_cmd}
+
+set +x
diff --git a/examples/t5/config/config_t5_large.json b/examples/t5/config/config_t5_large.json
new file mode 100644
index 0000000000000000000000000000000000000000..1a4bffcaff366d29421f5d1d605b819457586b78
--- /dev/null
+++ b/examples/t5/config/config_t5_large.json
@@ -0,0 +1,31 @@
+{
+  "train_micro_batch_size_per_gpu": 4,
+  "gradient_accumulation_steps": 1,
+  "steps_per_print": 100,
+  "gradient_clipping": 1.0,
+  "zero_optimization": {
+    "stage": 2,
+    "contiguous_gradients": false,
+    "overlap_comm": true,
+    "reduce_scatter": true,
+    "reduce_bucket_size": 50000000,
+    "allgather_bucket_size": 500000000
+  },
+  "optimizer": {
+    "type": "Adam",
+    "params": {
+      "lr": 0.0002,
+      "weight_decay": 0.1,
+      "betas": [
+        0.9,
+        0.98
+      ],
+      "eps": 1e-6
+    }
+  },
+  "activation_checkpointing": {
+    "partition_activations": false,
+    "contiguous_memory_optimization": false
+  },
+  "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/examples/t5/config/model_t5_large.sh b/examples/t5/config/model_t5_large.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4e97805010050a195b521b1b08fe58c6cf1ec703
--- /dev/null
+++ b/examples/t5/config/model_t5_large.sh
@@ -0,0 +1,14 @@
+MODEL_TYPE="t5-large"
+MODEL_ARGS="--block-lm \
+            --cloze-eval \
+            --vocab-size 32128 \
+            --num-layers 24 \
+            --hidden-size 1024 \
+            --inner-hidden-size 4096 \
+            --num-attention-heads 16 \
+            --hidden-size-per-attention-head 64 \
+            --max-sequence-length 513 \
+            --relative-attention-num-buckets 32 \
+            --layernorm-epsilon 1e-6 \
+            --tokenizer-type hf_T5Tokenizer \
+            --tokenizer-model-type t5-large"
\ No newline at end of file
diff --git a/examples/t5/finetune_t5.py b/examples/t5/finetune_t5.py
new file mode 100644
index 0000000000000000000000000000000000000000..d64423947cdcab706524fe766f3b000c825a6d66
--- /dev/null
+++ b/examples/t5/finetune_t5.py
@@ -0,0 +1,83 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   finetune_t5.py
+@Time    :   2021/12/11 02:39:13
+@Author  :   Ming Ding 
+@Contact :   dm18@mails.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+
+import torch
+import argparse
+import numpy as np
+
+from SwissArmyTransformer import mpu, get_args, get_tokenizer
+from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
+from SwissArmyTransformer.training.deepspeed_training import training_main
+from SwissArmyTransformer.data_utils import TSVDataset
+from SwissArmyTransformer.model import T5Model
+
+def get_batch(data_iterator, args, timers):
+    # Items and their type.
+    keys = ['sentence', 'label']
+    datatype = torch.int64
+
+    # Broadcast data.
+    timers('data loader').start()
+    if data_iterator is not None:
+        data = next(data_iterator)
+    else:
+        data = None
+    timers('data loader').stop()
+    data = data[0].to('cuda')
+    attention_mask = torch.ones((1, data.shape[-1], data.shape[-1]), device=data.device)
+    attention_mask.tril_()
+    attention_mask.unsqueeze_(1)
+
+    return data, torch.arange(data.shape[-1], device=data.device), attention_mask
+    
+
+
+def forward_step(data_iterator, model, args, timers):
+    """Forward step."""
+
+    # Get the batch.
+    timers('batch generator').start()
+    input_ids, position_ids, mask = get_batch(
+        data_iterator, args, timers)
+    timers('batch generator').stop()
+    # Forward model.
+    
+    enc, logits, *mems = model(
+        enc_input_ids=input_ids,
+        dec_input_ids=input_ids, 
+        enc_position_ids=position_ids,
+        dec_position_ids=position_ids,
+        dec_attention_mask=mask)
+    # logits, *mems = model(tokens, position_ids, attention_mask)
+    loss = logits.mean()
+    return loss, {}
+
+def create_dataset_function(path, args):
+    
+    return torch.utils.data.TensorDataset(
+        torch.ones(100000, 20, dtype=torch.long)
+    )
+
+if __name__ == '__main__':    
+    py_parser = argparse.ArgumentParser(add_help=False)
+    py_parser.add_argument('--new_hyperparam', type=str, default=None)
+    py_parser.add_argument('--sample_length', type=int, default=80)
+    py_parser.add_argument('--prefix_len', type=int, default=16)
+    py_parser.add_argument('--cache-dir', type=str, default='/root/some_cache',
+                           help='hf cache')
+    T5Model.add_model_specific_args(py_parser)
+    known, args_list = py_parser.parse_known_args()
+    args = get_args(args_list)
+    args = argparse.Namespace(**vars(args), **vars(known))
+    training_main(args, model_cls=T5Model, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
diff --git a/examples/t5/inference_t5.py b/examples/t5/inference_t5.py
new file mode 100644
index 0000000000000000000000000000000000000000..8286aeee411a7621fd6404980b051f9e08b84033
--- /dev/null
+++ b/examples/t5/inference_t5.py
@@ -0,0 +1,79 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   inference_glm.py
+@Time    :   2021/10/22 19:41:58
+@Author  :   Ming Ding
+@Contact :   dm18@mails.tsinghua.edu.cn
+'''
+
+# here put the import lib
+from functools import partial
+import os
+import sys
+import random
+import time
+from datetime import datetime
+import torch
+import torch.nn.functional as F
+import argparse
+import stat
+from functools import partial
+
+from SwissArmyTransformer import mpu, get_args, get_tokenizer, load_checkpoint, initialize_distributed, set_random_seed
+
+from SwissArmyTransformer.model import T5Model
+from SwissArmyTransformer.model.mixins import CachedAutoregressiveMixin
+from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence, evaluate_perplexity, get_masks_and_position_ids_default
+from SwissArmyTransformer.generation.sampling_strategies import BeamSearchStrategy, BaseStrategy
+from SwissArmyTransformer.generation.utils import timed_name, generate_continually
+from SwissArmyTransformer.training.deepspeed_training import setup_model_and_optimizer
+
+
+
+def main(args):
+    args.do_train = False
+    initialize_distributed(args)
+    tokenizer = get_tokenizer(args)
+    # build model 
+    model = T5Model(args)
+    # model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
+    if args.fp16:
+        model = model.half()
+    model = model.to(args.device)
+    load_checkpoint(model, args)
+    set_random_seed(args.seed)
+    model.eval()
+
+    # test correctness
+    input_ids = tokenizer.EncodeAsIds("The <extra_id_0> walks in <extra_id_1> park").tokenization
+    input_ids = input_ids + [tokenizer.get_command("eos").Id]
+    input_ids = torch.tensor(input_ids, device='cuda', dtype=torch.long)
+    decoder_input_ids = tokenizer.EncodeAsIds('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>').tokenization
+    decoder_input_ids = decoder_input_ids + [tokenizer.get_command("eos").Id]
+    decoder_input_ids = torch.tensor(decoder_input_ids, device='cuda', dtype=torch.long)
+
+    input_ids, _mask, enc_position_ids = get_masks_and_position_ids_default(input_ids)
+    
+    decoder_input_ids, dec_attention_mask, dec_position_ids = get_masks_and_position_ids_default(decoder_input_ids)
+    
+    encoder_outputs, decoder_outputs, *mems = model(
+        enc_input_ids=input_ids,
+        dec_input_ids=decoder_input_ids, 
+        dec_attention_mask=dec_attention_mask
+    )
+    breakpoint()
+    
+
+if __name__ == "__main__":
+    py_parser = argparse.ArgumentParser(add_help=False)
+    py_parser.add_argument('--sampling-strategy', type=str, default='BaseStrategy',
+                           help='type name of sampling strategy')
+    py_parser.add_argument('--cache-dir', type=str, default='/root/some_cache',
+                           help='hf cache')
+    T5Model.add_model_specific_args(py_parser)
+    known, args_list = py_parser.parse_known_args()
+    args = get_args(args_list)
+    args = argparse.Namespace(**vars(args), **vars(known))
+
+    with torch.no_grad():
+        main(args)
diff --git a/examples/t5/scripts/config_t5_tmp.json b/examples/t5/scripts/config_t5_tmp.json
new file mode 100644
index 0000000000000000000000000000000000000000..e87fa36ee0ddf4fdb7a0dbb3b92ca51019d912dc
--- /dev/null
+++ b/examples/t5/scripts/config_t5_tmp.json
@@ -0,0 +1,23 @@
+{
+    "train_micro_batch_size_per_gpu": 16,
+    "gradient_accumulation_steps": 1,
+    "steps_per_print": 100,
+    "gradient_clipping": 1.0,
+    "optimizer": {
+      "type": "Adam",
+      "params": {
+        "lr": 0.0002,
+        "weight_decay": 0.1,
+        "betas": [
+          0.9,
+          0.98
+        ],
+        "eps": 1e-6
+      }
+    },
+    "activation_checkpointing": {
+      "partition_activations": false,
+      "contiguous_memory_optimization": false
+    },
+    "wall_clock_breakdown": false
+  }
\ No newline at end of file
diff --git a/examples/t5/scripts/finetune_t5.sh b/examples/t5/scripts/finetune_t5.sh
new file mode 100755
index 0000000000000000000000000000000000000000..365c6233b8fbd94add977deee25e2c4d8cb734ce
--- /dev/null
+++ b/examples/t5/scripts/finetune_t5.sh
@@ -0,0 +1,54 @@
+#! /bin/bash
+
+# Change for multinode config
+
+NUM_WORKERS=1
+NUM_GPUS_PER_WORKER=2
+MP_SIZE=1
+source $1
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+main_dir=$(dirname $script_dir)
+
+OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
+HOST_FILE_PATH="hostfile"
+HOST_FILE_PATH="hostfile_single"
+
+full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin"
+small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata"
+
+config_json="$main_dir/config/config_t5_large.json"
+gpt_options=" \
+       --experiment-name finetune-t5-test \
+       --tokenizer-type Fake \
+       --model-parallel-size ${MP_SIZE} \
+       --mode finetune \
+       $MODEL_ARGS \
+       --train-iters 200000 \
+       --resume-dataloader \
+       --train-data ${small_data} \
+       --split 1 \
+       --distributed-backend nccl \
+       --lr-decay-style cosine \
+       --warmup .1 \
+       --checkpoint-activations \
+       --save-interval 5000 \
+       --eval-interval 1000 \
+       --save /root/checkpoints \
+       --fp16
+"
+       # --load pretrained/cogview/cogview-base
+
+
+gpt_options="${gpt_options}
+       --deepspeed \
+       --deepspeed_config ${config_json} \
+"
+              
+
+run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} finetune_t5.py ${gpt_options}"
+echo ${run_cmd}
+eval ${run_cmd}
+
+set +x
diff --git a/examples/t5/scripts/generate_t5.sh b/examples/t5/scripts/generate_t5.sh
new file mode 100755
index 0000000000000000000000000000000000000000..f2432546348e684202911259b35a6697528d1491
--- /dev/null
+++ b/examples/t5/scripts/generate_t5.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+CHECKPOINT_PATH=/dataset/fd5061f6/sat_pretrained/t5/t5-large
+
+source $1
+MPSIZE=1
+MAXSEQLEN=512
+MASTER_PORT=$(shuf -n 1 -i 10000-65535)
+
+#SAMPLING ARGS
+TEMP=0.9
+#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p
+TOPK=40
+TOPP=0
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+
+config_json="$script_dir/config_t5_tmp.json"
+
+python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTER_PORT inference_t5.py \
+       --mode inference \
+       --model-parallel-size $MPSIZE \
+       $MODEL_ARGS \
+       --num-beams 4 \
+       --no-repeat-ngram-size 3 \
+       --length-penalty 0.7 \
+       --out-seq-length $MAXSEQLEN \
+       --temperature $TEMP \
+       --top_k $TOPK \
+       --output-path samples_glm \
+       --batch-size 2 \
+       --out-seq-length 200 \
+       --mode inference \
+       --input-source ./input.txt \
+       --checkpoint-activations \
+       --sampling-strategy BeamSearchStrategy \
+       --load $CHECKPOINT_PATH
diff --git a/examples/t5/test_t5.py b/examples/t5/test_t5.py
new file mode 100644
index 0000000000000000000000000000000000000000..41b9d2ca84320ab87a6e0eaa3c3b8912b755e014
--- /dev/null
+++ b/examples/t5/test_t5.py
@@ -0,0 +1,12 @@
+from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer
+device = 'cuda:1'
+tokenizer = T5Tokenizer.from_pretrained("t5-large")
+model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-xl-lm-adapt")
+model = model.to(device)
+model.eval()
+input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids.to(device)
+decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids.to(device)
+breakpoint()
+output = model(input_ids=input_ids, labels=decoder_input_ids)
+output.loss.backward()
+a = 1
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 636894d95dcc22a78c3c6e44697609ce9ece39e6..f106e9f7750ba16b476d2a96ad178ed3e613e4b1 100644
--- a/setup.py
+++ b/setup.py
@@ -16,7 +16,7 @@ def _requirements():
 
 setup(
     name="SwissArmyTransformer",
-    version='0.1.2',
+    version='0.1.3',
     description="A transformer-based framework with finetuning as the first class citizen.",
     long_description=Path("README.md").read_text(),
     long_description_content_type="text/markdown",