Skip to content
Snippets Groups Projects
Commit 63c5ce73 authored by duzx16's avatar duzx16
Browse files

Fix bug in Transformer implementation

parent 07d5dc0f
Branches
Tags
No related merge requests found
......@@ -198,7 +198,7 @@ class CrossAttention(torch.nn.Module):
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *, **kw_args):
def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *_, **kw_args):
# hidden_states: [b, s, h]
if 'cross_attention_forward' in self.hooks:
return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs,
......@@ -353,7 +353,7 @@ class BaseTransformerLayer(torch.nn.Module):
hooks=hooks
)
def forward(self, hidden_states, mask, **kw_args):
def forward(self, hidden_states, mask, encoder_outputs=None, cross_attention_mask=None, **kw_args):
'''
hidden_states: [batch, seq_len, hidden_size]
mask: [(1, 1), seq_len, seq_len]
......@@ -373,10 +373,7 @@ class BaseTransformerLayer(torch.nn.Module):
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.is_decoder:
encoder_outputs = kw_args['encoder_outputs']
if encoder_outputs is not None:
cross_attention_mask=kw_args['cross_attention_mask']
if self.is_decoder and encoder_outputs is not None:
# Cross attention
attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args)
# Residual connection.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment