diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
index 531ec54b0062fc616d46a42652a8d486a4659e33..7143950772587d88a45960668215a7c604ea1f36 100644
--- a/SwissArmyTransformer/model/t5_model.py
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -4,7 +4,7 @@ import torch.nn.functional as F
 from .mixins import BaseMixin
 from .encoder_decoder_model import EncoderDecoderModel
 from SwissArmyTransformer.mpu import get_model_parallel_world_size
-from SwissArmyTransformer.mpu.transformer import standard_attention
+from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP
 from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
 from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim
 
@@ -28,9 +28,11 @@ class T5LayerNorm(torch.nn.Module):
         variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
         hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 
-        # convert into float16 if necessary
+        # convert into float16 or bfloat16 if necessary
         if self.weight.dtype == torch.float16:
             hidden_states = hidden_states.to(torch.float16)
+        elif self.weight.dtype == torch.bfloat16:
+            hidden_states = hidden_states.to(torch.bfloat16)
         return self.weight * hidden_states
 
 
@@ -111,7 +113,7 @@ class T5AttentionMixin(BaseMixin):
         values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
         return values
 
-    def attention_forward(self, hidden_states, mask, *args, layer_id=None, mems=None, **kw_args):
+    def attention_forward(self, hidden_states, mask, position_bias=None, *args, layer_id=None, mems=None, **kw_args):
         attn_module = self.transformer.layers[layer_id].attention
         seq_length = hidden_states.size(1)
         memory_length = mems[layer_id].size(1) if mems else 0
@@ -126,7 +128,8 @@ class T5AttentionMixin(BaseMixin):
         key_layer = attn_module._transpose_for_scores(mixed_key_layer)
         value_layer = attn_module._transpose_for_scores(mixed_value_layer)
 
-        position_bias = self.compute_bias(seq_length, memory_length + seq_length)
+        if position_bias is None:
+            position_bias = self.compute_bias(seq_length, memory_length + seq_length)
         context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn,
                                            log_attention_weights=position_bias, scaling_attention_score=False)
         context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
@@ -137,9 +140,10 @@ class T5AttentionMixin(BaseMixin):
         if attn_module.training:
             output = attn_module.output_dropout(output)
 
-        return output, None
+        return output, position_bias
 
-    def cross_attention_forward(self, hidden_states, cross_mask, encoder_outputs, layer_id=None, *args, **kw_args):
+    def cross_attention_forward(self, hidden_states, cross_attention_mask, encoder_outputs, layer_id=None, *args,
+                                **kw_args):
         attn_module = self.transformer.layers[layer_id].cross_attention
         mixed_query_layer = attn_module.query(hidden_states)
         mixed_x_layer = attn_module.key_value(encoder_outputs)
@@ -151,7 +155,7 @@ class T5AttentionMixin(BaseMixin):
         key_layer = attn_module._transpose_for_scores(mixed_key_layer)
         value_layer = attn_module._transpose_for_scores(mixed_value_layer)
 
-        context_layer = standard_attention(query_layer, key_layer, value_layer, cross_mask, dropout_fn,
+        context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn,
                                            scaling_attention_score=False)
         context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
         new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
@@ -180,8 +184,10 @@ class T5DecoderFinalMixin(BaseMixin):
 
 class T5Model(EncoderDecoderModel):
     def __init__(self, args, **kwargs):
+        self.init_method_std = args.init_method_std
         super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False,
-        layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu)
+                         layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu,
+                         init_method=self._init_weights)
         self.encoder.add_mixin(
             "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads)
         )
@@ -201,10 +207,42 @@ class T5Model(EncoderDecoderModel):
         )
         del self.decoder.transformer.position_embeddings
 
+    def _init_weights(self, weight, module, name):
+        init_method_std = self.init_method_std
+        if isinstance(module, MLP):
+            if name == "dense_h_to_4h":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
+            elif name == "dense_4h_to_h":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
+            else:
+                raise NotImplementedError(name)
+        elif isinstance(module, SelfAttention):
+            if name == "query_key_value":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
+                torch.nn.init.normal_(weight[:module.inner_hidden_size], mean=0, std=init_method_std * (
+                        (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5))
+            elif name == "dense":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
+            else:
+                raise NotImplementedError(name)
+        elif isinstance(module, CrossAttention):
+            if name == "query":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (
+                        (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5))
+            elif name == "key_value":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
+            elif name == "dense":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
+            else:
+                raise NotImplementedError(name)
+        else:
+            raise NotImplementedError(module)
+
     @classmethod
     def add_model_specific_args(cls, parser):
         super().add_model_specific_args(parser)
         parser.add_argument("--relative-attention-num-buckets", type=int, default=None)
+        parser.add_argument("--init-method-std", type=float, default=0.02)
 
     def encode(self, input_ids, attention_mask=None, **kw_args):
         return super().encode(input_ids, None, attention_mask, **kw_args)
diff --git a/SwissArmyTransformer/mpu/layers.py b/SwissArmyTransformer/mpu/layers.py
index 4af3f33180aae46bba28efa213a9a941df5cd3ce..8d7cf370e75f2629ae96b1b2f442f4dfe6bae70f 100755
--- a/SwissArmyTransformer/mpu/layers.py
+++ b/SwissArmyTransformer/mpu/layers.py
@@ -37,7 +37,7 @@ from .utils import VocabUtility
 
 def _initialize_affine_weight(weight, output_size, input_size,
                               per_partition_size, partition_dim, init_method,
-                              stride=1, return_master_weight=False):
+                              stride=1, return_master_weight=False, module=None, name=None):
     """Initialize affine weight for model parallel.
 
     Build the master weight on all processes and scatter
@@ -45,7 +45,7 @@ def _initialize_affine_weight(weight, output_size, input_size,
     # If we only use 1 process for model parallelism, bypass scatter.
     world_size = get_model_parallel_world_size()
     if world_size == 1:
-        init_method(weight)
+        init_method(weight, module=module, name=name)
         if return_master_weight:
             return weight
         return None
@@ -54,7 +54,7 @@ def _initialize_affine_weight(weight, output_size, input_size,
     master_weight = torch.empty(output_size, input_size,
                                 dtype=weight.dtype,
                                 requires_grad=False)
-    init_method(master_weight)
+    init_method(master_weight, module=module, name=name)
 
     # Split and copy
     per_partition_per_stride_size = divide(per_partition_size, stride)
@@ -200,7 +200,7 @@ class ColumnParallelLinear(torch.nn.Module):
     """
     def __init__(self, input_size, output_size, bias=True, gather_output=True,
                  init_method=init.xavier_normal_, stride=1,
-                 keep_master_weight_for_test=False):
+                 keep_master_weight_for_test=False, module=None, name=None):
         super(ColumnParallelLinear, self).__init__()
 
         # Keep input parameters
@@ -230,7 +230,7 @@ class ColumnParallelLinear(torch.nn.Module):
         self.master_weight = _initialize_affine_weight(
             self.weight, self.output_size, self.input_size,
             self.output_size_per_partition, 0, init_method,
-            stride=stride, return_master_weight=keep_master_weight_for_test)
+            stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name)
 
     def forward(self, input_):
         # Set up backprop all-reduce.
@@ -274,7 +274,7 @@ class RowParallelLinear(torch.nn.Module):
     def __init__(self, input_size, output_size, bias=True,
                  input_is_parallel=False,
                  init_method=init.xavier_normal_, stride=1,
-                 keep_master_weight_for_test=False):
+                 keep_master_weight_for_test=False, module=None, name=None):
         super(RowParallelLinear, self).__init__()
 
         # Keep input parameters
@@ -303,7 +303,7 @@ class RowParallelLinear(torch.nn.Module):
         self.master_weight = _initialize_affine_weight(
             self.weight, self.output_size, self.input_size,
             self.input_size_per_partition, 1, init_method,
-            stride=stride, return_master_weight=keep_master_weight_for_test)
+            stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name)
 
     def forward(self, input_):
         # Set up backprop all-reduce.
diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index 855e093ada69e1168b565f78f15eeaee313326e5..cab65a4623060d3f362a6edc31d8d99a51e734ee 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -22,11 +22,12 @@ import torch
 import torch.nn.functional as F
 from apex.normalization.fused_layer_norm import FusedLayerNorm
 
+from SwissArmyTransformer import mpu
 from .initialize import get_model_parallel_world_size
 from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
 from .mappings import gather_from_model_parallel_region, copy_to_model_parallel_region
 
-from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint, get_cuda_rng_tracker
+from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
 
 from .utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu
 from .utils import split_tensor_along_last_dim
@@ -62,7 +63,10 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask,
     attention_probs = F.softmax(attention_scores, dim=-1)
 
     if attention_dropout is not None:
