Skip to content
Snippets Groups Projects
Commit bd2d86f9 authored by Ming Ding's avatar Ming Ding
Browse files

remove layer_forward enc-dec params for simple

parent 049b6ffc
No related branches found
No related tags found
No related merge requests found
......@@ -198,7 +198,7 @@ class CrossAttention(torch.nn.Module):
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *_, **kw_args):
def forward(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args):
# hidden_states: [b, s, h]
if 'cross_attention_forward' in self.hooks:
return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs,
......@@ -353,7 +353,7 @@ class BaseTransformerLayer(torch.nn.Module):
hooks=hooks
)
def forward(self, hidden_states, mask, encoder_outputs=None, cross_attention_mask=None, **kw_args):
def forward(self, hidden_states, mask, **kw_args):
'''
hidden_states: [batch, seq_len, hidden_size]
mask: [(1, 1), seq_len, seq_len]
......@@ -373,9 +373,12 @@ class BaseTransformerLayer(torch.nn.Module):
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.is_decoder and encoder_outputs is not None:
# only for Encoder-Decoder, omit this for BERT-like or GPT-like models
if self.is_decoder and \
'encoder_outputs' in kw_args and kw_args['encoder_outputs'] is not None:
assert 'cross_attention_mask' in kw_args
# Cross attention
attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args)
attention_output = self.cross_attention(layernorm_output, **kw_args)
# Residual connection.
layernorm_input = layernorm_input + attention_output
# Layer norm post the cross attention
......
File moved
......@@ -3,4 +3,5 @@ tokenizer = T5Tokenizer.from_pretrained("t5-large")
model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large")
input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
\ No newline at end of file
output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
breakpoint()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment