diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index dda201dadc9e99b4b473ae567c65f91242ef72ed..855e093ada69e1168b565f78f15eeaee313326e5 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -198,7 +198,7 @@ class CrossAttention(torch.nn.Module):
         tensor = tensor.view(*new_tensor_shape)
         return tensor.permute(0, 2, 1, 3)
 
-    def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *_, **kw_args):
+    def forward(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args):
         # hidden_states: [b, s, h]
         if 'cross_attention_forward' in self.hooks:
             return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs,
@@ -353,7 +353,7 @@ class BaseTransformerLayer(torch.nn.Module):
             hooks=hooks
         )
 
-    def forward(self, hidden_states, mask, encoder_outputs=None, cross_attention_mask=None, **kw_args):
+    def forward(self, hidden_states, mask, **kw_args):
         '''
             hidden_states: [batch, seq_len, hidden_size]
             mask: [(1, 1), seq_len, seq_len]
@@ -373,9 +373,12 @@ class BaseTransformerLayer(torch.nn.Module):
         # Layer norm post the self attention.
         layernorm_output = self.post_attention_layernorm(layernorm_input)
 
-        if self.is_decoder and encoder_outputs is not None:
+        # only for Encoder-Decoder, omit this for BERT-like or GPT-like models
+        if self.is_decoder and \
+            'encoder_outputs' in kw_args and kw_args['encoder_outputs'] is not None:
+                assert 'cross_attention_mask' in kw_args
                 # Cross attention
-                attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args)
+                attention_output = self.cross_attention(layernorm_output, **kw_args)
                 # Residual connection.
                 layernorm_input = layernorm_input + attention_output
                 # Layer norm post the cross attention
diff --git a/inference_t5.py b/examples/t5/inference_t5.py
similarity index 100%
rename from inference_t5.py
rename to examples/t5/inference_t5.py
diff --git a/test_t5.py b/examples/t5/test_t5.py
similarity index 96%
rename from test_t5.py
rename to examples/t5/test_t5.py
index e464bd956056caa4b728d43b7304bd42439af6a2..805f22811c7dfdcd198203c35c1c816680b1a834 100644
--- a/test_t5.py
+++ b/examples/t5/test_t5.py
@@ -3,4 +3,5 @@ tokenizer = T5Tokenizer.from_pretrained("t5-large")
 model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large")
 input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
 decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
-output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
\ No newline at end of file
+output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
+breakpoint()
\ No newline at end of file