-        with get_cuda_rng_tracker().fork():
+        if mpu.get_cuda_rng_tracker is not None:
+            with mpu.get_cuda_rng_tracker().fork():
+                attention_probs = attention_dropout(attention_probs)
+        else:
             attention_probs = attention_dropout(attention_probs)
 
     context_layer = torch.matmul(attention_probs, value_layer)
@@ -82,31 +86,36 @@ class SelfAttention(torch.nn.Module):
         self.layer_id = layer_id
         # Per attention head and per partition values.
         world_size = get_model_parallel_world_size()
+        self.hidden_size = hidden_size
         if hidden_size_per_attention_head is None:
             self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
         else:
             self.hidden_size_per_attention_head = hidden_size_per_attention_head
         self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
-        inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
+        self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
         self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
 
         # Strided linear layer.
         self.query_key_value = ColumnParallelLinear(
             hidden_size,
-            3 * inner_hidden_size,
+            3 * self.inner_hidden_size,
             stride=3,
             gather_output=False,
             init_method=init_method,
-            bias=bias
+            bias=bias,
+            module=self,
+            name="query_key_value"
         )
         self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
 
         self.dense = RowParallelLinear(
-            inner_hidden_size,
+            self.inner_hidden_size,
             hidden_size,
             input_is_parallel=True,
             init_method=output_layer_init_method,
-            bias=bias
+            bias=bias,
+            module=self,
+            name="dense"
         )
         self.output_dropout = torch.nn.Dropout(output_dropout_prob)
 
@@ -122,7 +131,7 @@ class SelfAttention(torch.nn.Module):
 
     def forward(self, hidden_states, mask, *args, **kw_args):
         if 'attention_forward' in self.hooks:
-            return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id)
+            return self.hooks['attention_forward'](hidden_states, mask, *args, **kw_args, layer_id=self.layer_id)
         else:
             mixed_raw_layer = self.query_key_value(hidden_states)
             (mixed_query_layer,
@@ -144,7 +153,7 @@ class SelfAttention(torch.nn.Module):
             if self.training:
                 output = self.output_dropout(output)
 
-            return output, None
+            return output
 
 
 class CrossAttention(torch.nn.Module):
@@ -160,21 +169,22 @@ class CrossAttention(torch.nn.Module):
         self.layer_id = layer_id
         # Per attention head and per partition values.
         world_size = get_model_parallel_world_size()
+        self.hidden_size = hidden_size
         if hidden_size_per_attention_head is None:
             self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
         else:
             self.hidden_size_per_attention_head = hidden_size_per_attention_head
         self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
-        inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
+        self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
         self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
         # Strided linear layer.
-        self.query = ColumnParallelLinear(hidden_size, inner_hidden_size,
+        self.query = ColumnParallelLinear(hidden_size, self.inner_hidden_size,
                                           gather_output=False,
-                                          init_method=init_method, bias=bias)
-        self.key_value = ColumnParallelLinear(hidden_size, 2 * inner_hidden_size,
+                                          init_method=init_method, bias=bias, module=self, name="query")
+        self.key_value = ColumnParallelLinear(hidden_size, 2 * self.inner_hidden_size,
                                               stride=2,
                                               gather_output=False,
-                                              init_method=init_method, bias=bias)
+                                              init_method=init_method, bias=bias, module=self, name="key_value")
         # Dropout. Note that for a single iteration, this layer will generate
         # different outputs on different number of parallel partitions but
         # on average it should not be partition dependent.
@@ -182,10 +192,10 @@ class CrossAttention(torch.nn.Module):
 
         # Output.
         self.dense = RowParallelLinear(
-            inner_hidden_size,
+            self.inner_hidden_size,
             hidden_size,
             input_is_parallel=True,
-            init_method=output_layer_init_method, bias=bias)
+            init_method=output_layer_init_method, bias=bias, module=self, name="dense")
         self.output_dropout = torch.nn.Dropout(output_dropout_prob)
 
     def _transpose_for_scores(self, tensor):
@@ -198,7 +208,7 @@ class CrossAttention(torch.nn.Module):
         tensor = tensor.view(*new_tensor_shape)
         return tensor.permute(0, 2, 1, 3)
 
-    def forward(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args):
+    def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *args, **kw_args):
         # hidden_states: [b, s, h]
         if 'cross_attention_forward' in self.hooks:
             return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs,
@@ -240,22 +250,28 @@ class MLP(torch.nn.Module):
             output_layer_init_method = init_method
         self.hooks = hooks
         # Project to 4h.
+        self.hidden_size = hidden_size
         if inner_hidden_size is None:
             inner_hidden_size = 4 * hidden_size
+        self.inner_hidden_size = inner_hidden_size
         self.dense_h_to_4h = ColumnParallelLinear(
-            hidden_size,
-            inner_hidden_size,
+            self.hidden_size,
+            self.inner_hidden_size,
             gather_output=False,
             init_method=init_method,
-            bias=bias
+            bias=bias,
+            module=self,
+            name="dense_h_to_4h"
         )
         # Project back to h.
         self.dense_4h_to_h = RowParallelLinear(
-            inner_hidden_size,
-            hidden_size,
+            self.inner_hidden_size,
+            self.hidden_size,
             input_is_parallel=True,
             init_method=output_layer_init_method,
-            bias=bias
+            bias=bias,
+            module=self,
+            name="dense_4h_to_h"
         )
         self.dropout = torch.nn.Dropout(output_dropout_prob)
 
@@ -353,7 +369,7 @@ class BaseTransformerLayer(torch.nn.Module):
             hooks=hooks
         )
 
-    def forward(self, hidden_states, mask, **kw_args):
+    def forward(self, hidden_states, mask, encoder_outputs=None, *args, **kw_args):
         '''
             hidden_states: [batch, seq_len, hidden_size]
             mask: [(1, 1), seq_len, seq_len]
@@ -362,7 +378,7 @@ class BaseTransformerLayer(torch.nn.Module):
         # Layer norm at the begining of the transformer layer.
         layernorm_output1 = self.input_layernorm(hidden_states)
         # Self attention.
-        attention_output, output_this_layer = self.attention(layernorm_output1, mask, **kw_args)
+        attention_output, *output_this_layer = self.attention(layernorm_output1, mask, *args, **kw_args)
 
         # Third LayerNorm
         if self.sandwich_ln:
@@ -374,15 +390,14 @@ class BaseTransformerLayer(torch.nn.Module):
         layernorm_output = self.post_attention_layernorm(layernorm_input)
 
         # only for Encoder-Decoder, omit this for BERT-like or GPT-like models
-        if self.is_decoder and \
-            'encoder_outputs' in kw_args and kw_args['encoder_outputs'] is not None:
-                assert 'cross_attention_mask' in kw_args
-                # Cross attention
-                attention_output = self.cross_attention(layernorm_output, **kw_args)
-                # Residual connection.
-                layernorm_input = layernorm_input + attention_output
-                # Layer norm post the cross attention
-                layernorm_output = self.post_cross_attention_layernorm(layernorm_input)
+        if self.is_decoder and encoder_outputs is not None:
+            assert 'cross_attention_mask' in kw_args
+            # Cross attention
+            attention_output = self.cross_attention(layernorm_output, encoder_outputs=encoder_outputs, **kw_args)
+            # Residual connection.
+            layernorm_input = layernorm_input + attention_output
+            # Layer norm post the cross attention
+            layernorm_output = self.post_cross_attention_layernorm(layernorm_input)
 
         # MLP.
         mlp_output = self.mlp(layernorm_output, **kw_args)
@@ -394,7 +409,7 @@ class BaseTransformerLayer(torch.nn.Module):
         # Second residual connection.
         output = layernorm_input + mlp_output
 
-        return output, output_this_layer  # temporally, output_this_layer is only from attention
+        return output, *output_this_layer  # temporally, output_this_layer is only from attention
 
 
 class BaseTransformer(torch.nn.Module):
@@ -419,6 +434,7 @@ class BaseTransformer(torch.nn.Module):
                  use_bias=True,
                  activation_func=gelu,
                  layernorm=LayerNorm,
+                 init_method=None,
                  hooks={}
                  ):
         super(BaseTransformer, self).__init__()
@@ -441,8 +457,12 @@ class BaseTransformer(torch.nn.Module):
         torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
 
         # create all layers
-        self.output_layer_init_method = scaled_init_method(init_method_std, num_layers)
-        self.init_method = unscaled_init_method(init_method_std)
+        if init_method is None:
+            self.output_layer_init_method = scaled_init_method(init_method_std, num_layers)
+            self.init_method = unscaled_init_method(init_method_std)
+        else:
+            self.output_layer_init_method = init_method
+            self.init_method = init_method
 
         def get_layer(layer_id):
             return BaseTransformerLayer(
@@ -470,7 +490,7 @@ class BaseTransformer(torch.nn.Module):
         # Final layer norm before output.
         self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
 
-    def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, 
+    def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, encoder_outputs=None,
                 output_hidden_states=False, **kw_args):
         # sanity check
         assert len(input_ids.shape) == 2
@@ -478,7 +498,7 @@ class BaseTransformer(torch.nn.Module):
         if attention_mask is None:
             attention_mask = torch.ones(1, 1, device=input_ids.device).type_as(
                 next(self.parameters())
-            ) # None means full attention
+            )  # None means full attention
         assert len(attention_mask.shape) == 2 or \
                len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
         assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor)
@@ -513,23 +533,19 @@ class BaseTransformer(torch.nn.Module):
             def custom(start, end):
                 def custom_forward(*inputs):
                     layers_ = self.layers[start:end]
-                    x_, mask = inputs[0], inputs[1]
-                    if len(inputs) > 2: # have branch_input
-                        branch_ = inputs[2]
+                    x_, mask, encoder_outputs_ = inputs[0], inputs[1], inputs[2]
+                    output_this_layer = inputs[3:]
                     output_per_layers_part = []
                     for i, layer in enumerate(layers_):
-                        if len(inputs) > 2:
-                            x_, branch_, output_this_layer = self.hooks['layer_forward'](
-                                x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args
-                            )
-                        elif 'layer_forward' in self.hooks:
-                            x_, output_this_layer = self.hooks['layer_forward'](
-                                x_, mask, layer_id=layer.layer_id, **kw_args
+                        if 'layer_forward' in self.hooks:
+                            x_, *output_this_layer = self.hooks['layer_forward'](
+                                x_, mask, encoder_outputs_, *output_this_layer, layer_id=layer.layer_id, **kw_args
                             )
                         else:
-                            x_, output_this_layer = layer(x_, mask, **kw_args)
+                            x_, *output_this_layer = layer(x_, mask, encoder_outputs_, *output_this_layer, **kw_args)
                         output_per_layers_part.append(output_this_layer)
-                    return x_, output_per_layers_part
+                    return x_, output_per_layers_part, *output_this_layer
+
                 return custom_forward
 
             # prevent to lose requires_grad in checkpointing.
@@ -539,25 +555,24 @@ class BaseTransformer(torch.nn.Module):
 
             l, num_layers = 0, len(self.layers)
             chunk_length = self.checkpoint_num_layers
+            output_this_layer = []
             while l < num_layers:
-                args = [hidden_states, attention_mask]
-                if branch_input is not None:
-                    hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, branch_input)
-                else:
-                    hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args)
+                args = [hidden_states, attention_mask, encoder_outputs]
+                hidden_states, output_per_layers_part, *output_this_layer = checkpoint(custom(l, l + chunk_length),
+                                                                                       *args, *output_this_layer)
                 if output_hidden_states:
                     hidden_states_outputs.append(hidden_states)
                 output_per_layers.extend(output_per_layers_part)
                 l += chunk_length
         else:
+            output_this_layer = []
             for i, layer in enumerate(self.layers):
-                args = [hidden_states, attention_mask]
-                if branch_input is not None: # customized layer_forward with branch_input
-                    hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_args)
-                elif 'layer_forward' in self.hooks: # customized layer_forward
-                    hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args)
+                args = [hidden_states, attention_mask, encoder_outputs]
+                if 'layer_forward' in self.hooks:  # customized layer_forward
+                    hidden_states, *output_this_layer = self.hooks['layer_forward'](*args, *output_this_layer,
+                                                                                    layer_id=torch.tensor(i), **kw_args)
                 else:
-                    hidden_states, output_this_layer = layer(*args, **kw_args)
+                    hidden_states, *output_this_layer = layer(*args, *output_this_layer, **kw_args)
                 if output_hidden_states:
                     hidden_states_outputs.append(hidden_states)
                 output_per_layers.append(output_this_layer)
diff --git a/SwissArmyTransformer/mpu/utils.py b/SwissArmyTransformer/mpu/utils.py
index c83f501889ccabb33cce33260f7d4fee9eadcab7..2f4f29f08227269634af831b91d2fe9b301b4e9e 100755
--- a/SwissArmyTransformer/mpu/utils.py
+++ b/SwissArmyTransformer/mpu/utils.py
@@ -75,7 +75,7 @@ def sqrt(x):
 
 def unscaled_init_method(sigma):
     """Init method based on N(0, sigma)."""
-    def init_(tensor):
+    def init_(tensor, **kwargs):
         return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
 
     return init_
@@ -83,7 +83,7 @@ def unscaled_init_method(sigma):
 def scaled_init_method(sigma, num_layers):
     """Init method based on N(0, sigma/sqrt(2*num_layers)."""
     std = sigma / math.sqrt(2.0 * num_layers)
-    def init_(tensor):
+    def init_(tensor, **kwargs):
         return torch.nn.init.normal_(tensor, mean=0.0, std=std)
 
     return init_
diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py
index ffdf5c354cfe6985d1aff9415079ab1c02dcacd7..2279ee051c622f17991f0337206933d2a05c3653 100644
--- a/SwissArmyTransformer/tokenization/__init__.py
+++ b/SwissArmyTransformer/tokenization/__init__.py
@@ -67,7 +67,7 @@ def get_tokenizer(args=None, outer_tokenizer=None):
         elif args.tokenizer_type.startswith('hf'):
             from .hf_tokenizer import HFT5Tokenizer
             if args.tokenizer_type == "hf_T5Tokenizer":
-                get_tokenizer.tokenizer = HFT5Tokenizer(args.tokenizer_model_type)
+                get_tokenizer.tokenizer = HFT5Tokenizer(args.tokenizer_model_type, cache_dir=args.cache_dir)
         else:
             assert args.vocab_size > 0
             get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size)
diff --git a/SwissArmyTransformer/tokenization/hf_tokenizer.py b/SwissArmyTransformer/tokenization/hf_tokenizer.py
index d67197662cf200acbda1299bce1e9199c1e34862..e40adb47a4ac8cb876132dc5cefc27f824442baf 100644
--- a/SwissArmyTransformer/tokenization/hf_tokenizer.py
+++ b/SwissArmyTransformer/tokenization/hf_tokenizer.py
@@ -2,8 +2,18 @@ from transformers import T5Tokenizer
 from .glm.tokenization import Tokenization, CommandToken
 
 
+PRETRAINED_VOCAB_FILES_MAP = {
+    "t5-small": "/dataset/fd5061f6/yanan/huggingface_models/t5-small",
+    "t5-base": "/dataset/fd5061f6/yanan/huggingface_models/t5-base",
+    "t5-large": "/mnt/t5",
+    "t5-3b": "/dataset/fd5061f6/yanan/huggingface_models/t5-3b",
+    "t5-11b": "/dataset/fd5061f6/yanan/huggingface_models/t5-11b"
+}
+
 class HFTokenizer:
     def __init__(self, model_cls, model_type_or_path=None, cache_dir=None, command_tokens=None):
+        if model_type_or_path in PRETRAINED_VOCAB_FILES_MAP:
+            model_type_or_path = PRETRAINED_VOCAB_FILES_MAP[model_type_or_path]
         self.text_tokenizer = model_cls.from_pretrained(model_type_or_path, cache_dir=cache_dir)
         self.num_tokens = len(self.text_tokenizer)
         self._command_tokens = []
@@ -11,6 +21,9 @@ class HFTokenizer:
         self.command_token_map = {}
         self.command_id_map = {}
 
+    def __len__(self):
+        return len(self.text_tokenizer)
+
     @property
     def command_tokens(self):
         return self._command_tokens
@@ -59,7 +72,9 @@ class HFT5Tokenizer(HFTokenizer):
         command_tokens = [
             CommandToken('eos', '</s>', self.TokenToId("</s>")),
             CommandToken('pad', '<pad>', self.TokenToId("<pad>")),
+            CommandToken('sop', '<pad>', self.TokenToId("<pad>"))
         ]
         for i in range(100):
             command_tokens.append(CommandToken(f'MASK{i}', f'<extra_id_{i}>', self.TokenToId(f'<extra_id_{i}>')))
         self.command_tokens = command_tokens
+
diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py
index d9059ac4d1f282e3792b3a877af76d81c8531b7c..beacffc39d8dfce4609c25fcf40c37c1d78e6c6e 100644
--- a/SwissArmyTransformer/training/deepspeed_training.py
+++ b/SwissArmyTransformer/training/deepspeed_training.py
@@ -144,6 +144,8 @@ def get_model(args, model_cls):
 
     if args.fp16:
         model.half()
+    elif args.bf16:
+        model.bfloat16()
     model.cuda(torch.cuda.current_device())
 
     return model
@@ -546,7 +548,8 @@ def initialize_distributed(args):
     # Optional DeepSpeed Activation Checkpointing Features
     if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing:
         set_deepspeed_activation_checkpointing(args)  # TODO manual model-parallel seed
-
+    else:
+        mpu.get_cuda_rng_tracker = None
 
 def set_random_seed(seed):
     """Set random seed for reproducability."""