diff --git a/SwissArmyTransformer/model/cached_autoregressive_model.py b/SwissArmyTransformer/model/cached_autoregressive_model.py
index ed25fe48aaa3ef83b6c784429b229d40a22b7020..8caed663b27e4b6c1b4382090ddd754a1c875d3e 100755
--- a/SwissArmyTransformer/model/cached_autoregressive_model.py
+++ b/SwissArmyTransformer/model/cached_autoregressive_model.py
@@ -21,20 +21,22 @@ class CachedAutoregressiveMixin(BaseMixin):
         super().__init__()     
            
     @non_conflict
-    def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, old_impl=standard_attention, **kw_args):
-        mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size
-        b, nh, seq_len, hidden_size = k.shape
+    def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, cross_attention=False, old_impl=standard_attention,
+                     **kw_args):
+        if not cross_attention:
+            mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size
+            b, nh, seq_len, hidden_size = k.shape
 
-        cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2)
-        kw_args['output_this_layer']['mem_kv'] = cache_kv
-        
-        if mem is not None: # the first time, mem is None
-            # might change batch_size
-            mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4) 
-            memk, memv = mem[0], mem[1]
-            k = torch.cat((memk, k), dim=2)
-            v = torch.cat((memv, v), dim=2)
-        return old_impl(q, k, v, mask, dropout_fn, **kw_args)
+            cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2)
+            kw_args['output_this_layer']['mem_kv'] = cache_kv
+
+            if mem is not None: # the first time, mem is None
+                # might change batch_size
+                mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4)
+                memk, memv = mem[0], mem[1]
+                k = torch.cat((memk, k), dim=2)
+                v = torch.cat((memv, v), dim=2)
+        return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, mems=mems, **kw_args)
 
 
 class CachedAutoregressiveModel(BaseModel):
diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
index a5f13760351ce27da585f96960493421b4378dc5..d1a50951ea352babcd763a0178ed2de2013695dc 100644
--- a/SwissArmyTransformer/model/t5_model.py
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -3,10 +3,12 @@ import torch
 import torch.nn.functional as F
 from .mixins import BaseMixin
 from .encoder_decoder_model import EncoderDecoderModel
+from .base_model import non_conflict
 from SwissArmyTransformer.mpu import get_model_parallel_world_size
 from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP
 from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
-from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim
+from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim, unscaled_init_method
+from SwissArmyTransformer.mpu.layers import ColumnParallelLinear, VocabParallelEmbedding
 
 
 class T5PositionEmbeddingMixin(BaseMixin):
@@ -94,7 +96,7 @@ class T5AttentionMixin(BaseMixin):
         relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
         return relative_buckets
 
-    def compute_bias(self, query_length, key_length, cross_attention=False):
+    def compute_bias(self, query_length, key_length):
         """Compute binned relative position bias"""
         context_position = torch.arange(query_length, dtype=torch.long)[:, None]
         memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
@@ -106,84 +108,88 @@ class T5AttentionMixin(BaseMixin):
         )
         relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
         # shape (query_length, key_length, num_heads)
-        if cross_attention:
-            values = self.cross_relative_attention_bias(relative_position_bucket)
-        else:
-            values = self.relative_attention_bias(relative_position_bucket)
+        values = self.relative_attention_bias(relative_position_bucket)
         values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
         return values
 
-    def attention_forward(self, hidden_states, mask, position_bias=None, *args, layer_id=None, mems=None, **kw_args):
-        attn_module = self.transformer.layers[layer_id].attention
-        seq_length = hidden_states.size(1)
-        memory_length = mems[layer_id].size(1) if mems else 0
-        mixed_raw_layer = attn_module.query_key_value(hidden_states)
-        (mixed_query_layer,
-         mixed_key_layer,
-         mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
-
-        dropout_fn = attn_module.attention_dropout if attn_module.training else None
-
-        query_layer = attn_module._transpose_for_scores(mixed_query_layer)
-        key_layer = attn_module._transpose_for_scores(mixed_key_layer)
-        value_layer = attn_module._transpose_for_scores(mixed_value_layer)
-
-        if position_bias is None:
-            position_bias = self.compute_bias(seq_length, memory_length + seq_length)
-        context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn,
-                                           log_attention_weights=position_bias, scaling_attention_score=False)
-        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-        new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
-        context_layer = context_layer.view(*new_context_layer_shape)
-        output = attn_module.dense(context_layer)
-
-        if attn_module.training:
-            output = attn_module.output_dropout(output)
-
-        kw_args['output_cross_layer']['position_bias'] = position_bias
-
-        return output 
-
-    def cross_attention_forward(self, hidden_states, cross_attention_mask, encoder_outputs, layer_id=None, *args,
-                                **kw_args):
-        attn_module = self.transformer.layers[layer_id].cross_attention
-        mixed_query_layer = attn_module.query(hidden_states)
-        mixed_x_layer = attn_module.key_value(encoder_outputs)
-        (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2)
-
-        dropout_fn = attn_module.attention_dropout if attn_module.training else None
-        # Reshape and transpose [b, np, s, hn]
-        query_layer = attn_module._transpose_for_scores(mixed_query_layer)
-        key_layer = attn_module._transpose_for_scores(mixed_key_layer)
-        value_layer = attn_module._transpose_for_scores(mixed_value_layer)
-
-        context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn,
-                                           scaling_attention_score=False)
-        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-        new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
-        # [b, s, hp]
-        context_layer = context_layer.view(*new_context_layer_shape)
-
-        # Output. [b, s, h]
-        output = attn_module.dense(context_layer)
-        if attn_module.training:
-            output = attn_module.output_dropout(output)
-
-        return output
+    @non_conflict
+    def attention_fn(self, q, k, v, mask, dropout_fn, position_bias=None, old_impl=standard_attention,
+                     cross_attention=False, **kw_args):
+        log_attention_weights = None
+        if not cross_attention:
+            if position_bias is None:
+                seq_length = q.size(2)
+                key_length = k.size(2)
+                position_bias = self.compute_bias(key_length, key_length)
+                position_bias = position_bias[:, :, -seq_length:, :]
+            kw_args['output_cross_layer']['position_bias'] = position_bias
+            log_attention_weights = position_bias
+        return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, position_bias=position_bias,
+                        log_attention_weights=log_attention_weights, scaling_attention_score=False, **kw_args)
 
 
 class T5DecoderFinalMixin(BaseMixin):
-    def __init__(self, hidden_size):
+    def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True):
         super().__init__()
         self.hidden_size = hidden_size
+        self.tie_word_embeddings = tie_word_embeddings
+        if not tie_word_embeddings:
+            self.lm_head = VocabParallelEmbedding(
+                vocab_size, hidden_size, init_method=unscaled_init_method(0.02))
 
     def final_forward(self, logits, **kwargs):
         logits_parallel = copy_to_model_parallel_region(logits)
-        logits_parallel = logits_parallel * (self.hidden_size ** -0.5)
-        logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight)
+        if self.tie_word_embeddings:
+            logits_parallel = logits_parallel * (self.hidden_size ** -0.5)
+            logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight)
+        else:
+            logits_parallel = F.linear(logits_parallel, self.lm_head.weight)
         return logits_parallel
 
 
+def t5_gelu(x):
+    """
+    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
+    the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+    """
+    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
+
+
+class T5GatedGeluMLPMixin(BaseMixin):
+    def __init__(self, num_layers, hidden_size, inner_hidden_size=None, bias=True, init_method_std=0.02):
+        super().__init__()
+        self.hidden_size = hidden_size
+        if inner_hidden_size is None:
+            inner_hidden_size = 4 * hidden_size
+        self.inner_hidden_size = inner_hidden_size
+        self.init_method_std = init_method_std
+        self.gated_h_to_4h_list = torch.nn.ModuleList([
+            ColumnParallelLinear(
+                self.hidden_size,
+                self.inner_hidden_size,
+                gather_output=False,
+                init_method=self._init_weights,
+                bias=bias,
+                module=self,
+                name="gated_h_to_4h"
+            )
+            for layer_id in range(num_layers)])
+
+    def _init_weights(self, weight, **kwargs):
+        torch.nn.init.normal_(weight, mean=0, std=self.init_method_std * (self.hidden_size ** -0.5))
+
+    def mlp_forward(self, hidden_states, layer_id=None, **kw_args):
+        mlp_module = self.transformer.layers[layer_id].mlp
+        hidden_gelu = t5_gelu(mlp_module.dense_h_to_4h(hidden_states))
+        hidden_linear = self.gated_h_to_4h_list[layer_id](hidden_states)
+        hidden_states = hidden_gelu * hidden_linear
+        output = mlp_module.dense_4h_to_h(hidden_states)
+
+        if self.training:
+            output = mlp_module.dropout(output)
+        return output
+
+
 class T5Model(EncoderDecoderModel):
     def __init__(self, args, **kwargs):
         self.init_method_std = args.init_method_std
@@ -205,10 +211,19 @@ class T5Model(EncoderDecoderModel):
             "t5-position", T5PositionEmbeddingMixin()
         )
         self.decoder.add_mixin(
-            "t5-final", T5DecoderFinalMixin(args.hidden_size)
+            "t5-final",
+            T5DecoderFinalMixin(args.vocab_size, args.hidden_size, tie_word_embeddings=not args.no_share_embeddings)
         )
         del self.decoder.transformer.position_embeddings
-    
+        if args.gated_gelu_mlp:
+            self.encoder.add_mixin(
+                "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std,
+                                                 inner_hidden_size=args.inner_hidden_size, bias=False)
+            )
+            self.decoder.add_mixin(
+                "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std,
+                                                 inner_hidden_size=args.inner_hidden_size, bias=False)
+            )
 
     def _init_weights(self, weight, module, name):
         init_method_std = self.init_method_std
@@ -246,6 +261,8 @@ class T5Model(EncoderDecoderModel):
         super().add_model_specific_args(parser)
         parser.add_argument("--relative-attention-num-buckets", type=int, default=None)
         parser.add_argument("--init-method-std", type=float, default=0.02)
+        parser.add_argument("--gated-gelu-mlp", action='store_true')
+        parser.add_argument("--no-share-embeddings", action='store_true')
 
     def encode(self, input_ids, attention_mask=None, **kw_args):
         return super().encode(input_ids, None, attention_mask, **kw_args)
@@ -254,7 +271,7 @@ class T5Model(EncoderDecoderModel):
         return super().decode(input_ids, None, attention_mask, encoder_outputs=encoder_outputs,
                               cross_attention_mask=cross_attention_mask, **kw_args)
 
-    def forward(self, enc_input_ids, dec_input_ids, dec_attention_mask, *, enc_attention_mask=None,
+    def forward(self, enc_input_ids, dec_input_ids, *, enc_attention_mask=None, dec_attention_mask=None,
                 cross_attention_mask=None, **kw_args):
         batch_size, seq_length = enc_input_ids.size()[:2]
         if enc_attention_mask is None:
diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index 98e42a5ac7e4a9dc3050fc677e90dc6d8f11de6b..289056948feff75964127fd5a2e4b0f4e8ed3944 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -218,6 +218,10 @@ class CrossAttention(torch.nn.Module):
         if 'cross_attention_forward' in self.hooks:
             return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, **kw_args)
         else:
+            attention_fn = standard_attention
+            if 'attention_fn' in self.hooks:
+                attention_fn = self.hooks['attention_fn']
+
             mixed_query_layer = self.query(hidden_states)
             mixed_x_layer = self.key_value(encoder_outputs)
             (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2)
@@ -228,7 +232,8 @@ class CrossAttention(torch.nn.Module):
             key_layer = self._transpose_for_scores(mixed_key_layer)
             value_layer = self._transpose_for_scores(mixed_value_layer)
 
-            context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn)
+            context_layer = attention_fn(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn,
+                                         cross_attention=True, **kw_args)
             context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
             new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
             # [b, s, hp]
@@ -394,7 +399,7 @@ class BaseTransformerLayer(torch.nn.Module):
         if self.is_decoder:
             encoder_outputs = kw_args['encoder_outputs']
             if encoder_outputs is not None:
-                cross_attention_mask = kw_args['cross_attention_mask']
+                assert 'cross_attention_mask' in kw_args
                 # Cross attention
                 attention_output = self.cross_attention(layernorm_output, **kw_args)
                 # Residual connection.
@@ -504,7 +509,7 @@ class BaseTransformer(torch.nn.Module):
             )  # None means full attention
         assert len(attention_mask.shape) == 2 or \
                len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
-    
+
         # embedding part
         if 'word_embedding_forward' in self.hooks:
             hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args)
@@ -526,7 +531,7 @@ class BaseTransformer(torch.nn.Module):
             output_cross_layer = self.hooks['cross_layer_embedding_forward'](hidden_states, **kw_args)
         else:
             output_cross_layer = {}
-            
+
         output_per_layers = []
         if self.checkpoint_activations:
             # define custom_forward for checkpointing
@@ -534,7 +539,7 @@ class BaseTransformer(torch.nn.Module):
                 def custom_forward(*inputs):
                     layers_ = self.layers[start:end]
                     x_, mask = inputs[0], inputs[1]
-                    
+
                     # recover kw_args and output_cross_layer
                     flat_inputs = inputs[2:]
                     kw_args, output_cross_layer = {}, {}
@@ -543,19 +548,19 @@ class BaseTransformer(torch.nn.Module):
                     for k, idx in cross_layer_index.items():
                         output_cross_layer[k] = flat_inputs[idx]
                     # -----------------
-                    
+
                     output_per_layers_part = []
                     for i, layer in enumerate(layers_):
                         if 'layer_forward' in self.hooks:
                             layer_ret = self.hooks['layer_forward'](
-                                x_, mask, layer_id=layer.layer_id, 
-                                **kw_args, **output_cross_layer, 
+                                x_, mask, layer_id=layer.layer_id,
+                                **kw_args, **output_cross_layer,
                                 output_this_layer={}, output_cross_layer={}
                             )
                         else:
                             layer_ret = layer(
-                                x_, mask, layer_id=layer.layer_id, 
-                                **kw_args, **output_cross_layer, 
+                                x_, mask, layer_id=layer.layer_id,
+                                **kw_args, **output_cross_layer,
                                 output_this_layer={}, output_cross_layer={}
                             )
                         if torch.is_tensor(layer_ret): # only hidden_states
@@ -563,7 +568,7 @@ class BaseTransformer(torch.nn.Module):
                         elif len(layer_ret) == 2: # hidden_states & output_this_layer
                             x_, output_this_layer = layer_ret
                             output_cross_layer = {}
-                        elif len(layer_ret) == 3: 
+                        elif len(layer_ret) == 3:
                             x_, output_this_layer, output_cross_layer = layer_ret
                         assert isinstance(output_this_layer, dict)
                         assert isinstance(output_cross_layer, dict)
@@ -590,7 +595,7 @@ class BaseTransformer(torch.nn.Module):
             # To save memory when only finetuning the final layers, don't use checkpointing.
             if self.training:
                 hidden_states.requires_grad_(True)
-            
+
             l, num_layers = 0, len(self.layers)
             chunk_length = self.checkpoint_num_layers
             output_this_layer = []
@@ -625,20 +630,20 @@ class BaseTransformer(torch.nn.Module):
                 args = [hidden_states, attention_mask]
 
                 if 'layer_forward' in self.hooks: # customized layer_forward
-                    layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), 
-                        **kw_args, 
-                        **output_cross_layer, 
+                    layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i),
+                        **kw_args,
+                        **output_cross_layer,
                         output_this_layer={}, output_cross_layer={}
                     )
                 else:
-                    layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer, 
+                    layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer,
                         output_this_layer={}, output_cross_layer={})
                 if torch.is_tensor(layer_ret): # only hidden_states
                     hidden_states, output_this_layer, output_cross_layer = layer_ret, {}, {}
                 elif len(layer_ret) == 2: # hidden_states & output_this_layer
                     hidden_states, output_this_layer = layer_ret
                     output_cross_layer = {}
-                elif len(layer_ret) == 3: 
+                elif len(layer_ret) == 3:
                     hidden_states, output_this_layer, output_cross_layer = layer_ret
 
                 if output_hidden_states:
diff --git a/examples/t5/test_t5.py b/examples/t5/test_t5.py
index a91956cbb69f31c079a06fc97e35ca492a40605c..41b9d2ca84320ab87a6e0eaa3c3b8912b755e014 100644
--- a/examples/t5/test_t5.py
+++ b/examples/t5/test_t5.py
@@ -1,7 +1,12 @@
 from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer
-tokenizer = T5Tokenizer.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large")
-model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large")
-input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
-decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
-output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
-breakpoint()
\ No newline at end of file
+device = 'cuda:1'
+tokenizer = T5Tokenizer.from_pretrained("t5-large")
+model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-xl-lm-adapt")
+model = model.to(device)
+model.eval()
+input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids.to(device)
+decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids.to(device)
+breakpoint()
+output = model(input_ids=input_ids, labels=decoder_input_ids)
+output.loss.backward()
+a = 1
\ No newline at end of file