diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
index cfe97ca6c120316d44c79d988f348ef78450f26e..ddcc8fdeb8f9368d55a4df59f4dd589b4ab0e024 100644
--- a/SwissArmyTransformer/model/t5_model.py
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -6,7 +6,8 @@ from .encoder_decoder_model import EncoderDecoderModel
 from SwissArmyTransformer.mpu import get_model_parallel_world_size
 from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP
 from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
-from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim
+from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim, unscaled_init_method
+from SwissArmyTransformer.mpu.layers import ColumnParallelLinear, VocabParallelEmbedding
 
 
 class T5PositionEmbeddingMixin(BaseMixin):
@@ -142,7 +143,7 @@ class T5AttentionMixin(BaseMixin):
 
         kw_args['output_cross_layer']['position_bias'] = position_bias
 
-        return output 
+        return output
 
     def cross_attention_forward(self, hidden_states, cross_attention_mask, encoder_outputs, layer_id=None, *args,
                                 **kw_args):
@@ -173,17 +174,63 @@ class T5AttentionMixin(BaseMixin):
 
 
 class T5DecoderFinalMixin(BaseMixin):
-    def __init__(self, hidden_size):
+    def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True):
         super().__init__()
         self.hidden_size = hidden_size
 
     def final_forward(self, logits, **kwargs):
         logits_parallel = copy_to_model_parallel_region(logits)
-        logits_parallel = logits_parallel * (self.hidden_size ** -0.5)
-        logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight)
+        if self.tie_word_embeddings:
+            logits_parallel = logits_parallel * (self.hidden_size ** -0.5)
+            logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight)
+        else:
+            logits_parallel = F.linear(logits_parallel, self.lm_head.weight)
         return logits_parallel
 
 
+def t5_gelu(x):
+    """
+    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
+    the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+    """
+    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
+
+
+class T5GatedGeluMLPMixin(BaseMixin):
+    def __init__(self, num_layers, hidden_size, inner_hidden_size=None, bias=True, init_method_std=0.02):
+        super().__init__()
+        self.hidden_size = hidden_size
+        if inner_hidden_size is None:
+            inner_hidden_size = 4 * hidden_size
+        self.inner_hidden_size = inner_hidden_size
+        self.init_method_std = init_method_std
+        self.gated_h_to_4h_list = torch.nn.ModuleList([
+            ColumnParallelLinear(
+                self.hidden_size,
+                self.inner_hidden_size,
+                gather_output=False,
+                init_method=self._init_weights,
+                bias=bias,
+                module=self,
+                name="gated_h_to_4h"
+            )
+            for layer_id in range(num_layers)])
+
+    def _init_weights(self, weight, **kwargs):
+        torch.nn.init.normal_(weight, mean=0, std=self.init_method_std * (self.hidden_size ** -0.5))
+
+    def mlp_forward(self, hidden_states, layer_id=None, **kw_args):
+        mlp_module = self.transformer.layers[layer_id].mlp
+        hidden_gelu = t5_gelu(mlp_module.dense_h_to_4h(hidden_states))
+        hidden_linear = self.gated_h_to_4h_list[layer_id](hidden_states)
+        hidden_states = hidden_gelu * hidden_linear
+        output = mlp_module.dense_4h_to_h(hidden_states)
+
+        if self.training:
+            output = mlp_module.dropout(output)
+        return output
+
+
 class T5Model(EncoderDecoderModel):
     def __init__(self, args, **kwargs):
         self.init_method_std = args.init_method_std
@@ -205,10 +252,19 @@ class T5Model(EncoderDecoderModel):
             "t5-position", T5PositionEmbeddingMixin()
         )
         self.decoder.add_mixin(
-            "t5-final", T5DecoderFinalMixin(args.hidden_size)
+            "t5-final",
+            T5DecoderFinalMixin(args.vocab_size, args.hidden_size, tie_word_embeddings=not args.no_share_embeddings)
         )
         del self.decoder.transformer.position_embeddings
-    
+        if args.gated_gelu_mlp:
+            self.encoder.add_mixin(
+                "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std,
+                                                 inner_hidden_size=args.inner_hidden_size, bias=False)
+            )
+            self.decoder.add_mixin(
+                "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std,
+                                                 inner_hidden_size=args.inner_hidden_size, bias=False)
+            )
 
     def _init_weights(self, weight, module, name):
         init_method_std = self.init_method_std
@@ -246,6 +302,8 @@ class T5Model(EncoderDecoderModel):
         super().add_model_specific_args(parser)
         parser.add_argument("--relative-attention-num-buckets", type=int, default=None)
         parser.add_argument("--init-method-std", type=float, default=0.02)
+        parser.add_argument("--gated-gelu-mlp", action='store_true')
+        parser.add_argument("--no-share-embeddings", action='store_true')
 
     def encode(self, input_ids, attention_mask=None, **kw_args):
         return super().encode(input_ids, None, attention_mask, **kw_args)
diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index 3b4ae73cc6609178af67564198be9b15038ab9d2..d9d3181abbdd5ab288d1e3c81f7d8be31fd59969 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -158,7 +158,7 @@ class SelfAttention(torch.nn.Module):
             if self.training:
                 output = self.output_dropout(output)
 
-            return output
+            return output, None
 
 
 class CrossAttention(torch.nn.Module):
diff --git a/examples/t5/test_t5.py b/examples/t5/test_t5.py
index a91956cbb69f31c079a06fc97e35ca492a40605c..41b9d2ca84320ab87a6e0eaa3c3b8912b755e014 100644
--- a/examples/t5/test_t5.py
+++ b/examples/t5/test_t5.py
@@ -1,7 +1,12 @@
 from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer
-tokenizer = T5Tokenizer.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large")
-model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large")
-input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
-decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
-output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
-breakpoint()
\ No newline at end of file
+device = 'cuda:1'
+tokenizer = T5Tokenizer.from_pretrained("t5-large")
+model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-xl-lm-adapt")
+model = model.to(device)
+model.eval()
+input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids.to(device)
+decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids.to(device)
+breakpoint()
+output = model(input_ids=input_ids, labels=decoder_input_ids)
+output.loss.backward()
+a = 1
\ No newline at end of file