diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index e9f7ab13dd2bb0f0e15633f11312a577ff546143..dda201dadc9e99b4b473ae567c65f91242ef72ed 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -198,7 +198,7 @@ class CrossAttention(torch.nn.Module):
         tensor = tensor.view(*new_tensor_shape)
         return tensor.permute(0, 2, 1, 3)
 
-    def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *, **kw_args):
+    def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *_, **kw_args):
         # hidden_states: [b, s, h]
         if 'cross_attention_forward' in self.hooks:
             return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs,
@@ -353,7 +353,7 @@ class BaseTransformerLayer(torch.nn.Module):
             hooks=hooks
         )
 
-    def forward(self, hidden_states, mask, **kw_args):
+    def forward(self, hidden_states, mask, encoder_outputs=None, cross_attention_mask=None, **kw_args):
         '''
             hidden_states: [batch, seq_len, hidden_size]
             mask: [(1, 1), seq_len, seq_len]
@@ -373,10 +373,7 @@ class BaseTransformerLayer(torch.nn.Module):
         # Layer norm post the self attention.
         layernorm_output = self.post_attention_layernorm(layernorm_input)
 
-        if self.is_decoder:
-            encoder_outputs = kw_args['encoder_outputs']
-            if encoder_outputs is not None:
-                cross_attention_mask=kw_args['cross_attention_mask']
+        if self.is_decoder and encoder_outputs is not None:
                 # Cross attention
                 attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args)
                 # Residual connection.