Skip to content
Snippets Groups Projects
Commit 7c573bbd authored by duzx16's avatar duzx16
Browse files

Use assert to ensure cross_attention_mask

Add default value (None) for dec_attention_mask in T5 forward
parent be5e6c90
No related branches found
No related tags found
No related merge requests found
......@@ -254,7 +254,7 @@ class T5Model(EncoderDecoderModel):
return super().decode(input_ids, None, attention_mask, encoder_outputs=encoder_outputs,
cross_attention_mask=cross_attention_mask, **kw_args)
def forward(self, enc_input_ids, dec_input_ids, dec_attention_mask, *, enc_attention_mask=None,
def forward(self, enc_input_ids, dec_input_ids, *, enc_attention_mask=None, dec_attention_mask=None,
cross_attention_mask=None, **kw_args):
batch_size, seq_length = enc_input_ids.size()[:2]
if enc_attention_mask is None:
......
......@@ -394,7 +394,7 @@ class BaseTransformerLayer(torch.nn.Module):
if self.is_decoder:
encoder_outputs = kw_args['encoder_outputs']
if encoder_outputs is not None:
cross_attention_mask = kw_args['cross_attention_mask']
assert 'cross_attention_mask' in kw_args
# Cross attention
attention_output = self.cross_attention(layernorm_output, **kw_args)
# Residual connection.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment