From 63c5ce734b853147d562aa832b103bd1dd80863c Mon Sep 17 00:00:00 2001 From: duzx16 <zx-du20@mails.tsinghua.edu.cn> Date: Sun, 5 Dec 2021 23:13:12 +0800 Subject: [PATCH] Fix bug in Transformer implementation --- SwissArmyTransformer/mpu/transformer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index e9f7ab1..dda201d 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -198,7 +198,7 @@ class CrossAttention(torch.nn.Module): tensor = tensor.view(*new_tensor_shape) return tensor.permute(0, 2, 1, 3) - def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *, **kw_args): + def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *_, **kw_args): # hidden_states: [b, s, h] if 'cross_attention_forward' in self.hooks: return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, @@ -353,7 +353,7 @@ class BaseTransformerLayer(torch.nn.Module): hooks=hooks ) - def forward(self, hidden_states, mask, **kw_args): + def forward(self, hidden_states, mask, encoder_outputs=None, cross_attention_mask=None, **kw_args): ''' hidden_states: [batch, seq_len, hidden_size] mask: [(1, 1), seq_len, seq_len] @@ -373,10 +373,7 @@ class BaseTransformerLayer(torch.nn.Module): # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) - if self.is_decoder: - encoder_outputs = kw_args['encoder_outputs'] - if encoder_outputs is not None: - cross_attention_mask=kw_args['cross_attention_mask'] + if self.is_decoder and encoder_outputs is not None: # Cross attention attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args) # Residual connection. -- GitLab