diff --git a/SwissArmyTransformer/arguments.py b/SwissArmyTransformer/arguments.py
index 5eafe851c4d6064d802532c2b822efe4cefc42a1..804bae08edbd705097b8fb2828443b87066e8a9c 100755
--- a/SwissArmyTransformer/arguments.py
+++ b/SwissArmyTransformer/arguments.py
@@ -300,25 +300,28 @@ def get_args(args_list=None):
         print('using world size: {} and model-parallel size: {} '.format(
             args.world_size, args.model_parallel_size))
 
-    if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_config is not None:
-        with open(args.deepspeed_config) as file:
-            deepspeed_config = json.load(file)
-        if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
-            args.fp16 = True
-        else:
-            args.fp16 = False
+    if hasattr(args, "deepspeed") and args.deepspeed:
         if args.checkpoint_activations:
             args.deepspeed_activation_checkpointing = True
-        if "train_micro_batch_size_per_gpu" in deepspeed_config:
-            args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
-        if "gradient_accumulation_steps" in deepspeed_config:
-            args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
         else:
-            args.gradient_accumulation_steps = None
-        if "optimizer" in deepspeed_config:
-            optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
-            args.lr = optimizer_params_config.get("lr", args.lr)
-            args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
+            args.deepspeed_activation_checkpointing = False
+        if args.deepspeed_config is not None:
+            with open(args.deepspeed_config) as file:
+                deepspeed_config = json.load(file)
+            if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
+                args.fp16 = True
+            else:
+                args.fp16 = False
+            if "train_micro_batch_size_per_gpu" in deepspeed_config:
+                args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
+            if "gradient_accumulation_steps" in deepspeed_config:
+                args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
+            else:
+                args.gradient_accumulation_steps = None
+            if "optimizer" in deepspeed_config:
+                optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
+                args.lr = optimizer_params_config.get("lr", args.lr)
+                args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
     return args
 
 
diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py
index 0c6336c66eb887792e50872d9af768445c5ce030..341a800083d5a317dfc55cde1592915a8296c4ca 100644
--- a/SwissArmyTransformer/model/encoder_decoder_model.py
+++ b/SwissArmyTransformer/model/encoder_decoder_model.py
@@ -64,11 +64,22 @@ class EncoderDecoderModel(torch.nn.Module):
         return encoder_outputs
     
     def decode(self, input_ids, position_ids, attention_mask, encoder_outputs,cross_attention_mask=None, **kw_args):
+        if attention_mask is None:
+            batch_size, seq_length = input_ids.size()[:2]
+            seq_ids = torch.arange(seq_length, device=input_ids.device)
+            attention_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+            attention_mask = attention_mask.to(self.decoder.transformer.word_embeddings.weight.dtype)
+            attention_mask = attention_mask[:, None, :, :]
         # If no context, please explicitly pass ``encoder_outputs=None''
         return self.decoder(input_ids, position_ids, attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args)
     
-    def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids,dec_attention_mask, *, enc_attention_mask=None, cross_attention_mask=None, **kw_args):
+    def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids, *, enc_attention_mask=None, dec_attention_mask=None, cross_attention_mask=None, **kw_args):
         # Please use self.decoder for auto-regressive generation.
+        batch_size, seq_length = enc_input_ids.size()[:2]
+        if enc_attention_mask is None:
+            enc_attention_mask = torch.ones(1, 1, 1, seq_length, dtype=self.encoder.transformer.word_embeddings.weight.dtype, device=enc_input_ids.device)
+        if cross_attention_mask is None:
+            cross_attention_mask = enc_attention_mask
         encoder_outputs = self.encode(enc_input_ids, enc_position_ids, enc_attention_mask, **kw_args)
         decoder_outputs, *mems = self.decode(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args)
         return encoder_outputs, decoder_outputs, *mems
diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
index 944a133371fe591e8f20c6f47c9350040df519a5..531ec54b0062fc616d46a42652a8d486a4659e33 100644
--- a/SwissArmyTransformer/model/t5_model.py
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -180,7 +180,7 @@ class T5DecoderFinalMixin(BaseMixin):
 
 class T5Model(EncoderDecoderModel):
     def __init__(self, args, **kwargs):
-        super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False, 
+        super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False,
         layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu)
         self.encoder.add_mixin(
             "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads)
@@ -205,3 +205,25 @@ class T5Model(EncoderDecoderModel):
     def add_model_specific_args(cls, parser):
         super().add_model_specific_args(parser)
         parser.add_argument("--relative-attention-num-buckets", type=int, default=None)
+
+    def encode(self, input_ids, attention_mask=None, **kw_args):
+        return super().encode(input_ids, None, attention_mask, **kw_args)
+
+    def decode(self, input_ids, attention_mask=None, encoder_outputs=None, cross_attention_mask=None, **kw_args):
+        return super().decode(input_ids, None, attention_mask, encoder_outputs=encoder_outputs,
+                              cross_attention_mask=cross_attention_mask, **kw_args)
+
+    def forward(self, enc_input_ids, dec_input_ids, *, enc_attention_mask=None,
+                dec_attention_mask=None, cross_attention_mask=None, **kw_args):
+        batch_size, seq_length = enc_input_ids.size()[:2]
+        if enc_attention_mask is None:
+            enc_attention_mask = torch.ones(1, 1, 1, seq_length,
+                                            dtype=self.encoder.transformer.word_embeddings.weight.dtype,
+                                            device=enc_input_ids.device)
+        if cross_attention_mask is None:
+            cross_attention_mask = enc_attention_mask
+        encoder_outputs = self.encode(enc_input_ids, enc_attention_mask, **kw_args)
+        decoder_outputs, *mems = self.decode(dec_input_ids, dec_attention_mask,
+                                             encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask,
+                                             **kw_args)
+        return encoder_outputs, decoder_outputs, *mems