From cc852e00817460b110ff5b3639e5c56b0b72ba68 Mon Sep 17 00:00:00 2001 From: duzx16 <zx-du20@mails.tsinghua.edu.cn> Date: Thu, 2 Dec 2021 21:16:24 +0800 Subject: [PATCH] Implement T5 model --- SwissArmyTransformer/arguments.py | 2 + SwissArmyTransformer/model/__init__.py | 3 +- SwissArmyTransformer/model/base_model.py | 31 +-- .../model/encoder_decoder_model.py | 192 ++++++++-------- SwissArmyTransformer/model/t5_model.py | 190 ++++++++++++++++ SwissArmyTransformer/mpu/transformer.py | 212 ++++++++++++++---- 6 files changed, 474 insertions(+), 156 deletions(-) create mode 100644 SwissArmyTransformer/model/t5_model.py diff --git a/SwissArmyTransformer/arguments.py b/SwissArmyTransformer/arguments.py index dfa75c9..97e95ee 100755 --- a/SwissArmyTransformer/arguments.py +++ b/SwissArmyTransformer/arguments.py @@ -33,6 +33,8 @@ def add_model_config_args(parser): help='num of transformer attention heads') group.add_argument('--hidden-size', type=int, default=1024, help='tansformer hidden size') + group.add_argument('--inner-hidden-size', type=int, default=None) + group.add_argument('--hidden-size-per-attention-head', type=int, default=None) group.add_argument('--num-layers', type=int, default=24, help='num decoder layers') group.add_argument('--layernorm-epsilon', type=float, default=1e-5, diff --git a/SwissArmyTransformer/model/__init__.py b/SwissArmyTransformer/model/__init__.py index 32f46e4..4fbcd53 100755 --- a/SwissArmyTransformer/model/__init__.py +++ b/SwissArmyTransformer/model/__init__.py @@ -2,4 +2,5 @@ from .base_model import BaseModel from .cached_autoregressive_model import CachedAutoregressiveModel from .cuda2d_model import Cuda2dModel from .glm_model import GLMModel -from .encoder_decoder_model import EncoderDecoderModel \ No newline at end of file +from .encoder_decoder_model import EncoderDecoderModel +from .t5_model import T5Model diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py index c9e1c90..5e5793e 100644 --- a/SwissArmyTransformer/model/base_model.py +++ b/SwissArmyTransformer/model/base_model.py @@ -13,20 +13,23 @@ import math import random import torch -from SwissArmyTransformer.mpu import BaseTransformer +from SwissArmyTransformer.mpu import BaseTransformer, LayerNorm + class BaseMixin(torch.nn.Module): def __init__(self): super(BaseMixin, self).__init__() # define new params + def reinit(self, *pre_mixins): # reload the initial params from previous trained modules pass # can also define hook-functions here # ... + class BaseModel(torch.nn.Module): - def __init__(self, args, transformer=None, parallel_output=True): + def __init__(self, args, transformer=None, **kwargs): super(BaseModel, self).__init__() self.mixins = torch.nn.ModuleDict() self.collect_hooks_() @@ -42,14 +45,16 @@ class BaseModel(torch.nn.Module): embedding_dropout_prob=args.hidden_dropout, attention_dropout_prob=args.attention_dropout, output_dropout_prob=args.hidden_dropout, + inner_hidden_size=args.inner_hidden_size, + hidden_size_per_attention_head=args.hidden_size_per_attention_head, checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, sandwich_ln=args.sandwich_ln, - parallel_output=parallel_output, - hooks=self.hooks + hooks=self.hooks, + **kwargs ) - def reinit(self): # will be called when loading model + def reinit(self): # will be called when loading model # if some mixins are loaded, overrides this function for m in self.mixins.values(): m.reinit(self.transformer) @@ -58,11 +63,11 @@ class BaseModel(torch.nn.Module): assert name not in self.mixins assert isinstance(new_mixin, BaseMixin) - self.mixins[name] = new_mixin # will auto-register parameters - object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr + self.mixins[name] = new_mixin # will auto-register parameters + object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr if reinit: - new_mixin.reinit(self.transformer, **self.mixins) # also pass current mixins + new_mixin.reinit(self.transformer, **self.mixins) # also pass current mixins self.collect_hooks_() def del_mixin(self, name): @@ -82,15 +87,15 @@ class BaseModel(torch.nn.Module): def collect_hooks_(self): names = ['word_embedding_forward', 'position_embedding_forward', - 'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward', - 'branch_embedding_forward', 'branch_final_forward' - ] + 'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward', + 'branch_embedding_forward', 'branch_final_forward' + ] hooks = {} hook_origins = {} for name in names: for mixin_name, m in self.mixins.items(): if hasattr(m, name): - if name in hooks: # conflict + if name in hooks: # conflict raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.') hooks[name] = getattr(m, name) hook_origins[name] = mixin_name @@ -104,4 +109,4 @@ class BaseModel(torch.nn.Module): return hooks def disable_untrainable_params(self): - pass \ No newline at end of file + pass diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py index 7e868f2..1beace8 100644 --- a/SwissArmyTransformer/model/encoder_decoder_model.py +++ b/SwissArmyTransformer/model/encoder_decoder_model.py @@ -14,108 +14,90 @@ import random import torch import argparse from .base_model import BaseModel, BaseMixin -from .common_layers import CrossAttention, LayerNorm - - -class CrossAttentionMixin(BaseMixin): - def __init__(self, num_layers, hidden_size, num_attention_heads, - attention_dropout_prob, output_dropout_prob, - init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None): - super().__init__() - - self.cross_attentions = torch.nn.ModuleList( - [CrossAttention( - hidden_size, num_attention_heads, - attention_dropout_prob, output_dropout_prob, - init_method, enc_hidden_size=enc_hidden_size, inner_hidden_size=inner_hidden_size, - output_layer_init_method=output_layer_init_method - ) for layer_id in range(num_layers)] - ) # Just copy args - self.cross_lns = torch.nn.ModuleList( - [LayerNorm(hidden_size, 1e-5) - for layer_id in range(num_layers)] - ) - - def layer_forward(self, hidden_states, mask, layer_id, **kw_args): - layer = self.transformer.layers[layer_id] - encoder_outputs = kw_args['encoder_outputs'] - ''' - hidden_states: [batch, seq_len, hidden_size] - mask: [(1, 1), seq_len, seq_len] - encoder_outputs: [batch, enc_seq_len, enc_hidden_size] - ''' - # Layer norm at the begining of the transformer layer. - layernorm_output = layer.input_layernorm(hidden_states) - attention_output, output_this_layer = layer.attention(layernorm_output, mask, **kw_args) - # Third LayerNorm - if layer.sandwich_ln: - attention_output = layer.third_layernorm(attention_output) - # Residual connection. - hidden_states = hidden_states + attention_output - - # Cross attention. - layernorm_output = self.cross_lns[layer_id](hidden_states) - cross_attn_output = self.cross_attentions[layer_id]( - layernorm_output, - torch.ones(1, 1, device=hidden_states.device, dtype=hidden_states.dtype), - encoder_outputs +from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region +from .common_layers import LayerNorm + + +def get_extended_attention_mask(attention_mask, input_shape, device, dtype=torch.float32, is_decoder=False): + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + dtype: + is_decoder: + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask is None or attention_mask.dim() == 2: + batch_size, seq_length = input_shape + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(dtype) + if attention_mask is not None: + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), + causal_mask], axis=-1) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + if attention_mask is None: + extended_attention_mask = torch.ones(1, 1, 1, seq_length, device=device, dtype=dtype) + else: + extended_attention_mask = attention_mask[:, None, None, :] + elif attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" ) - hidden_states = hidden_states + cross_attn_output - - # Layer norm post the layer attention. - layernorm_output = layer.post_attention_layernorm(hidden_states) - # MLP. - mlp_output = layer.mlp(layernorm_output, **kw_args) - - # Fourth LayerNorm - if layer.sandwich_ln: - mlp_output = layer.fourth_layernorm(mlp_output) - output = hidden_states + mlp_output + return extended_attention_mask - return output, output_this_layer - -class DecoderModel(BaseModel): - def __init__(self, args, transformer=None): - dec_args = argparse.Namespace(**vars(args)) - dec_args.enc_hidden_size = dec_args.hidden_size # used for cross attn - override_attrs = ['num_layers', 'vocab_size', - 'hidden_size', 'num_attention_heads', - 'max_sequence_length', 'sandwich_ln' # TODO - ] - for name in override_attrs: - dec_attr = getattr(dec_args, 'dec_' + name, None) - if dec_attr is not None: # else use encoder-config - setattr(dec_args, name, dec_attr) - - super().__init__(dec_args, transformer=transformer) - self.add_mixin('cross_attention', - CrossAttentionMixin( - dec_args.num_layers, - dec_args.hidden_size, dec_args.num_attention_heads, - dec_args.attention_dropout, dec_args.hidden_dropout, - self.transformer.init_method, - enc_hidden_size=dec_args.enc_hidden_size, - inner_hidden_size=getattr(dec_args, 'dec_inner_hidden_size', None), - output_layer_init_method=self.transformer.output_layer_init_method - ) - ) +class EncoderFinalMixin(BaseMixin): + def final_forward(self, logits, **kwargs): + logits = copy_to_model_parallel_region(logits) + return logits class EncoderDecoderModel(torch.nn.Module): - def __init__(self, args, encoder=None, decoder=None): + def __init__(self, args, encoder=None, decoder=None, parallel_output=False, **kwargs): super(EncoderDecoderModel, self).__init__() if encoder is not None: assert isinstance(encoder, BaseModel) self.encoder = encoder else: - self.encoder = BaseModel(args) - + self.encoder = BaseModel(args, **kwargs) + self.encoder.add_mixin("final", EncoderFinalMixin()) if decoder is not None: assert isinstance(decoder, BaseModel) self.decoder = decoder else: - self.decoder = DecoderModel(args) + dec_args = argparse.Namespace(**vars(args)) + dec_args.enc_hidden_size = dec_args.hidden_size # used for cross attn + override_attrs = ['num_layers', 'hidden_size', 'num_attention_heads', + 'max_sequence_length', 'inner_hidden_size', 'hidden_size_per_attention_head'] + for name in override_attrs: + dec_attr = getattr(dec_args, 'dec_' + name, None) + if dec_attr is not None: # else use encoder-config + setattr(dec_args, name, dec_attr) + self.decoder = BaseModel(args, is_decoder=True, parallel_output=parallel_output, **kwargs) def reinit(self): self.encoder.reinit() @@ -125,23 +107,29 @@ class EncoderDecoderModel(torch.nn.Module): self.encoder.disable_untrainable_params() self.decoder.disable_untrainable_params() - def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids, dec_attention_mask, *, - branch_input=None, **kw_args): - mask_one = torch.ones(1, 1, device=enc_input_ids.device, dtype=dec_attention_mask.dtype) - enc_outputs, *_dumps = self.encoder(enc_input_ids, enc_position_ids, mask_one, branch_input=branch_input, - **kw_args) - dec_outputs, *dec_mems = self.decoder(dec_input_ids, dec_position_ids, dec_attention_mask, - encoder_outputs=enc_outputs, branch_input=branch_input, **kw_args) - return enc_outputs, dec_outputs, *dec_mems + def forward(self, input_ids=None, input_position_ids=None, attention_mask=None, decoder_input_ids=None, + decoder_position_ids=None, decoder_attention_mask=None, + **kw_args): + dtype = self.encoder.transformer.word_embeddings.weight.dtype + batch_size, encoder_seq_length = input_ids.size()[:2] + encoder_attention_mask = get_extended_attention_mask(attention_mask, (batch_size, encoder_seq_length), + device=input_ids.device, dtype=dtype) + decoder_seq_length = decoder_input_ids.size(1) + encoder_outputs, *_dumps = self.encoder(input_ids, input_position_ids, encoder_attention_mask, **kw_args) + decoder_attention_mask = get_extended_attention_mask(decoder_attention_mask, (batch_size, decoder_seq_length), + device=input_ids.device, dtype=dtype) + decoder_outputs, *decoder_mems = self.decoder(decoder_input_ids, decoder_position_ids, decoder_attention_mask, + encoder_outputs=encoder_outputs, + cross_attention_mask=encoder_attention_mask, **kw_args) + return encoder_outputs, decoder_outputs, *decoder_mems @classmethod def add_model_specific_args(cls, parser): group = parser.add_argument_group('EncoderDecoderModel', 'T5 or Bart') - group.add_argument("--dec_num_layers", type=int, default=None) - group.add_argument("--dec_vocab_size", type=int, default=None) - group.add_argument("--dec_hidden_size", type=int, default=None) - group.add_argument("--dec_num_attention_heads", type=int, default=None) - group.add_argument("--dec_max_sequence_length", type=int, default=None) - group.add_argument("--dec_sandwich_ln", action='store_true') - group.add_argument("--dec_inner_hidden_size", type=int, default=None) + group.add_argument("--dec-num-layers", type=int, default=None) + group.add_argument("--dec-hidden-size", type=int, default=None) + group.add_argument("--dec-num-attention-heads", type=int, default=None) + group.add_argument("--dec-max-sequence-length", type=int, default=None) + group.add_argument("--dec-inner-hidden-size", type=int, default=None) + group.add_argument("--dec-hidden-size-per-attention-head", type=int, default=None) return parser diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py new file mode 100644 index 0000000..bf26a3b --- /dev/null +++ b/SwissArmyTransformer/model/t5_model.py @@ -0,0 +1,190 @@ +import math +import torch +from .mixins import BaseMixin +from .encoder_decoder_model import EncoderDecoderModel +from SwissArmyTransformer.mpu import get_model_parallel_world_size +from SwissArmyTransformer.mpu.transformer import standard_attention +from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim + + +class T5PositionEmbeddingMixin(BaseMixin): + def position_embedding_forward(self, position_ids, **kw_args): + return None + + +class T5LayerNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style No bias and no subtraction of mean. + """ + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # layer norm should always be calculated in float32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into float16 if necessary + if self.weight.dtype == torch.float16: + hidden_states = hidden_states.to(torch.float16) + return self.weight * hidden_states + + +class T5AttentionMixin(BaseMixin): + def __init__(self, relative_attention_num_buckets, num_attention_heads, is_decoder=False): + super().__init__() + self.relative_attention_num_buckets = relative_attention_num_buckets + world_size = get_model_parallel_world_size() + self.num_attention_heads_per_partition = divide(num_attention_heads, world_size) + self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, + self.num_attention_heads_per_partition) + self.is_decoder = is_decoder + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, cross_attention=False): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + # shape (query_length, key_length, num_heads) + if cross_attention: + values = self.cross_relative_attention_bias(relative_position_bucket) + else: + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def attention_forward(self, hidden_states, mask, *args, layer_id=None, mems=None, **kw_args): + attn_module = self.transformer.layers[layer_id].attention + seq_length = hidden_states.size(1) + memory_length = mems[layer_id].size(1) if mems else 0 + mixed_raw_layer = attn_module.query_key_value(hidden_states) + (mixed_query_layer, + mixed_key_layer, + mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) + + dropout_fn = attn_module.attention_dropout if attn_module.training else None + + query_layer = attn_module._transpose_for_scores(mixed_query_layer) + key_layer = attn_module._transpose_for_scores(mixed_key_layer) + value_layer = attn_module._transpose_for_scores(mixed_value_layer) + + position_bias = self.compute_bias(seq_length, memory_length + seq_length) + context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn, + log_attention_weights=position_bias, scaling_attention_score=False) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + output = attn_module.dense(context_layer) + + if attn_module.training: + output = attn_module.output_dropout(output) + + return output, None + + def cross_attention_forward(self, hidden_states, cross_mask, encoder_outputs, layer_id=None, *args, **kw_args): + attn_module = self.transformer.layers[layer_id].cross_attention + mixed_query_layer = attn_module.query(hidden_states) + mixed_x_layer = attn_module.key_value(encoder_outputs) + (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2) + + dropout_fn = attn_module.attention_dropout if attn_module.training else None + # Reshape and transpose [b, np, s, hn] + query_layer = attn_module._transpose_for_scores(mixed_query_layer) + key_layer = attn_module._transpose_for_scores(mixed_key_layer) + value_layer = attn_module._transpose_for_scores(mixed_value_layer) + + context_layer = standard_attention(query_layer, key_layer, value_layer, cross_mask, dropout_fn, + scaling_attention_score=False) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,) + # [b, s, hp] + context_layer = context_layer.view(*new_context_layer_shape) + + # Output. [b, s, h] + output = attn_module.dense(context_layer) + if attn_module.training: + output = attn_module.output_dropout(output) + + return output + + +class T5Model(EncoderDecoderModel): + def __init__(self, args, **kwargs): + super().__init__(args, **kwargs, use_bias=False, layernorm=T5LayerNorm) + self.encoder.add_mixin( + "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads) + ) + self.encoder.add_mixin( + "t5-position", T5PositionEmbeddingMixin() + ) + del self.encoder.transformer.position_embeddings + num_attention_heads = args.dec_num_attention_heads if args.dec_num_attention_heads is not None else args.num_attention_heads + self.decoder.add_mixin( + "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, num_attention_heads, is_decoder=True) + ) + self.decoder.add_mixin( + "t5-position", T5PositionEmbeddingMixin() + ) + del self.decoder.transformer.position_embeddings + self.decoder.transformer.word_embeddings = self.encoder.transformer.word_embeddings + + @classmethod + def add_model_specific_args(cls, parser): + super().add_model_specific_args(parser) + parser.add_argument("--relative-attention-num-buckets", type=int, default=None) diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index 5dfac59..7a3234b 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -44,14 +44,13 @@ class LayerNorm(FusedLayerNorm): def standard_attention(query_layer, key_layer, value_layer, attention_mask, - attention_dropout=None, log_attention_weights=None): + attention_dropout=None, log_attention_weights=None, scaling_attention_score=True): # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. - attention_scores = torch.matmul( - query_layer / math.sqrt(query_layer.shape[-1]), - key_layer.transpose(-1, -2) - ) + if scaling_attention_score: + query_layer / math.sqrt(query_layer.shape[-1]) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if log_attention_weights is not None: attention_scores += log_attention_weights @@ -73,7 +72,7 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, class SelfAttention(torch.nn.Module): def __init__(self, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, - init_method, layer_id, output_layer_init_method=None, + init_method, layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True, hooks={}): super(SelfAttention, self).__init__() # Set output layer initialization if not provided. @@ -83,25 +82,31 @@ class SelfAttention(torch.nn.Module): self.layer_id = layer_id # Per attention head and per partition values. world_size = get_model_parallel_world_size() - self.hidden_size_per_partition = divide(hidden_size, world_size) - self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head self.num_attention_heads_per_partition = divide(num_attention_heads, world_size) + inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition # Strided linear layer. self.query_key_value = ColumnParallelLinear( hidden_size, - 3 * hidden_size, + 3 * inner_hidden_size, stride=3, gather_output=False, - init_method=init_method + init_method=init_method, + bias=bias ) self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) self.dense = RowParallelLinear( - hidden_size, + inner_hidden_size, hidden_size, input_is_parallel=True, - init_method=output_layer_init_method + init_method=output_layer_init_method, + bias=bias ) self.output_dropout = torch.nn.Dropout(output_dropout_prob) @@ -115,7 +120,7 @@ class SelfAttention(torch.nn.Module): tensor = tensor.view(*new_tensor_shape) return tensor.permute(0, 2, 1, 3) - def forward(self, hidden_states, mask, **kw_args): + def forward(self, hidden_states, mask, *args, **kw_args): if 'attention_forward' in self.hooks: return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id) else: @@ -142,9 +147,91 @@ class SelfAttention(torch.nn.Module): return output, None +class CrossAttention(torch.nn.Module): + """Parallel cross-attention layer for Transformer""" + + def __init__(self, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, init_method, + layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True, hooks={}): + super().__init__() + # Set output layer initialization if not provided. + if output_layer_init_method is None: + output_layer_init_method = init_method + self.hooks = hooks + self.layer_id = layer_id + # Per attention head and per partition values. + world_size = get_model_parallel_world_size() + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head + self.num_attention_heads_per_partition = divide(num_attention_heads, world_size) + inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition + # Strided linear layer. + self.query = ColumnParallelLinear(hidden_size, inner_hidden_size, + gather_output=False, + init_method=init_method, bias=bias) + self.key_value = ColumnParallelLinear(hidden_size, 2 * inner_hidden_size, + stride=2, + gather_output=False, + init_method=init_method, bias=bias) + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) + + # Output. + self.dense = RowParallelLinear( + inner_hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, bias=bias) + self.output_dropout = torch.nn.Dropout(output_dropout_prob) + + def _transpose_for_scores(self, tensor): + """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with + size [b, np, s, hn]. + """ + new_tensor_shape = tensor.size()[:-1] + \ + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + tensor = tensor.view(*new_tensor_shape) + return tensor.permute(0, 2, 1, 3) + + def forward(self, hidden_states, encoder_outputs, *args, cross_attention_mask=None, **kw_args): + # hidden_states: [b, s, h] + if 'cross_attention_forward' in self.hooks: + return self.hooks['cross_attention_forward'](hidden_states, encoder_outputs, + cross_attention_mask=cross_attention_mask, **kw_args, + layer_id=self.layer_id) + else: + mixed_query_layer = self.query(hidden_states) + mixed_x_layer = self.key_value(encoder_outputs) + (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2) + + dropout_fn = self.attention_dropout if self.training else None + # Reshape and transpose [b, np, s, hn] + query_layer = self._transpose_for_scores(mixed_query_layer) + key_layer = self._transpose_for_scores(mixed_key_layer) + value_layer = self._transpose_for_scores(mixed_value_layer) + + context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + # [b, s, hp] + context_layer = context_layer.view(*new_context_layer_shape) + + # Output. [b, s, h] + output = self.dense(context_layer) + if self.training: + output = self.output_dropout(output) + + return output + + class MLP(torch.nn.Module): - def __init__(self, hidden_size, output_dropout_prob, init_method, - output_layer_init_method=None, layer_id=None, hooks={}): + def __init__(self, hidden_size, output_dropout_prob, init_method, inner_hidden_size=None, + output_layer_init_method=None, layer_id=None, hooks={}, bias=True): super(MLP, self).__init__() self.layer_id = layer_id # Set output layer initialization if not provided. @@ -152,18 +239,22 @@ class MLP(torch.nn.Module): output_layer_init_method = init_method self.hooks = hooks # Project to 4h. + if inner_hidden_size is None: + inner_hidden_size = 4 * hidden_size self.dense_h_to_4h = ColumnParallelLinear( hidden_size, - 4 * hidden_size, + inner_hidden_size, gather_output=False, - init_method=init_method + init_method=init_method, + bias=bias ) # Project back to h. self.dense_4h_to_h = RowParallelLinear( - 4 * hidden_size, + inner_hidden_size, hidden_size, input_is_parallel=True, - init_method=output_layer_init_method + init_method=output_layer_init_method, + bias=bias ) self.dropout = torch.nn.Dropout(output_dropout_prob) @@ -190,8 +281,13 @@ class BaseTransformerLayer(torch.nn.Module): layernorm_epsilon, init_method, layer_id, + inner_hidden_size=None, + hidden_size_per_attention_head=None, output_layer_init_method=None, sandwich_ln=True, + layernorm=LayerNorm, + is_decoder=False, + use_bias=True, hooks={} ): super(BaseTransformerLayer, self).__init__() @@ -199,10 +295,11 @@ class BaseTransformerLayer(torch.nn.Module): if output_layer_init_method is None: output_layer_init_method = init_method self.layer_id = layer_id + self.is_decoder = is_decoder self.hooks = hooks # Layernorm on the input data. - self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) # Self attention. self.attention = SelfAttention( @@ -212,28 +309,48 @@ class BaseTransformerLayer(torch.nn.Module): output_dropout_prob, init_method, layer_id, + hidden_size_per_attention_head=hidden_size_per_attention_head, output_layer_init_method=output_layer_init_method, + bias=use_bias, hooks=hooks ) # Layernorm on the input data. - self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) self.sandwich_ln = sandwich_ln if sandwich_ln: - self.third_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) - self.fourth_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + self.third_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + self.fourth_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + + # Cross attention. + if self.is_decoder: + self.cross_attention = CrossAttention( + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + init_method, + layer_id, + hidden_size_per_attention_head=hidden_size_per_attention_head, + output_layer_init_method=output_layer_init_method, + bias=use_bias, + hooks=hooks + ) + self.post_cross_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) # MLP self.mlp = MLP( hidden_size, output_dropout_prob, init_method, + inner_hidden_size=inner_hidden_size, output_layer_init_method=output_layer_init_method, + bias=use_bias, layer_id=layer_id, hooks=hooks ) - def forward(self, hidden_states, mask, **kw_args): + def forward(self, hidden_states, mask, encoder_outputs=None, **kw_args): ''' hidden_states: [batch, seq_len, hidden_size] mask: [(1, 1), seq_len, seq_len] @@ -252,6 +369,15 @@ class BaseTransformerLayer(torch.nn.Module): layernorm_input = hidden_states + attention_output # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) + + if self.is_decoder and encoder_outputs is not None: + # Cross attention + attention_output = self.cross_attention(layernorm_output, encoder_outputs, **kw_args) + # Residual connection. + layernorm_input = layernorm_output + attention_output + # Layer norm post the cross attention + layernorm_output = self.post_cross_attention_layernorm(layernorm_input) + # MLP. mlp_output = self.mlp(layernorm_output, **kw_args) @@ -279,13 +405,19 @@ class BaseTransformer(torch.nn.Module): checkpoint_num_layers=1, layernorm_epsilon=1.0e-5, init_method_std=0.02, + inner_hidden_size=None, + hidden_size_per_attention_head=None, sandwich_ln=True, parallel_output=True, + is_decoder=False, + use_bias=True, + layernorm=LayerNorm, hooks={} ): super(BaseTransformer, self).__init__() # recording parameters + self.is_decoder = is_decoder self.parallel_output = parallel_output self.checkpoint_activations = checkpoint_activations self.checkpoint_num_layers = checkpoint_num_layers @@ -314,8 +446,13 @@ class BaseTransformer(torch.nn.Module): layernorm_epsilon, self.init_method, layer_id, + inner_hidden_size=inner_hidden_size, + hidden_size_per_attention_head=hidden_size_per_attention_head, output_layer_init_method=self.output_layer_init_method, + is_decoder=self.is_decoder, sandwich_ln=sandwich_ln, + layernorm=layernorm, + use_bias=use_bias, hooks=self.hooks ) @@ -323,10 +460,10 @@ class BaseTransformer(torch.nn.Module): [get_layer(layer_id) for layer_id in range(num_layers)]) # Final layer norm before output. - self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) - def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, output_hidden_states=False, - **kw_args): + def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, encoder_outputs=None, + output_hidden_states=False, **kw_args): # sanity check assert len(input_ids.shape) == 2 batch_size, query_length = input_ids.shape @@ -349,7 +486,8 @@ class BaseTransformer(torch.nn.Module): assert len(position_ids.shape) <= 2 assert position_ids.shape[-1] == query_length position_embeddings = self.position_embeddings(position_ids) - hidden_states = hidden_states + position_embeddings + if position_embeddings is not None: + hidden_states = hidden_states + position_embeddings hidden_states = self.embedding_dropout(hidden_states) hidden_states_outputs = [hidden_states] if output_hidden_states else [] @@ -363,21 +501,15 @@ class BaseTransformer(torch.nn.Module): def custom(start, end): def custom_forward(*inputs): layers_ = self.layers[start:end] - x_, mask = inputs[0], inputs[1] - if len(inputs) > 2: # have branch_input - branch_ = inputs[2] + x_, mask, encoder_outputs_ = inputs[0], inputs[1], inputs[2] output_per_layers_part = [] for i, layer in enumerate(layers_): - if len(inputs) > 2: - x_, branch_, output_this_layer = self.hooks['layer_forward']( - x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args - ) - elif 'layer_forward' in self.hooks: + if 'layer_forward' in self.hooks: x_, output_this_layer = self.hooks['layer_forward']( - x_, mask, layer_id=layer.layer_id, **kw_args + x_, mask, encoder_outputs_, layer_id=layer.layer_id, **kw_args ) else: - x_, output_this_layer = layer(x_, mask, **kw_args) + x_, output_this_layer = layer(x_, mask, encoder_outputs_, **kw_args) output_per_layers_part.append(output_this_layer) return x_, output_per_layers_part @@ -386,7 +518,7 @@ class BaseTransformer(torch.nn.Module): l, num_layers = 0, len(self.layers) chunk_length = self.checkpoint_num_layers while l < num_layers: - args = [hidden_states, attention_mask] + args = [hidden_states, attention_mask, encoder_outputs] if branch_input is not None: hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, branch_input) @@ -398,7 +530,7 @@ class BaseTransformer(torch.nn.Module): l += chunk_length else: for i, layer in enumerate(self.layers): - args = [hidden_states, attention_mask] + args = [hidden_states, attention_mask, encoder_outputs] if branch_input is not None: # customized layer_forward with branch_input hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor( -- GitLab