diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py
index 5e5793e071e5cec6308f81093ab374f11e737961..8dfb9899e85d1d6850d6b1028287460e846d5d53 100644
--- a/SwissArmyTransformer/model/base_model.py
+++ b/SwissArmyTransformer/model/base_model.py
@@ -87,7 +87,7 @@ class BaseModel(torch.nn.Module):
 
     def collect_hooks_(self):
         names = ['word_embedding_forward', 'position_embedding_forward',
-                 'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
+                 'attention_forward', 'cross_attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
                  'branch_embedding_forward', 'branch_final_forward'
                  ]
         hooks = {}
diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py
index 1beace877aaa5aaa7bbe84fc8ffea500fa03bd72..a31fa6e886f53201f0e00c0c81f919d4ecbac0f0 100644
--- a/SwissArmyTransformer/model/encoder_decoder_model.py
+++ b/SwissArmyTransformer/model/encoder_decoder_model.py
@@ -56,6 +56,8 @@ def get_extended_attention_mask(attention_mask, input_shape, device, dtype=torch
                          causal_mask], axis=-1)
 
                 extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+            else:
+                extended_attention_mask = causal_mask[:, None, :, :]
         else:
             if attention_mask is None:
                 extended_attention_mask = torch.ones(1, 1, 1, seq_length, device=device, dtype=dtype)
@@ -108,16 +110,20 @@ class EncoderDecoderModel(torch.nn.Module):
         self.decoder.disable_untrainable_params()
 
     def forward(self, input_ids=None, input_position_ids=None, attention_mask=None, decoder_input_ids=None,
-                decoder_position_ids=None, decoder_attention_mask=None,
+                decoder_position_ids=None, decoder_attention_mask=None, encoder_outputs=None,
                 **kw_args):
         dtype = self.encoder.transformer.word_embeddings.weight.dtype
-        batch_size, encoder_seq_length = input_ids.size()[:2]
+        if encoder_outputs is None:
+            batch_size, encoder_seq_length = input_ids.size()[:2]
+        else:
+            batch_size, encoder_seq_length = encoder_outputs.size()[:2]
         encoder_attention_mask = get_extended_attention_mask(attention_mask, (batch_size, encoder_seq_length),
                                                              device=input_ids.device, dtype=dtype)
         decoder_seq_length = decoder_input_ids.size(1)
-        encoder_outputs, *_dumps = self.encoder(input_ids, input_position_ids, encoder_attention_mask, **kw_args)
+        if encoder_outputs is None:
+            encoder_outputs, *_dumps = self.encoder(input_ids, input_position_ids, encoder_attention_mask, **kw_args)
         decoder_attention_mask = get_extended_attention_mask(decoder_attention_mask, (batch_size, decoder_seq_length),
-                                                             device=input_ids.device, dtype=dtype)
+                                                             device=input_ids.device, dtype=dtype, is_decoder=True)
         decoder_outputs, *decoder_mems = self.decoder(decoder_input_ids, decoder_position_ids, decoder_attention_mask,
                                                       encoder_outputs=encoder_outputs,
                                                       cross_attention_mask=encoder_attention_mask, **kw_args)
diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
index bf26a3b81bfa41c1978aaf8bb9ab17a921cb0384..bf4f2c9ff0fbfdcd56d024dcb4faab3f5fe89a8a 100644
--- a/SwissArmyTransformer/model/t5_model.py
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -1,9 +1,11 @@
 import math
 import torch
+import torch.nn.functional as F
 from .mixins import BaseMixin
 from .encoder_decoder_model import EncoderDecoderModel
 from SwissArmyTransformer.mpu import get_model_parallel_world_size
 from SwissArmyTransformer.mpu.transformer import standard_attention
+from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
 from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim
 
 
@@ -164,9 +166,22 @@ class T5AttentionMixin(BaseMixin):
         return output
 
 
+class T5DecoderFinalMixin(BaseMixin):
+    def __init__(self, hidden_size):
+        super().__init__()
+        self.hidden_size = hidden_size
+
+    def final_forward(self, logits, **kwargs):
+        logits_parallel = copy_to_model_parallel_region(logits)
+        logits_parallel = logits_parallel * (self.hidden_size ** -0.5)
+        logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight)
+        return logits_parallel
+
+
 class T5Model(EncoderDecoderModel):
     def __init__(self, args, **kwargs):
-        super().__init__(args, **kwargs, use_bias=False, layernorm=T5LayerNorm)
+        super().__init__(args, **kwargs, use_bias=False, layernorm=T5LayerNorm,
+                         activation_func=torch.nn.functional.relu)
         self.encoder.add_mixin(
             "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads)
         )
@@ -181,6 +196,9 @@ class T5Model(EncoderDecoderModel):
         self.decoder.add_mixin(
             "t5-position", T5PositionEmbeddingMixin()
         )
+        self.decoder.add_mixin(
+            "t5-final", T5DecoderFinalMixin(args.hidden_size)
+        )
         del self.decoder.transformer.position_embeddings
         self.decoder.transformer.word_embeddings = self.encoder.transformer.word_embeddings
 
diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index 7a3234b4340f45f52f74082bb34cc86a25818684..d58145552635d26e245382d3c7585ef825db723f 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -198,11 +198,11 @@ class CrossAttention(torch.nn.Module):
         tensor = tensor.view(*new_tensor_shape)
         return tensor.permute(0, 2, 1, 3)
 
-    def forward(self, hidden_states, encoder_outputs, *args, cross_attention_mask=None, **kw_args):
+    def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *args, **kw_args):
         # hidden_states: [b, s, h]
         if 'cross_attention_forward' in self.hooks:
-            return self.hooks['cross_attention_forward'](hidden_states, encoder_outputs,
-                                                         cross_attention_mask=cross_attention_mask, **kw_args,
+            return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs,
+                                                         **kw_args,
                                                          layer_id=self.layer_id)
         else:
             mixed_query_layer = self.query(hidden_states)
@@ -231,9 +231,10 @@ class CrossAttention(torch.nn.Module):
 
 class MLP(torch.nn.Module):
     def __init__(self, hidden_size, output_dropout_prob, init_method, inner_hidden_size=None,
-                 output_layer_init_method=None, layer_id=None, hooks={}, bias=True):
+                 output_layer_init_method=None, layer_id=None, hooks={}, bias=True, activation_func=gelu):
         super(MLP, self).__init__()
         self.layer_id = layer_id
+        self.activation_func = activation_func
         # Set output layer initialization if not provided.
         if output_layer_init_method is None:
             output_layer_init_method = init_method
@@ -263,7 +264,7 @@ class MLP(torch.nn.Module):
             output = self.hooks['mlp_forward'](hidden_states, **kw_args, layer_id=self.layer_id)
         else:
             intermediate_parallel = self.dense_h_to_4h(hidden_states)
-            intermediate_parallel = gelu(intermediate_parallel)
+            intermediate_parallel = self.activation_func(intermediate_parallel)
             output = self.dense_4h_to_h(intermediate_parallel)
 
         if self.training:
@@ -288,6 +289,7 @@ class BaseTransformerLayer(torch.nn.Module):
             layernorm=LayerNorm,
             is_decoder=False,
             use_bias=True,
+            activation_func=gelu,
             hooks={}
     ):
         super(BaseTransformerLayer, self).__init__()
@@ -347,10 +349,11 @@ class BaseTransformerLayer(torch.nn.Module):
             output_layer_init_method=output_layer_init_method,
             bias=use_bias,
             layer_id=layer_id,
+            activation_func=activation_func,
             hooks=hooks
         )
 
-    def forward(self, hidden_states, mask, encoder_outputs=None, **kw_args):
+    def forward(self, hidden_states, mask, encoder_outputs=None, cross_attention_mask=None, **kw_args):
         '''
             hidden_states: [batch, seq_len, hidden_size]
             mask: [(1, 1), seq_len, seq_len]
@@ -372,9 +375,9 @@ class BaseTransformerLayer(torch.nn.Module):
 
         if self.is_decoder and encoder_outputs is not None:
             # Cross attention
-            attention_output = self.cross_attention(layernorm_output, encoder_outputs, **kw_args)
+            attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args)
             # Residual connection.
-            layernorm_input = layernorm_output + attention_output
+            layernorm_input = layernorm_input + attention_output
             # Layer norm post the cross attention
             layernorm_output = self.post_cross_attention_layernorm(layernorm_input)
 
@@ -411,6 +414,7 @@ class BaseTransformer(torch.nn.Module):
                  parallel_output=True,
                  is_decoder=False,
                  use_bias=True,
+                 activation_func=gelu,
                  layernorm=LayerNorm,
                  hooks={}
                  ):
@@ -453,6 +457,7 @@ class BaseTransformer(torch.nn.Module):
                 sandwich_ln=sandwich_ln,
                 layernorm=layernorm,
                 use_bias=use_bias,
+                activation_func=activation_func,
                 hooks=self.hooks
             )
 
@@ -464,7 +469,7 @@ class BaseTransformer(torch.nn.Module):
 
     def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, encoder_outputs=None,
                 output_hidden_states=False, **kw_args):
-        # sanity check 
+        # sanity check
         assert len(input_ids.shape) == 2
         batch_size, query_length = input_ids.shape
         assert len(attention_mask.shape) == 2 or \
@@ -504,7 +509,11 @@ class BaseTransformer(torch.nn.Module):
                     x_, mask, encoder_outputs_ = inputs[0], inputs[1], inputs[2]
                     output_per_layers_part = []
                     for i, layer in enumerate(layers_):
-                        if 'layer_forward' in self.hooks:
+                        if branch_input is not None:
+                            x_, encoder_outputs_, output_this_layer = self.hooks['layer_forward'](
+                                x_, mask, layer_id=layer.layer_id, branch_input=encoder_outputs_, **kw_args
+                            )
+                        elif 'layer_forward' in self.hooks:
                             x_, output_this_layer = self.hooks['layer_forward'](
                                 x_, mask, encoder_outputs_, layer_id=layer.layer_id, **kw_args
                             )
@@ -518,11 +527,11 @@ class BaseTransformer(torch.nn.Module):
             l, num_layers = 0, len(self.layers)
             chunk_length = self.checkpoint_num_layers
             while l < num_layers:
-                args = [hidden_states, attention_mask, encoder_outputs]
                 if branch_input is not None:
-                    hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args,
-                                                                                     branch_input)
+                    args = [hidden_states, attention_mask, branch_input]
+                    hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args)
                 else:
+                    args = [hidden_states, attention_mask, encoder_outputs]
                     hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args)
                 if output_hidden_states:
                     hidden_states_outputs.append(hidden_states)