diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index dda201dadc9e99b4b473ae567c65f91242ef72ed..855e093ada69e1168b565f78f15eeaee313326e5 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -198,7 +198,7 @@ class CrossAttention(torch.nn.Module): tensor = tensor.view(*new_tensor_shape) return tensor.permute(0, 2, 1, 3) - def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *_, **kw_args): + def forward(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args): # hidden_states: [b, s, h] if 'cross_attention_forward' in self.hooks: return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, @@ -353,7 +353,7 @@ class BaseTransformerLayer(torch.nn.Module): hooks=hooks ) - def forward(self, hidden_states, mask, encoder_outputs=None, cross_attention_mask=None, **kw_args): + def forward(self, hidden_states, mask, **kw_args): ''' hidden_states: [batch, seq_len, hidden_size] mask: [(1, 1), seq_len, seq_len] @@ -373,9 +373,12 @@ class BaseTransformerLayer(torch.nn.Module): # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) - if self.is_decoder and encoder_outputs is not None: + # only for Encoder-Decoder, omit this for BERT-like or GPT-like models + if self.is_decoder and \ + 'encoder_outputs' in kw_args and kw_args['encoder_outputs'] is not None: + assert 'cross_attention_mask' in kw_args # Cross attention - attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args) + attention_output = self.cross_attention(layernorm_output, **kw_args) # Residual connection. layernorm_input = layernorm_input + attention_output # Layer norm post the cross attention diff --git a/inference_t5.py b/examples/t5/inference_t5.py similarity index 100% rename from inference_t5.py rename to examples/t5/inference_t5.py diff --git a/test_t5.py b/examples/t5/test_t5.py similarity index 96% rename from test_t5.py rename to examples/t5/test_t5.py index e464bd956056caa4b728d43b7304bd42439af6a2..805f22811c7dfdcd198203c35c1c816680b1a834 100644 --- a/test_t5.py +++ b/examples/t5/test_t5.py @@ -3,4 +3,5 @@ tokenizer = T5Tokenizer.from_pretrained("t5-large") model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large") input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids -output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) \ No newline at end of file +output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) +breakpoint() \ No newline at end of file