Skip to content
Snippets Groups Projects
Commit 63c5ce73 authored by duzx16's avatar duzx16
Browse files

Fix bug in Transformer implementation

parent 07d5dc0f
No related branches found
No related tags found
No related merge requests found
......@@ -198,7 +198,7 @@ class CrossAttention(torch.nn.Module):
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *, **kw_args):
def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *_, **kw_args):
# hidden_states: [b, s, h]
if 'cross_attention_forward' in self.hooks:
return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs,
......@@ -353,7 +353,7 @@ class BaseTransformerLayer(torch.nn.Module):
hooks=hooks
)
def forward(self, hidden_states, mask, **kw_args):
def forward(self, hidden_states, mask, encoder_outputs=None, cross_attention_mask=None, **kw_args):
'''
hidden_states: [batch, seq_len, hidden_size]
mask: [(1, 1), seq_len, seq_len]
......@@ -373,10 +373,7 @@ class BaseTransformerLayer(torch.nn.Module):
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.is_decoder:
encoder_outputs = kw_args['encoder_outputs']
if encoder_outputs is not None:
cross_attention_mask=kw_args['cross_attention_mask']
if self.is_decoder and encoder_outputs is not None:
# Cross attention
attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args)
# Residual connection.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment