diff --git a/SwissArmyTransformer/arguments.py b/SwissArmyTransformer/arguments.py
index 5eafe851c4d6064d802532c2b822efe4cefc42a1..804bae08edbd705097b8fb2828443b87066e8a9c 100755
--- a/SwissArmyTransformer/arguments.py
+++ b/SwissArmyTransformer/arguments.py
@@ -300,25 +300,28 @@ def get_args(args_list=None):
         print('using world size: {} and model-parallel size: {} '.format(
             args.world_size, args.model_parallel_size))
 
-    if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_config is not None:
-        with open(args.deepspeed_config) as file:
-            deepspeed_config = json.load(file)
-        if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
-            args.fp16 = True
-        else:
-            args.fp16 = False
+    if hasattr(args, "deepspeed") and args.deepspeed:
         if args.checkpoint_activations:
             args.deepspeed_activation_checkpointing = True
-        if "train_micro_batch_size_per_gpu" in deepspeed_config:
-            args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
-        if "gradient_accumulation_steps" in deepspeed_config:
-            args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
         else:
-            args.gradient_accumulation_steps = None
-        if "optimizer" in deepspeed_config:
-            optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
-            args.lr = optimizer_params_config.get("lr", args.lr)
-            args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
+            args.deepspeed_activation_checkpointing = False
+        if args.deepspeed_config is not None:
+            with open(args.deepspeed_config) as file:
+                deepspeed_config = json.load(file)
+            if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
+                args.fp16 = True
+            else:
+                args.fp16 = False
+            if "train_micro_batch_size_per_gpu" in deepspeed_config:
+                args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
+            if "gradient_accumulation_steps" in deepspeed_config:
+                args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
+            else:
+                args.gradient_accumulation_steps = None
+            if "optimizer" in deepspeed_config:
+                optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
+                args.lr = optimizer_params_config.get("lr", args.lr)
+                args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
     return args
 
 
diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py
index 0c6336c66eb887792e50872d9af768445c5ce030..341a800083d5a317dfc55cde1592915a8296c4ca 100644
--- a/SwissArmyTransformer/model/encoder_decoder_model.py
+++ b/SwissArmyTransformer/model/encoder_decoder_model.py
@@ -64,11 +64,22 @@ class EncoderDecoderModel(torch.nn.Module):
         return encoder_outputs
     
     def decode(self, input_ids, position_ids, attention_mask, encoder_outputs,cross_attention_mask=None, **kw_args):
+        if attention_mask is None:
+            batch_size, seq_length = input_ids.size()[:2]
+            seq_ids = torch.arange(seq_length, device=input_ids.device)
+            attention_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+            attention_mask = attention_mask.to(self.decoder.transformer.word_embeddings.weight.dtype)
+            attention_mask = attention_mask[:, None, :, :]
         # If no context, please explicitly pass ``encoder_outputs=None''
         return self.decoder(input_ids, position_ids, attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args)
     
-    def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids,dec_attention_mask, *, enc_attention_mask=None, cross_attention_mask=None, **kw_args):
+    def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids, *, enc_attention_mask=None, dec_attention_mask=None, cross_attention_mask=None, **kw_args):
         # Please use self.decoder for auto-regressive generation.
+        batch_size, seq_length = enc_input_ids.size()[:2]
+        if enc_attention_mask is None:
+            enc_attention_mask = torch.ones(1, 1, 1, seq_length, dtype=self.encoder.transformer.word_embeddings.weight.dtype, device=enc_input_ids.device)
+        if cross_attention_mask is None:
+            cross_attention_mask = enc_attention_mask
         encoder_outputs = self.encode(enc_input_ids, enc_position_ids, enc_attention_mask, **kw_args)
         decoder_outputs, *mems = self.decode(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args)
         return encoder_outputs, decoder_outputs, *mems
diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
index 944a133371fe591e8f20c6f47c9350040df519a5..7143950772587d88a45960668215a7c604ea1f36 100644
--- a/SwissArmyTransformer/model/t5_model.py
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -4,7 +4,7 @@ import torch.nn.functional as F
 from .mixins import BaseMixin
 from .encoder_decoder_model import EncoderDecoderModel
 from SwissArmyTransformer.mpu import get_model_parallel_world_size
-from SwissArmyTransformer.mpu.transformer import standard_attention
+from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP
 from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
 from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim
 
@@ -28,9 +28,11 @@ class T5LayerNorm(torch.nn.Module):
         variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
         hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 
-        # convert into float16 if necessary
+        # convert into float16 or bfloat16 if necessary
         if self.weight.dtype == torch.float16:
             hidden_states = hidden_states.to(torch.float16)
+        elif self.weight.dtype == torch.bfloat16:
+            hidden_states = hidden_states.to(torch.bfloat16)
         return self.weight * hidden_states
 
 
@@ -111,7 +113,7 @@ class T5AttentionMixin(BaseMixin):
         values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
         return values
 
-    def attention_forward(self, hidden_states, mask, *args, layer_id=None, mems=None, **kw_args):
+    def attention_forward(self, hidden_states, mask, position_bias=None, *args, layer_id=None, mems=None, **kw_args):
         attn_module = self.transformer.layers[layer_id].attention
         seq_length = hidden_states.size(1)
         memory_length = mems[layer_id].size(1) if mems else 0
@@ -126,7 +128,8 @@ class T5AttentionMixin(BaseMixin):
         key_layer = attn_module._transpose_for_scores(mixed_key_layer)
         value_layer = attn_module._transpose_for_scores(mixed_value_layer)
 
-        position_bias = self.compute_bias(seq_length, memory_length + seq_length)
+        if position_bias is None:
+            position_bias = self.compute_bias(seq_length, memory_length + seq_length)
         context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn,
                                            log_attention_weights=position_bias, scaling_attention_score=False)
         context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
@@ -137,9 +140,10 @@ class T5AttentionMixin(BaseMixin):
         if attn_module.training:
             output = attn_module.output_dropout(output)
 
-        return output, None
+        return output, position_bias
 
-    def cross_attention_forward(self, hidden_states, cross_mask, encoder_outputs, layer_id=None, *args, **kw_args):
+    def cross_attention_forward(self, hidden_states, cross_attention_mask, encoder_outputs, layer_id=None, *args,
+                                **kw_args):
         attn_module = self.transformer.layers[layer_id].cross_attention
         mixed_query_layer = attn_module.query(hidden_states)
         mixed_x_layer = attn_module.key_value(encoder_outputs)
@@ -151,7 +155,7 @@ class T5AttentionMixin(BaseMixin):
         key_layer = attn_module._transpose_for_scores(mixed_key_layer)
         value_layer = attn_module._transpose_for_scores(mixed_value_layer)
 
-        context_layer = standard_attention(query_layer, key_layer, value_layer, cross_mask, dropout_fn,
+        context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn,
                                            scaling_attention_score=False)
         context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
         new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
@@ -180,8 +184,10 @@ class T5DecoderFinalMixin(BaseMixin):
 
 class T5Model(EncoderDecoderModel):
     def __init__(self, args, **kwargs):
-        super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False, 
-        layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu)
+        self.init_method_std = args.init_method_std
+        super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False,
+                         layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu,
+                         init_method=self._init_weights)
         self.encoder.add_mixin(
             "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads)
         )
@@ -201,7 +207,61 @@ class T5Model(EncoderDecoderModel):
         )
         del self.decoder.transformer.position_embeddings
 
+    def _init_weights(self, weight, module, name):
+        init_method_std = self.init_method_std
+        if isinstance(module, MLP):
+            if name == "dense_h_to_4h":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
+            elif name == "dense_4h_to_h":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
+            else:
+                raise NotImplementedError(name)
+        elif isinstance(module, SelfAttention):
+            if name == "query_key_value":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
+                torch.nn.init.normal_(weight[:module.inner_hidden_size], mean=0, std=init_method_std * (
+                        (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5))
+            elif name == "dense":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
+            else:
+                raise NotImplementedError(name)
+        elif isinstance(module, CrossAttention):
+            if name == "query":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (
+                        (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5))
+            elif name == "key_value":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
+            elif name == "dense":
+                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
+            else:
+                raise NotImplementedError(name)
+        else:
+            raise NotImplementedError(module)
+
     @classmethod
     def add_model_specific_args(cls, parser):
         super().add_model_specific_args(parser)
         parser.add_argument("--relative-attention-num-buckets", type=int, default=None)
+        parser.add_argument("--init-method-std", type=float, default=0.02)
+
+    def encode(self, input_ids, attention_mask=None, **kw_args):
+        return super().encode(input_ids, None, attention_mask, **kw_args)
+
+    def decode(self, input_ids, attention_mask=None, encoder_outputs=None, cross_attention_mask=None, **kw_args):
+        return super().decode(input_ids, None, attention_mask, encoder_outputs=encoder_outputs,
+                              cross_attention_mask=cross_attention_mask, **kw_args)
+
+    def forward(self, enc_input_ids, dec_input_ids, *, enc_attention_mask=None,
+                dec_attention_mask=None, cross_attention_mask=None, **kw_args):
+        batch_size, seq_length = enc_input_ids.size()[:2]
+        if enc_attention_mask is None:
+            enc_attention_mask = torch.ones(1, 1, 1, seq_length,
+                                            dtype=self.encoder.transformer.word_embeddings.weight.dtype,
+                                            device=enc_input_ids.device)
+        if cross_attention_mask is None:
+            cross_attention_mask = enc_attention_mask
+        encoder_outputs = self.encode(enc_input_ids, enc_attention_mask, **kw_args)
+        decoder_outputs, *mems = self.decode(dec_input_ids, dec_attention_mask,
+                                             encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask,
+                                             **kw_args)
+        return encoder_outputs, decoder_outputs, *mems
diff --git a/SwissArmyTransformer/mpu/layers.py b/SwissArmyTransformer/mpu/layers.py
index 4af3f33180aae46bba28efa213a9a941df5cd3ce..8d7cf370e75f2629ae96b1b2f442f4dfe6bae70f 100755
--- a/SwissArmyTransformer/mpu/layers.py
+++ b/SwissArmyTransformer/mpu/layers.py
@@ -37,7 +37,7 @@ from .utils import VocabUtility
 
 def _initialize_affine_weight(weight, output_size, input_size,
                               per_partition_size, partition_dim, init_method,
-                              stride=1, return_master_weight=False):
+                              stride=1, return_master_weight=False, module=None, name=None):
     """Initialize affine weight for model parallel.
 
     Build the master weight on all processes and scatter
@@ -45,7 +45,7 @@ def _initialize_affine_weight(weight, output_size, input_size,
     # If we only use 1 process for model parallelism, bypass scatter.
     world_size = get_model_parallel_world_size()
     if world_size == 1:
-        init_method(weight)
+        init_method(weight, module=module, name=name)
         if return_master_weight:
             return weight
         return None
@@ -54,7 +54,7 @@ def _initialize_affine_weight(weight, output_size, input_size,
     master_weight = torch.empty(output_size, input_size,
                                 dtype=weight.dtype,
                                 requires_grad=False)
-    init_method(master_weight)
+    init_method(master_weight, module=module, name=name)
 
     # Split and copy
     per_partition_per_stride_size = divide(per_partition_size, stride)
@@ -200,7 +200,7 @@ class ColumnParallelLinear(torch.nn.Module):
     """
     def __init__(self, input_size, output_size, bias=True, gather_output=True,
                  init_method=init.xavier_normal_, stride=1,
-                 keep_master_weight_for_test=False):
+                 keep_master_weight_for_test=False, module=None, name=None):
         super(ColumnParallelLinear, self).__init__()
 
         # Keep input parameters
