From 6a2fa4083ccafc260943cd8338da6bc6fe40b129 Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Sun, 5 Dec 2021 07:51:24 +0000 Subject: [PATCH] tidy up enc-dec model --- SwissArmyTransformer/model/base_model.py | 3 +- .../model/encoder_decoder_model.py | 93 ++++--------------- SwissArmyTransformer/model/t5_model.py | 5 +- SwissArmyTransformer/mpu/transformer.py | 70 +++++++------- 4 files changed, 59 insertions(+), 112 deletions(-) diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py index 8dfb989..0b593bb 100644 --- a/SwissArmyTransformer/model/base_model.py +++ b/SwissArmyTransformer/model/base_model.py @@ -13,8 +13,7 @@ import math import random import torch -from SwissArmyTransformer.mpu import BaseTransformer, LayerNorm - +from SwissArmyTransformer.mpu import BaseTransformer class BaseMixin(torch.nn.Module): def __init__(self): diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py index e7395b7..0c6336c 100644 --- a/SwissArmyTransformer/model/encoder_decoder_model.py +++ b/SwissArmyTransformer/model/encoder_decoder_model.py @@ -17,60 +17,6 @@ from .base_model import BaseModel, BaseMixin from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region -def get_extended_attention_mask(attention_mask, input_shape, device, dtype=torch.float32, is_decoder=False): - """ - Makes broadcastable attention and causal masks so that future and masked tokens are ignored. - - Arguments: - attention_mask (:obj:`torch.Tensor`): - Mask with ones indicating tokens to attend to, zeros for tokens to ignore. - input_shape (:obj:`Tuple[int]`): - The shape of the input to the model. - device: (:obj:`torch.device`): - The device of the input to the model. - dtype: - is_decoder: - - Returns: - :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. - """ - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if attention_mask is None or attention_mask.dim() == 2: - batch_size, seq_length = input_shape - # Provided a padding mask of dimensions [batch_size, seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] - if is_decoder: - seq_ids = torch.arange(seq_length, device=device) - causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] - # in case past_key_values are used we need to add a prefix ones mask to the causal mask - # causal and attention masks must have same type with pytorch version < 1.3 - causal_mask = causal_mask.to(dtype) - if attention_mask is not None: - if causal_mask.shape[1] < attention_mask.shape[1]: - prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] - causal_mask = torch.cat( - [torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), - causal_mask], axis=-1) - - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] - else: - extended_attention_mask = causal_mask[:, None, :, :] - else: - if attention_mask is None: - extended_attention_mask = torch.ones(1, 1, 1, seq_length, device=device, dtype=dtype) - else: - extended_attention_mask = attention_mask[:, None, None, :] - elif attention_mask.dim() == 3: - extended_attention_mask = attention_mask[:, None, :, :] - else: - raise ValueError( - f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" - ) - return extended_attention_mask - - class EncoderFinalMixin(BaseMixin): def final_forward(self, logits, **kwargs): logits = copy_to_model_parallel_region(logits) @@ -78,7 +24,7 @@ class EncoderFinalMixin(BaseMixin): class EncoderDecoderModel(torch.nn.Module): - def __init__(self, args, encoder=None, decoder=None, parallel_output=False, **kwargs): + def __init__(self, args, encoder=None, decoder=None, tie_word_embeddings=True, parallel_output=False, **kwargs): super(EncoderDecoderModel, self).__init__() if encoder is not None: assert isinstance(encoder, BaseModel) @@ -86,6 +32,7 @@ class EncoderDecoderModel(torch.nn.Module): else: self.encoder = BaseModel(args, **kwargs) self.encoder.add_mixin("final", EncoderFinalMixin()) + if decoder is not None: assert isinstance(decoder, BaseModel) self.decoder = decoder @@ -100,6 +47,10 @@ class EncoderDecoderModel(torch.nn.Module): setattr(dec_args, name, dec_attr) self.decoder = BaseModel(args, is_decoder=True, parallel_output=parallel_output, **kwargs) + self.tie_word_embeddings = tie_word_embeddings + if tie_word_embeddings: + self.decoder.transformer.word_embeddings = self.encoder.transformer.word_embeddings + def reinit(self): self.encoder.reinit() self.decoder.reinit() @@ -108,25 +59,19 @@ class EncoderDecoderModel(torch.nn.Module): self.encoder.disable_untrainable_params() self.decoder.disable_untrainable_params() - def forward(self, input_ids=None, input_position_ids=None, attention_mask=None, decoder_input_ids=None, - decoder_position_ids=None, decoder_attention_mask=None, encoder_outputs=None, - **kw_args): - dtype = self.encoder.transformer.word_embeddings.weight.dtype - if encoder_outputs is None: - batch_size, encoder_seq_length = input_ids.size()[:2] - else: - batch_size, encoder_seq_length = encoder_outputs.size()[:2] - encoder_attention_mask = get_extended_attention_mask(attention_mask, (batch_size, encoder_seq_length), - device=input_ids.device, dtype=dtype) - decoder_seq_length = decoder_input_ids.size(1) - if encoder_outputs is None: - encoder_outputs, *_dumps = self.encoder(input_ids, input_position_ids, encoder_attention_mask, **kw_args) - decoder_attention_mask = get_extended_attention_mask(decoder_attention_mask, (batch_size, decoder_seq_length), - device=input_ids.device, dtype=dtype, is_decoder=True) - decoder_outputs, *decoder_mems = self.decoder(decoder_input_ids, decoder_position_ids, decoder_attention_mask, - encoder_outputs=encoder_outputs, - cross_attention_mask=encoder_attention_mask, **kw_args) - return encoder_outputs, decoder_outputs, *decoder_mems + def encode(self, input_ids, position_ids, attention_mask=None, **kw_args): + encoder_outputs, *_dumps = self.encoder(input_ids, position_ids, attention_mask, **kw_args) + return encoder_outputs + + def decode(self, input_ids, position_ids, attention_mask, encoder_outputs,cross_attention_mask=None, **kw_args): + # If no context, please explicitly pass ``encoder_outputs=None'' + return self.decoder(input_ids, position_ids, attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args) + + def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids,dec_attention_mask, *, enc_attention_mask=None, cross_attention_mask=None, **kw_args): + # Please use self.decoder for auto-regressive generation. + encoder_outputs = self.encode(enc_input_ids, enc_position_ids, enc_attention_mask, **kw_args) + decoder_outputs, *mems = self.decode(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args) + return encoder_outputs, decoder_outputs, *mems @classmethod def add_model_specific_args(cls, parser): diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py index bf4f2c9..944a133 100644 --- a/SwissArmyTransformer/model/t5_model.py +++ b/SwissArmyTransformer/model/t5_model.py @@ -180,8 +180,8 @@ class T5DecoderFinalMixin(BaseMixin): class T5Model(EncoderDecoderModel): def __init__(self, args, **kwargs): - super().__init__(args, **kwargs, use_bias=False, layernorm=T5LayerNorm, - activation_func=torch.nn.functional.relu) + super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False, + layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu) self.encoder.add_mixin( "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads) ) @@ -200,7 +200,6 @@ class T5Model(EncoderDecoderModel): "t5-final", T5DecoderFinalMixin(args.hidden_size) ) del self.decoder.transformer.position_embeddings - self.decoder.transformer.word_embeddings = self.encoder.transformer.word_embeddings @classmethod def add_model_specific_args(cls, parser): diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index 538d965..e9f7ab1 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -198,7 +198,7 @@ class CrossAttention(torch.nn.Module): tensor = tensor.view(*new_tensor_shape) return tensor.permute(0, 2, 1, 3) - def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *args, **kw_args): + def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *, **kw_args): # hidden_states: [b, s, h] if 'cross_attention_forward' in self.hooks: return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, @@ -353,7 +353,7 @@ class BaseTransformerLayer(torch.nn.Module): hooks=hooks ) - def forward(self, hidden_states, mask, encoder_outputs=None, cross_attention_mask=None, **kw_args): + def forward(self, hidden_states, mask, **kw_args): ''' hidden_states: [batch, seq_len, hidden_size] mask: [(1, 1), seq_len, seq_len] @@ -373,13 +373,16 @@ class BaseTransformerLayer(torch.nn.Module): # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) - if self.is_decoder and encoder_outputs is not None: - # Cross attention - attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args) - # Residual connection. - layernorm_input = layernorm_input + attention_output - # Layer norm post the cross attention - layernorm_output = self.post_cross_attention_layernorm(layernorm_input) + if self.is_decoder: + encoder_outputs = kw_args['encoder_outputs'] + if encoder_outputs is not None: + cross_attention_mask=kw_args['cross_attention_mask'] + # Cross attention + attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args) + # Residual connection. + layernorm_input = layernorm_input + attention_output + # Layer norm post the cross attention + layernorm_output = self.post_cross_attention_layernorm(layernorm_input) # MLP. mlp_output = self.mlp(layernorm_output, **kw_args) @@ -467,12 +470,15 @@ class BaseTransformer(torch.nn.Module): # Final layer norm before output. self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) - def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, encoder_outputs=None, + def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, output_hidden_states=False, **kw_args): - breakpoint() # sanity check assert len(input_ids.shape) == 2 batch_size, query_length = input_ids.shape + if attention_mask is None: + attention_mask = torch.ones(1, 1, device=input_ids.device).type_as( + next(self.parameters()) + ) # None means full attention assert len(attention_mask.shape) == 2 or \ len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1 assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor) @@ -507,34 +513,37 @@ class BaseTransformer(torch.nn.Module): def custom(start, end): def custom_forward(*inputs): layers_ = self.layers[start:end] - x_, mask, encoder_outputs_ = inputs[0], inputs[1], inputs[2] + x_, mask = inputs[0], inputs[1] + if len(inputs) > 2: # have branch_input + branch_ = inputs[2] output_per_layers_part = [] for i, layer in enumerate(layers_): - if branch_input is not None: - x_, encoder_outputs_, output_this_layer = self.hooks['layer_forward']( - x_, mask, layer_id=layer.layer_id, branch_input=encoder_outputs_, **kw_args + if len(inputs) > 2: + x_, branch_, output_this_layer = self.hooks['layer_forward']( + x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args ) elif 'layer_forward' in self.hooks: x_, output_this_layer = self.hooks['layer_forward']( - x_, mask, encoder_outputs_, layer_id=layer.layer_id, **kw_args + x_, mask, layer_id=layer.layer_id, **kw_args ) else: - x_, output_this_layer = layer(x_, mask, encoder_outputs_, **kw_args) + x_, output_this_layer = layer(x_, mask, **kw_args) output_per_layers_part.append(output_this_layer) return x_, output_per_layers_part - return custom_forward - l, num_layers = 0, len(self.layers) - chunk_length = self.checkpoint_num_layers + # prevent to lose requires_grad in checkpointing. + # To save memory when only finetuning the final layers, don't use checkpointing. if self.training: hidden_states.requires_grad_(True) + + l, num_layers = 0, len(self.layers) + chunk_length = self.checkpoint_num_layers while l < num_layers: + args = [hidden_states, attention_mask] if branch_input is not None: - args = [hidden_states, attention_mask, branch_input] - hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args) + hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, branch_input) else: - args = [hidden_states, attention_mask, encoder_outputs] hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args) if output_hidden_states: hidden_states_outputs.append(hidden_states) @@ -542,16 +551,11 @@ class BaseTransformer(torch.nn.Module): l += chunk_length else: for i, layer in enumerate(self.layers): - args = [hidden_states, attention_mask, encoder_outputs] - if branch_input is not None: # customized layer_forward with branch_input - hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, - layer_id=torch.tensor( - i), - branch_input=branch_input, - **kw_args) - elif 'layer_forward' in self.hooks: # customized layer_forward - hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), - **kw_args) + args = [hidden_states, attention_mask] + if branch_input is not None: # customized layer_forward with branch_input + hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_args) + elif 'layer_forward' in self.hooks: # customized layer_forward + hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args) else: hidden_states, output_this_layer = layer(*args, **kw_args) if output_hidden_states: -- GitLab