@@ -230,7 +230,7 @@ class ColumnParallelLinear(torch.nn.Module):
         self.master_weight = _initialize_affine_weight(
             self.weight, self.output_size, self.input_size,
             self.output_size_per_partition, 0, init_method,
-            stride=stride, return_master_weight=keep_master_weight_for_test)
+            stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name)
 
     def forward(self, input_):
         # Set up backprop all-reduce.
@@ -274,7 +274,7 @@ class RowParallelLinear(torch.nn.Module):
     def __init__(self, input_size, output_size, bias=True,
                  input_is_parallel=False,
                  init_method=init.xavier_normal_, stride=1,
-                 keep_master_weight_for_test=False):
+                 keep_master_weight_for_test=False, module=None, name=None):
         super(RowParallelLinear, self).__init__()
 
         # Keep input parameters
@@ -303,7 +303,7 @@ class RowParallelLinear(torch.nn.Module):
         self.master_weight = _initialize_affine_weight(
             self.weight, self.output_size, self.input_size,
             self.input_size_per_partition, 1, init_method,
-            stride=stride, return_master_weight=keep_master_weight_for_test)
+            stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name)
 
     def forward(self, input_):
         # Set up backprop all-reduce.
diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index c553cc21daf4e96c7ce79be41dd08e1cb5cb1c47..8afb1f9dc66151584fbb77efcb89b328a61d1b98 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -22,11 +22,12 @@ import torch
 import torch.nn.functional as F
 from apex.normalization.fused_layer_norm import FusedLayerNorm
 
+from SwissArmyTransformer import mpu
 from .initialize import get_model_parallel_world_size
 from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
 from .mappings import gather_from_model_parallel_region, copy_to_model_parallel_region
 
-from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint, get_cuda_rng_tracker
+from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
 
 from .utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu
 from .utils import split_tensor_along_last_dim
@@ -62,7 +63,10 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask,
     attention_probs = F.softmax(attention_scores, dim=-1)
 
     if attention_dropout is not None:
-        with get_cuda_rng_tracker().fork():
+        if mpu.get_cuda_rng_tracker is not None:
+            with mpu.get_cuda_rng_tracker().fork():
+                attention_probs = attention_dropout(attention_probs)
+        else:
             attention_probs = attention_dropout(attention_probs)
 
     context_layer = torch.matmul(attention_probs, value_layer)
@@ -82,31 +86,36 @@ class SelfAttention(torch.nn.Module):
         self.layer_id = layer_id
         # Per attention head and per partition values.
         world_size = get_model_parallel_world_size()
+        self.hidden_size = hidden_size
         if hidden_size_per_attention_head is None:
             self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
         else:
             self.hidden_size_per_attention_head = hidden_size_per_attention_head
         self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
-        inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
+        self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
         self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
 
         # Strided linear layer.
         self.query_key_value = ColumnParallelLinear(
             hidden_size,
-            3 * inner_hidden_size,
+            3 * self.inner_hidden_size,
             stride=3,
             gather_output=False,
             init_method=init_method,
-            bias=bias
+            bias=bias,
+            module=self,
+            name="query_key_value"
         )
         self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
 
         self.dense = RowParallelLinear(
-            inner_hidden_size,
+            self.inner_hidden_size,
             hidden_size,
             input_is_parallel=True,
             init_method=output_layer_init_method,
-            bias=bias
+            bias=bias,
+            module=self,
+            name="dense"
         )
         self.output_dropout = torch.nn.Dropout(output_dropout_prob)
 
@@ -165,21 +174,22 @@ class CrossAttention(torch.nn.Module):
         self.layer_id = layer_id
         # Per attention head and per partition values.
         world_size = get_model_parallel_world_size()
+        self.hidden_size = hidden_size
         if hidden_size_per_attention_head is None:
             self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
         else:
             self.hidden_size_per_attention_head = hidden_size_per_attention_head
         self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
-        inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
+        self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
         self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
         # Strided linear layer.
-        self.query = ColumnParallelLinear(hidden_size, inner_hidden_size,
+        self.query = ColumnParallelLinear(hidden_size, self.inner_hidden_size,
                                           gather_output=False,
-                                          init_method=init_method, bias=bias)
-        self.key_value = ColumnParallelLinear(hidden_size, 2 * inner_hidden_size,
+                                          init_method=init_method, bias=bias, module=self, name="query")
+        self.key_value = ColumnParallelLinear(hidden_size, 2 * self.inner_hidden_size,
                                               stride=2,
                                               gather_output=False,
-                                              init_method=init_method, bias=bias)
+                                              init_method=init_method, bias=bias, module=self, name="key_value")
         # Dropout. Note that for a single iteration, this layer will generate
         # different outputs on different number of parallel partitions but
         # on average it should not be partition dependent.
@@ -187,10 +197,10 @@ class CrossAttention(torch.nn.Module):
 
         # Output.
         self.dense = RowParallelLinear(
-            inner_hidden_size,
+            self.inner_hidden_size,
             hidden_size,
             input_is_parallel=True,
-            init_method=output_layer_init_method, bias=bias)
+            init_method=output_layer_init_method, bias=bias, module=self, name="dense")
         self.output_dropout = torch.nn.Dropout(output_dropout_prob)
 
     def _transpose_for_scores(self, tensor):
@@ -245,22 +255,28 @@ class MLP(torch.nn.Module):
             output_layer_init_method = init_method
         self.hooks = hooks
         # Project to 4h.
+        self.hidden_size = hidden_size
         if inner_hidden_size is None:
             inner_hidden_size = 4 * hidden_size
+        self.inner_hidden_size = inner_hidden_size
         self.dense_h_to_4h = ColumnParallelLinear(
-            hidden_size,
-            inner_hidden_size,
+            self.hidden_size,
+            self.inner_hidden_size,
             gather_output=False,
             init_method=init_method,
-            bias=bias
+            bias=bias,
+            module=self,
+            name="dense_h_to_4h"
         )
         # Project back to h.
         self.dense_4h_to_h = RowParallelLinear(
-            inner_hidden_size,
-            hidden_size,
+            self.inner_hidden_size,
+            self.hidden_size,
             input_is_parallel=True,
             init_method=output_layer_init_method,
-            bias=bias
+            bias=bias,
+            module=self,
+            name="dense_4h_to_h"
         )
         self.dropout = torch.nn.Dropout(output_dropout_prob)
 
@@ -358,7 +374,7 @@ class BaseTransformerLayer(torch.nn.Module):
             hooks=hooks
         )
 
-    def forward(self, hidden_states, mask, **kw_args):
+    def forward(self, hidden_states, mask, encoder_outputs=None, *args, **kw_args):
         '''
             hidden_states: [batch, seq_len, hidden_size]
             mask: [(1, 1), seq_len, seq_len]
@@ -424,6 +440,7 @@ class BaseTransformer(torch.nn.Module):
                  use_bias=True,
                  activation_func=gelu,
                  layernorm=LayerNorm,
+                 init_method=None,
                  hooks={}
                  ):
         super(BaseTransformer, self).__init__()
@@ -446,8 +463,12 @@ class BaseTransformer(torch.nn.Module):
         torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
 
         # create all layers
-        self.output_layer_init_method = scaled_init_method(init_method_std, num_layers)
-        self.init_method = unscaled_init_method(init_method_std)
+        if init_method is None:
+            self.output_layer_init_method = scaled_init_method(init_method_std, num_layers)
+            self.init_method = unscaled_init_method(init_method_std)
+        else:
+            self.output_layer_init_method = init_method
+            self.init_method = init_method
 
         def get_layer(layer_id):
             return BaseTransformerLayer(
@@ -483,7 +504,7 @@ class BaseTransformer(torch.nn.Module):
         if attention_mask is None:
             attention_mask = torch.ones(1, 1, device=input_ids.device).type_as(
                 next(self.parameters())
-            ) # None means full attention
+            )  # None means full attention
         assert len(attention_mask.shape) == 2 or \
                len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
     
@@ -575,6 +596,7 @@ class BaseTransformer(torch.nn.Module):
             
             l, num_layers = 0, len(self.layers)
             chunk_length = self.checkpoint_num_layers
+            output_this_layer = []
             while l < num_layers:
                 args = [hidden_states, attention_mask]
                 # flatten kw_args and output_cross_layer
@@ -601,6 +623,7 @@ class BaseTransformer(torch.nn.Module):
                 output_per_layers.extend(output_per_layers_part)
                 l += chunk_length
         else:
+            output_this_layer = []
             for i, layer in enumerate(self.layers):
                 args = [hidden_states, attention_mask]
 
diff --git a/SwissArmyTransformer/mpu/utils.py b/SwissArmyTransformer/mpu/utils.py
index c83f501889ccabb33cce33260f7d4fee9eadcab7..2f4f29f08227269634af831b91d2fe9b301b4e9e 100755
--- a/SwissArmyTransformer/mpu/utils.py
+++ b/SwissArmyTransformer/mpu/utils.py
@@ -75,7 +75,7 @@ def sqrt(x):
 
 def unscaled_init_method(sigma):
     """Init method based on N(0, sigma)."""
-    def init_(tensor):
+    def init_(tensor, **kwargs):
         return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
 
     return init_
@@ -83,7 +83,7 @@ def unscaled_init_method(sigma):
 def scaled_init_method(sigma, num_layers):
     """Init method based on N(0, sigma/sqrt(2*num_layers)."""
     std = sigma / math.sqrt(2.0 * num_layers)
-    def init_(tensor):
+    def init_(tensor, **kwargs):
         return torch.nn.init.normal_(tensor, mean=0.0, std=std)
 
     return init_
diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py
index 7e0485274deeae3f095456f73a0d9c842a026e46..2279ee051c622f17991f0337206933d2a05c3653 100644
--- a/SwissArmyTransformer/tokenization/__init__.py
+++ b/SwissArmyTransformer/tokenization/__init__.py
@@ -29,7 +29,8 @@ def _export_vocab_size_to_args(args, original_num_tokens):
     print_rank_0('> padded vocab (size: {}) with {} dummy '
                  'tokens (new size: {})'.format(
         before, after - before, after))
-    args.vocab_size = after
+    if not args.vocab_size:
+        args.vocab_size = after
     print_rank_0("prepare tokenizer done")
     return tokenizer
 
@@ -63,6 +64,10 @@ def get_tokenizer(args=None, outer_tokenizer=None):
             elif args.tokenizer_type == "glm_ChineseSPTokenizer":
                 from .glm import ChineseSPTokenizer
                 get_tokenizer.tokenizer = ChineseSPTokenizer(args.tokenizer_model_type, **kwargs)
+        elif args.tokenizer_type.startswith('hf'):
+            from .hf_tokenizer import HFT5Tokenizer
+            if args.tokenizer_type == "hf_T5Tokenizer":
+                get_tokenizer.tokenizer = HFT5Tokenizer(args.tokenizer_model_type, cache_dir=args.cache_dir)
         else:
             assert args.vocab_size > 0
             get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size)
diff --git a/SwissArmyTransformer/tokenization/glm/tokenization.py b/SwissArmyTransformer/tokenization/glm/tokenization.py
index 67be818a0f1d42b0d35696c24577b5e5391ff3b4..9b9a8abc0202b46a95efc8a8338e90e0da9fd60f 100644
--- a/SwissArmyTransformer/tokenization/glm/tokenization.py
+++ b/SwissArmyTransformer/tokenization/glm/tokenization.py
@@ -312,11 +312,11 @@ class Tokenizer(object):
         tokenization.tokenization = [self.IdToToken(idx) for idx in tokenization.tokenization]
         return tokenization
 
-    def IdToToken(self, Id):
+    def IdToToken(self, idx):
         """convert Id to token accounting for command tokens"""
-        if isinstance(Id, CommandToken):
-            return Id.token
-        return self.tokens[Id]
+        if isinstance(idx, CommandToken):
+            return idx.token
+        return self.tokens[idx]
 
     def TokenToId(self, token):
         """convert token to Id accounting for command tokens"""
@@ -324,16 +324,16 @@ class Tokenizer(object):
             return token.Id
         return self.vocab[token]
 
-    def DecodeIds(self, Ids):
+    def DecodeIds(self, ids):
         """
         convert Ids to tokens accounting for command tokens, tokens
         are joined and returned as a string.
         """
         rtn_strs = []
         current_str = []
-        if isinstance(Ids, Tokenization):
-            Ids = Ids.tokenization
-        for Id in Ids:
+        if isinstance(ids, Tokenization):
+            ids = ids.tokenization
+        for Id in ids:
             if isinstance(Id, CommandToken):
                 rtn_strs.append(self._decode(current_str))
                 current_str = []
@@ -353,11 +353,11 @@ class Tokenizer(object):
         output = self.clean_up_tokenization(output)
         return output
 
-    def DecodeTokens(self, Tokens):
+    def DecodeTokens(self, tokens):
         """
         convert tokens to a string accounting for command and type tokens.
         """
-        Ids = [self.TokenToId(token) for token in Tokens]
+        Ids = [self.TokenToId(token) for token in tokens]
         return self.DecodeIds(Ids)
 
 
diff --git a/SwissArmyTransformer/tokenization/hf_tokenizer.py b/SwissArmyTransformer/tokenization/hf_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e40adb47a4ac8cb876132dc5cefc27f824442baf
--- /dev/null
+++ b/SwissArmyTransformer/tokenization/hf_tokenizer.py
@@ -0,0 +1,80 @@
+from transformers import T5Tokenizer
+from .glm.tokenization import Tokenization, CommandToken
+
+
+PRETRAINED_VOCAB_FILES_MAP = {
+    "t5-small": "/dataset/fd5061f6/yanan/huggingface_models/t5-small",
+    "t5-base": "/dataset/fd5061f6/yanan/huggingface_models/t5-base",
+    "t5-large": "/mnt/t5",
+    "t5-3b": "/dataset/fd5061f6/yanan/huggingface_models/t5-3b",
+    "t5-11b": "/dataset/fd5061f6/yanan/huggingface_models/t5-11b"
+}
+
+class HFTokenizer:
+    def __init__(self, model_cls, model_type_or_path=None, cache_dir=None, command_tokens=None):
+        if model_type_or_path in PRETRAINED_VOCAB_FILES_MAP:
+            model_type_or_path = PRETRAINED_VOCAB_FILES_MAP[model_type_or_path]
+        self.text_tokenizer = model_cls.from_pretrained(model_type_or_path, cache_dir=cache_dir)
+        self.num_tokens = len(self.text_tokenizer)
+        self._command_tokens = []
+        self.command_name_map = {}
+        self.command_token_map = {}
+        self.command_id_map = {}
+
+    def __len__(self):
+        return len(self.text_tokenizer)
+
+    @property
+    def command_tokens(self):
+        return self._command_tokens
+
+    @command_tokens.setter
+    def command_tokens(self, command_tokens):
+        self._command_tokens = command_tokens
+        self.command_name_map = {tok.name: tok for tok in self.command_tokens}
+        self.command_token_map = {tok.token: tok for tok in self.command_tokens}
+        self.command_id_map = {tok.Id: tok for tok in self.command_tokens}
+
+    def get_command(self, name):
+        """get command token corresponding to `name`"""
+        return self.command_name_map[name]
+
+    def EncodeAsIds(self, text, process_fn=None):
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        ids = self.text_tokenizer.encode(processed_text, add_special_tokens=False)
+        tokenization = Tokenization(ids, processed_text, text)
+        return tokenization
+
+    def DecodeIds(self, ids):
+        if isinstance(ids, Tokenization):
+            ids = ids.tokenization
+        return self.text_tokenizer.decode(ids)
+
+    def DecodeTokens(self, tokens):
+        return self.text_tokenizer.convert_tokens_to_string(tokens)
+
+    def IdToToken(self, Id):
+        if isinstance(Id, CommandToken):
+            return Id.token
+        return self.text_tokenizer.convert_ids_to_tokens(Id)
+
+    def TokenToId(self, token):
+        if isinstance(token, CommandToken):
+            return token.Id
+        return self.text_tokenizer.convert_tokens_to_ids(token)
+
+
+class HFT5Tokenizer(HFTokenizer):
+    def __init__(self, model_type_or_path=None, cache_dir=None):
+        super().__init__(T5Tokenizer, model_type_or_path=model_type_or_path, cache_dir=cache_dir)
+        command_tokens = [
+            CommandToken('eos', '</s>', self.TokenToId("</s>")),
+            CommandToken('pad', '<pad>', self.TokenToId("<pad>")),
+            CommandToken('sop', '<pad>', self.TokenToId("<pad>"))
+        ]
+        for i in range(100):
+            command_tokens.append(CommandToken(f'MASK{i}', f'<extra_id_{i}>', self.TokenToId(f'<extra_id_{i}>')))
+        self.command_tokens = command_tokens
+
diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py
index d9059ac4d1f282e3792b3a877af76d81c8531b7c..beacffc39d8dfce4609c25fcf40c37c1d78e6c6e 100644
--- a/SwissArmyTransformer/training/deepspeed_training.py
+++ b/SwissArmyTransformer/training/deepspeed_training.py
@@ -144,6 +144,8 @@ def get_model(args, model_cls):
 
     if args.fp16:
         model.half()
+    elif args.bf16:
+        model.bfloat16()
     model.cuda(torch.cuda.current_device())
 
     return model
@@ -546,7 +548,8 @@ def initialize_distributed(args):
     # Optional DeepSpeed Activation Checkpointing Features
     if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing:
         set_deepspeed_activation_checkpointing(args)  # TODO manual model-parallel seed
-
+    else:
+        mpu.get_cuda_rng_tracker = None
 
 def set_random_seed(seed):
     """Set random seed for reproducability."""
diff --git a/examples/t5/config/config_t5_large.json b/examples/t5/config/config_t5_large.json
new file mode 100644
index 0000000000000000000000000000000000000000..25d7bf7ac71485cda4fba8500944712a06b775ab
--- /dev/null
+++ b/examples/t5/config/config_t5_large.json
@@ -0,0 +1,34 @@
+{
+  "train_micro_batch_size_per_gpu": 16,
+  "gradient_accumulation_steps": 1,
+  "steps_per_print": 100,
+  "gradient_clipping": 1.0,
+  "zero_optimization": {
+    "stage": 2,
+    "contiguous_gradients": false,
+    "overlap_comm": true,
+    "reduce_scatter": true,
+    "reduce_bucket_size": 50000000,
+    "allgather_bucket_size": 500000000
+  },
+  "bfloat16": {
+    "enabled": true
+  },
+  "optimizer": {
+    "type": "Adam",
+    "params": {
+      "lr": 0.0002,
+      "weight_decay": 0.1,
+      "betas": [
+        0.9,
+        0.98
+      ],
+      "eps": 1e-6
+    }
+  },
+  "activation_checkpointing": {
+    "partition_activations": false,
+    "contiguous_memory_optimization": false
+  },
+  "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/examples/t5/config/model_t5_large.sh b/examples/t5/config/model_t5_large.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0c20d559447a87030f5d07a701d7a125b4f79c97
--- /dev/null
+++ b/examples/t5/config/model_t5_large.sh
@@ -0,0 +1,15 @@
+MODEL_TYPE="t5-large"
+MODEL_ARGS="--block-lm \
+            --cloze-eval \
+            --vocab-size 32128 \
+            --num-layers 24 \
+            --hidden-size 1024 \
+            --inner-hidden-size 4096 \
+            --num-attention-heads 16 \
+            --hidden-size-per-attention-head 64 \
+            --max-sequence-length 513 \
+            --relative-attention-num-buckets 32 \
+            --layernorm-epsilon 1e-6 \
+            --tokenizer-type hf_T5Tokenizer \
+            --tokenizer-model-type t5-large \
+            --load ${CHECKPOINT_PATH}/glm-large-en-blank"
\ No newline at end of file
diff --git a/examples/t5/inference_t5.py b/examples/t5/inference_t5.py
new file mode 100644
index 0000000000000000000000000000000000000000..b902f248e0c6ebe6dfccba3805d4112923eff6fa
--- /dev/null
+++ b/examples/t5/inference_t5.py
@@ -0,0 +1,218 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   inference_glm.py
+@Time    :   2021/10/22 19:41:58
+@Author  :   Ming Ding
+@Contact :   dm18@mails.tsinghua.edu.cn
+'''
+
+# here put the import lib
+from functools import partial
+import os
+import sys
+import random
+import time
+from datetime import datetime
+import torch
+import torch.nn.functional as F
+import argparse
+import stat
+from functools import partial
+
+from SwissArmyTransformer import mpu, get_args, get_tokenizer, load_checkpoint, initialize_distributed, set_random_seed
+
+from SwissArmyTransformer.model import T5Model
+from SwissArmyTransformer.model.mixins import CachedAutoregressiveMixin
+from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence, evaluate_perplexity
+from SwissArmyTransformer.generation.sampling_strategies import BeamSearchStrategy, BaseStrategy
+from SwissArmyTransformer.generation.utils import timed_name, generate_continually
+from SwissArmyTransformer.training.deepspeed_training import setup_model_and_optimizer
+
+
+def decoder_shift_right(input_ids, args):
+    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+    shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
+    shifted_input_ids[..., 0] = args.decoder_start_token_id
+    return shifted_input_ids
+
+
+def get_batch(data, args):
+    keys = ['text', 'loss_mask', 'target', 'attention_mask']
+    datatype = torch.int64
+
+    # Broadcast data.
+    data_b = mpu.broadcast_data(keys, data, datatype)
+    # Unpack.
+    tokens = data_b['text'].long()
+    labels = data_b['target'].long()
+    decoder_tokens = decoder_shift_right(labels, args)
+    attention_mask = data_b['attention_mask'].long()
+    loss_mask = data_b['loss_mask'].float()
+
+    # Convert
+    if args.fp16:
+        attention_mask = attention_mask.half()
+    elif args.bf16:
+        attention_mask = attention_mask.bfloat16()
+    return tokens, decoder_tokens, labels, loss_mask, attention_mask
+
+
+def get_masks_and_position_ids_glm(seq, mask_position, context_length):
+    tokens = seq.unsqueeze(0)
+
+    attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
+    attention_mask.tril_()
+    attention_mask[..., :context_length] = 1
+    attention_mask.unsqueeze_(1)
+
+    position_ids = torch.zeros(2, len(seq), device=tokens.device, dtype=torch.long)
+    torch.arange(0, context_length, out=position_ids[0, :context_length])
+    position_ids[0, context_length:] = mask_position
+    torch.arange(1, len(seq) - context_length + 1, out=position_ids[1, context_length:])
+
+    position_ids = position_ids.unsqueeze(0)
+    return tokens, attention_mask, position_ids
+
+
+def main(args):
+    args.do_train = False
+    initialize_distributed(args)
+    tokenizer = get_tokenizer(args)
+    # load_checkpoint(model, args)
+    set_random_seed(args.seed)
+
+    # Model, optimizer, and learning rate.
+    model_cls = T5Model
+    model, optimizer = setup_model_and_optimizer(args, model_cls=model_cls)
+
+    missing_keys, unexpected_keys = model.module.load_state_dict(
+        torch.load("/dataset/fd5061f6/yanan/huggingface_models/t5-large/model_states.pt")["module"])
+    optimizer.refresh_fp32_params()
+    model.eval()
+    input_ids = tokenizer.EncodeAsIds("The <extra_id_0> walks in <extra_id_1> park").tokenization
+    input_ids = input_ids + [tokenizer.get_command("eos").Id]
+    input_ids = torch.LongTensor([input_ids])
+    decoder_input_ids = tokenizer.EncodeAsIds('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>').tokenization
+    decoder_input_ids = decoder_input_ids + [tokenizer.get_command("eos").Id]
+    decoder_input_ids = torch.LongTensor([decoder_input_ids])
+    data = {'text': input_ids, 'loss_mask': input_ids.new_ones(input_ids.shape), 'target': decoder_input_ids,
+            'attention_mask': input_ids.new_ones(input_ids.shape)}
+    tokens, decoder_tokens, labels, loss_mask, attention_mask = get_batch(data, args)
+    encoder_outputs, logits, *_ = model(enc_input_ids=tokens, dec_input_ids=decoder_tokens,
+                                        enc_attention_mask=attention_mask)
+    losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels)
+    loss_mask = loss_mask.view(-1)
+    loss = torch.sum(losses.view(-1) * loss_mask)
+    if loss_mask.sum().item() > 0:
+        loss = loss / loss_mask.sum()
+    loss.backward()
+
+    breakpoint()
+
+    end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
+    # define function for each query
+    if args.sampling_strategy == 'BaseStrategy':
+        strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens)
+    elif args.sampling_strategy == 'BeamSearchStrategy':
+        strategy = BeamSearchStrategy(args.batch_size, length_penalty=args.length_penalty, consider_end=True,
+                                      end_tokens=end_tokens, no_repeat_ngram_size=args.no_repeat_ngram_size,
+                                      min_tgt_length=args.min_tgt_length)
+    else:
+        raise ValueError(f'unknown strategy {args.sampling_strategy}')
+
+    def process(raw_text):
+        if args.with_id:
+            query_id, raw_text = raw_text.split('\t')
+        # add MASK
+        generation_mask = '[gMASK]' if args.task_mask else '[MASK]'
+        if 'MASK]' not in raw_text:
+            raw_text += ' ' + generation_mask
+        seq = tokenizer.EncodeAsIds(raw_text).tokenization
+        seq = [tokenizer.get_command('ENC').Id] + seq
+        if not raw_text.endswith('MASK]'):
+            seq = seq + [tokenizer.get_command('eos').Id]
+        print('raw text: {}\n'.format(raw_text))
+        if len(seq) > args.max_sequence_length:
+            raise ValueError('text too long.')
+
+        # generation
+        mbz = args.max_inference_batch_size
+        assert args.batch_size < mbz or args.batch_size % mbz == 0
+        output_list = [seq]
+        # continually detect the first mark position
+        while True:
+            seq = output_list[0]  # TODO find the best one
+            # detect
+            mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
+            mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
+            mask_position = len(seq)
+            for token in mask_tokens:
+                try:
+                    mask_position = min(mask_position, seq.index(token))
+                except ValueError:
+                    pass
+            if mask_position == len(seq):
+                break
+
+            get_func = partial(get_masks_and_position_ids_glm, mask_position=mask_position, context_length=len(seq))
+            output_list = []
+            for tim in range(max(args.batch_size // mbz, 1)):
+                input_seq = torch.cuda.LongTensor(
+                    seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length - len(seq) - 1),
+                    device=args.device)
+                output = filling_sequence(model, input_seq,
+                                          batch_size=min(args.batch_size, mbz),
+                                          strategy=strategy,
+                                          log_attention_weights=None,
+                                          get_masks_and_position_ids=get_func
+                                          )[0]  # we don't use mems, fill back
+                if isinstance(output, torch.Tensor):  # different strategies
+                    output = list(output)
+
+                output_list.extend(output)
+
+            # clip -1s and fill back generated things into seq
+            for i in range(len(output_list)):
+                output = output_list[i].tolist()
+                try:
+                    unfinished = output.index(-1)
+                except ValueError:
+                    unfinished = len(output)
+                if output[unfinished - 1] in end_tokens:
+                    unfinished -= 1
+                bog = output.index(tokenizer.get_command('sop').Id)
+                output_list[i] = output[:mask_position] + output[bog + 1:unfinished] + output[mask_position + 1:bog]
+
+        # decoding
+        txts = []
+        for seq in output_list:
+            decode_tokens = tokenizer.DecodeIds(seq)
+            txts.append(decode_tokens)
+
+        # save
+        if args.with_id:
+            full_path = os.path.join(args.output_path, query_id + '.txt')
+        else:
+            prefix = raw_text.replace('/', '')[:20]
+            full_path = timed_name(prefix, '.txt', args.output_path)
+            print(txts[0])  # print the first.
+        with open(full_path, 'w') as fout:
+            for txt in txts:
+                fout.write(txt + '\n')
+        os.chmod(full_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU)
+
+    os.makedirs(args.output_path, exist_ok=True)
+    generate_continually(process, args.input_source)
+
+
+if __name__ == "__main__":
+    py_parser = argparse.ArgumentParser(add_help=False)
+    py_parser.add_argument('--sampling-strategy', type=str, default='BaseStrategy',
+                           help='type name of sampling strategy')
+    T5Model.add_model_specific_args(py_parser)
+    known, args_list = py_parser.parse_known_args()
+    args = get_args(args_list)
+    args = argparse.Namespace(**vars(args), **vars(known))
+
+    with torch.no_grad():
+        main(args)
diff --git a/examples/t5/scripts/generate_t5.sh b/examples/t5/scripts/generate_t5.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c4bc602d8babfca6a51b29c99e9ccfdd1d209558
--- /dev/null
+++ b/examples/t5/scripts/generate_t5.sh
@@ -0,0 +1,38 @@
+#!/bin/bash
+CHECKPOINT_PATH=/dataset/fd5061f6/sat_pretrained/glm
+
+source $1
+MPSIZE=1
+MAXSEQLEN=512
+MASTER_PORT=$(shuf -n 1 -i 10000-65535)
+
+#SAMPLING ARGS
+TEMP=0.9
+#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p
+TOPK=40
+TOPP=0
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+
+config_json="$script_dir/config_t5_large.json"
+
+python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTER_PORT inference_t5.py \
+       --deepspeed \
+       --deepspeed-config ${config_json} \
+       --mode inference \
+       --model-parallel-size $MPSIZE \
+       $MODEL_ARGS \
+       --num-beams 4 \
+       --no-repeat-ngram-size 3 \
+       --length-penalty 0.7 \
+       --out-seq-length $MAXSEQLEN \
+       --temperature $TEMP \
+       --top_k $TOPK \
+       --output-path samples_glm \
+       --batch-size 2 \
+       --out-seq-length 200 \
+       --mode inference \
+       --input-source ./input.txt \
+       --checkpoint-activations \
+       --sampling-strategy BeamSearchStrategy
diff --git a/examples/t5/test_t5.py b/examples/t5/test_t5.py
new file mode 100644
index 0000000000000000000000000000000000000000..692ed918f39ddebbcbc23cbf194e0999e2a44aa6
--- /dev/null
+++ b/examples/t5/test_t5.py
@@ -0,0 +1,10 @@
+from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer
+tokenizer = T5Tokenizer.from_pretrained("t5-large")
+model = T5Model.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large")
+model = model.to('cuda')
+model.eval()
+input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids.to('cuda')
+decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids.to('cuda')
+output = model(input_ids=input_ids, labels=decoder_input_ids)
+output.loss.backward()
+breakpoint()
\ No newline at end of file