Skip to content
Snippets Groups Projects
Commit cc852e00 authored by duzx16's avatar duzx16
Browse files

Implement T5 model

parent 7aee8e24
No related branches found
No related tags found
No related merge requests found
...@@ -33,6 +33,8 @@ def add_model_config_args(parser): ...@@ -33,6 +33,8 @@ def add_model_config_args(parser):
help='num of transformer attention heads') help='num of transformer attention heads')
group.add_argument('--hidden-size', type=int, default=1024, group.add_argument('--hidden-size', type=int, default=1024,
help='tansformer hidden size') help='tansformer hidden size')
group.add_argument('--inner-hidden-size', type=int, default=None)
group.add_argument('--hidden-size-per-attention-head', type=int, default=None)
group.add_argument('--num-layers', type=int, default=24, group.add_argument('--num-layers', type=int, default=24,
help='num decoder layers') help='num decoder layers')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5, group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
......
...@@ -2,4 +2,5 @@ from .base_model import BaseModel ...@@ -2,4 +2,5 @@ from .base_model import BaseModel
from .cached_autoregressive_model import CachedAutoregressiveModel from .cached_autoregressive_model import CachedAutoregressiveModel
from .cuda2d_model import Cuda2dModel from .cuda2d_model import Cuda2dModel
from .glm_model import GLMModel from .glm_model import GLMModel
from .encoder_decoder_model import EncoderDecoderModel from .encoder_decoder_model import EncoderDecoderModel
\ No newline at end of file from .t5_model import T5Model
...@@ -13,20 +13,23 @@ import math ...@@ -13,20 +13,23 @@ import math
import random import random
import torch import torch
from SwissArmyTransformer.mpu import BaseTransformer from SwissArmyTransformer.mpu import BaseTransformer, LayerNorm
class BaseMixin(torch.nn.Module): class BaseMixin(torch.nn.Module):
def __init__(self): def __init__(self):
super(BaseMixin, self).__init__() super(BaseMixin, self).__init__()
# define new params # define new params
def reinit(self, *pre_mixins): def reinit(self, *pre_mixins):
# reload the initial params from previous trained modules # reload the initial params from previous trained modules
pass pass
# can also define hook-functions here # can also define hook-functions here
# ... # ...
class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
def __init__(self, args, transformer=None, parallel_output=True): def __init__(self, args, transformer=None, **kwargs):
super(BaseModel, self).__init__() super(BaseModel, self).__init__()
self.mixins = torch.nn.ModuleDict() self.mixins = torch.nn.ModuleDict()
self.collect_hooks_() self.collect_hooks_()
...@@ -42,14 +45,16 @@ class BaseModel(torch.nn.Module): ...@@ -42,14 +45,16 @@ class BaseModel(torch.nn.Module):
embedding_dropout_prob=args.hidden_dropout, embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout, attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout, output_dropout_prob=args.hidden_dropout,
inner_hidden_size=args.inner_hidden_size,
hidden_size_per_attention_head=args.hidden_size_per_attention_head,
checkpoint_activations=args.checkpoint_activations, checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers, checkpoint_num_layers=args.checkpoint_num_layers,
sandwich_ln=args.sandwich_ln, sandwich_ln=args.sandwich_ln,
parallel_output=parallel_output, hooks=self.hooks,
hooks=self.hooks **kwargs
) )
def reinit(self): # will be called when loading model def reinit(self): # will be called when loading model
# if some mixins are loaded, overrides this function # if some mixins are loaded, overrides this function
for m in self.mixins.values(): for m in self.mixins.values():
m.reinit(self.transformer) m.reinit(self.transformer)
...@@ -58,11 +63,11 @@ class BaseModel(torch.nn.Module): ...@@ -58,11 +63,11 @@ class BaseModel(torch.nn.Module):
assert name not in self.mixins assert name not in self.mixins
assert isinstance(new_mixin, BaseMixin) assert isinstance(new_mixin, BaseMixin)
self.mixins[name] = new_mixin # will auto-register parameters self.mixins[name] = new_mixin # will auto-register parameters
object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr
if reinit: if reinit:
new_mixin.reinit(self.transformer, **self.mixins) # also pass current mixins new_mixin.reinit(self.transformer, **self.mixins) # also pass current mixins
self.collect_hooks_() self.collect_hooks_()
def del_mixin(self, name): def del_mixin(self, name):
...@@ -82,15 +87,15 @@ class BaseModel(torch.nn.Module): ...@@ -82,15 +87,15 @@ class BaseModel(torch.nn.Module):
def collect_hooks_(self): def collect_hooks_(self):
names = ['word_embedding_forward', 'position_embedding_forward', names = ['word_embedding_forward', 'position_embedding_forward',
'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward', 'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
'branch_embedding_forward', 'branch_final_forward' 'branch_embedding_forward', 'branch_final_forward'
] ]
hooks = {} hooks = {}
hook_origins = {} hook_origins = {}
for name in names: for name in names:
for mixin_name, m in self.mixins.items(): for mixin_name, m in self.mixins.items():
if hasattr(m, name): if hasattr(m, name):
if name in hooks: # conflict if name in hooks: # conflict
raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.') raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
hooks[name] = getattr(m, name) hooks[name] = getattr(m, name)
hook_origins[name] = mixin_name hook_origins[name] = mixin_name
...@@ -104,4 +109,4 @@ class BaseModel(torch.nn.Module): ...@@ -104,4 +109,4 @@ class BaseModel(torch.nn.Module):
return hooks return hooks
def disable_untrainable_params(self): def disable_untrainable_params(self):
pass pass
\ No newline at end of file
...@@ -14,108 +14,90 @@ import random ...@@ -14,108 +14,90 @@ import random
import torch import torch
import argparse import argparse
from .base_model import BaseModel, BaseMixin from .base_model import BaseModel, BaseMixin
from .common_layers import CrossAttention, LayerNorm from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
from .common_layers import LayerNorm
class CrossAttentionMixin(BaseMixin):
def __init__(self, num_layers, hidden_size, num_attention_heads, def get_extended_attention_mask(attention_mask, input_shape, device, dtype=torch.float32, is_decoder=False):
attention_dropout_prob, output_dropout_prob, """
init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None): Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
super().__init__()
Arguments:
self.cross_attentions = torch.nn.ModuleList( attention_mask (:obj:`torch.Tensor`):
[CrossAttention( Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
hidden_size, num_attention_heads, input_shape (:obj:`Tuple[int]`):
attention_dropout_prob, output_dropout_prob, The shape of the input to the model.
init_method, enc_hidden_size=enc_hidden_size, inner_hidden_size=inner_hidden_size, device: (:obj:`torch.device`):
output_layer_init_method=output_layer_init_method The device of the input to the model.
) for layer_id in range(num_layers)] dtype:
) # Just copy args is_decoder:
self.cross_lns = torch.nn.ModuleList(
[LayerNorm(hidden_size, 1e-5) Returns:
for layer_id in range(num_layers)] :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
) """
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
def layer_forward(self, hidden_states, mask, layer_id, **kw_args): # ourselves in which case we just need to make it broadcastable to all heads.
layer = self.transformer.layers[layer_id] if attention_mask is None or attention_mask.dim() == 2:
encoder_outputs = kw_args['encoder_outputs'] batch_size, seq_length = input_shape
''' # Provided a padding mask of dimensions [batch_size, seq_length]
hidden_states: [batch, seq_len, hidden_size] # - if the model is a decoder, apply a causal mask in addition to the padding mask
mask: [(1, 1), seq_len, seq_len] # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
encoder_outputs: [batch, enc_seq_len, enc_hidden_size] if is_decoder:
''' seq_ids = torch.arange(seq_length, device=device)
# Layer norm at the begining of the transformer layer. causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
layernorm_output = layer.input_layernorm(hidden_states) # in case past_key_values are used we need to add a prefix ones mask to the causal mask
attention_output, output_this_layer = layer.attention(layernorm_output, mask, **kw_args) # causal and attention masks must have same type with pytorch version < 1.3
# Third LayerNorm causal_mask = causal_mask.to(dtype)
if layer.sandwich_ln: if attention_mask is not None:
attention_output = layer.third_layernorm(attention_output) if causal_mask.shape[1] < attention_mask.shape[1]:
# Residual connection. prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
hidden_states = hidden_states + attention_output causal_mask = torch.cat(
[torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
# Cross attention. causal_mask], axis=-1)
layernorm_output = self.cross_lns[layer_id](hidden_states)
cross_attn_output = self.cross_attentions[layer_id]( extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
layernorm_output, else:
torch.ones(1, 1, device=hidden_states.device, dtype=hidden_states.dtype), if attention_mask is None:
encoder_outputs extended_attention_mask = torch.ones(1, 1, 1, seq_length, device=device, dtype=dtype)
else:
extended_attention_mask = attention_mask[:, None, None, :]
elif attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
else:
raise ValueError(
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
) )
hidden_states = hidden_states + cross_attn_output return extended_attention_mask
# Layer norm post the layer attention.
layernorm_output = layer.post_attention_layernorm(hidden_states)
# MLP.
mlp_output = layer.mlp(layernorm_output, **kw_args)
# Fourth LayerNorm
if layer.sandwich_ln:
mlp_output = layer.fourth_layernorm(mlp_output)
output = hidden_states + mlp_output
return output, output_this_layer
class EncoderFinalMixin(BaseMixin):
class DecoderModel(BaseModel): def final_forward(self, logits, **kwargs):
def __init__(self, args, transformer=None): logits = copy_to_model_parallel_region(logits)
dec_args = argparse.Namespace(**vars(args)) return logits
dec_args.enc_hidden_size = dec_args.hidden_size # used for cross attn
override_attrs = ['num_layers', 'vocab_size',
'hidden_size', 'num_attention_heads',
'max_sequence_length', 'sandwich_ln' # TODO
]
for name in override_attrs:
dec_attr = getattr(dec_args, 'dec_' + name, None)
if dec_attr is not None: # else use encoder-config
setattr(dec_args, name, dec_attr)
super().__init__(dec_args, transformer=transformer)
self.add_mixin('cross_attention',
CrossAttentionMixin(
dec_args.num_layers,
dec_args.hidden_size, dec_args.num_attention_heads,
dec_args.attention_dropout, dec_args.hidden_dropout,
self.transformer.init_method,
enc_hidden_size=dec_args.enc_hidden_size,
inner_hidden_size=getattr(dec_args, 'dec_inner_hidden_size', None),
output_layer_init_method=self.transformer.output_layer_init_method
)
)
class EncoderDecoderModel(torch.nn.Module): class EncoderDecoderModel(torch.nn.Module):
def __init__(self, args, encoder=None, decoder=None): def __init__(self, args, encoder=None, decoder=None, parallel_output=False, **kwargs):
super(EncoderDecoderModel, self).__init__() super(EncoderDecoderModel, self).__init__()
if encoder is not None: if encoder is not None:
assert isinstance(encoder, BaseModel) assert isinstance(encoder, BaseModel)
self.encoder = encoder self.encoder = encoder
else: else:
self.encoder = BaseModel(args) self.encoder = BaseModel(args, **kwargs)
self.encoder.add_mixin("final", EncoderFinalMixin())
if decoder is not None: if decoder is not None:
assert isinstance(decoder, BaseModel) assert isinstance(decoder, BaseModel)
self.decoder = decoder self.decoder = decoder
else: else:
self.decoder = DecoderModel(args) dec_args = argparse.Namespace(**vars(args))
dec_args.enc_hidden_size = dec_args.hidden_size # used for cross attn
override_attrs = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_sequence_length', 'inner_hidden_size', 'hidden_size_per_attention_head']
for name in override_attrs:
dec_attr = getattr(dec_args, 'dec_' + name, None)
if dec_attr is not None: # else use encoder-config
setattr(dec_args, name, dec_attr)
self.decoder = BaseModel(args, is_decoder=True, parallel_output=parallel_output, **kwargs)
def reinit(self): def reinit(self):
self.encoder.reinit() self.encoder.reinit()
...@@ -125,23 +107,29 @@ class EncoderDecoderModel(torch.nn.Module): ...@@ -125,23 +107,29 @@ class EncoderDecoderModel(torch.nn.Module):
self.encoder.disable_untrainable_params() self.encoder.disable_untrainable_params()
self.decoder.disable_untrainable_params() self.decoder.disable_untrainable_params()
def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids, dec_attention_mask, *, def forward(self, input_ids=None, input_position_ids=None, attention_mask=None, decoder_input_ids=None,
branch_input=None, **kw_args): decoder_position_ids=None, decoder_attention_mask=None,
mask_one = torch.ones(1, 1, device=enc_input_ids.device, dtype=dec_attention_mask.dtype) **kw_args):
enc_outputs, *_dumps = self.encoder(enc_input_ids, enc_position_ids, mask_one, branch_input=branch_input, dtype = self.encoder.transformer.word_embeddings.weight.dtype
**kw_args) batch_size, encoder_seq_length = input_ids.size()[:2]
dec_outputs, *dec_mems = self.decoder(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_attention_mask = get_extended_attention_mask(attention_mask, (batch_size, encoder_seq_length),
encoder_outputs=enc_outputs, branch_input=branch_input, **kw_args) device=input_ids.device, dtype=dtype)
return enc_outputs, dec_outputs, *dec_mems decoder_seq_length = decoder_input_ids.size(1)
encoder_outputs, *_dumps = self.encoder(input_ids, input_position_ids, encoder_attention_mask, **kw_args)
decoder_attention_mask = get_extended_attention_mask(decoder_attention_mask, (batch_size, decoder_seq_length),
device=input_ids.device, dtype=dtype)
decoder_outputs, *decoder_mems = self.decoder(decoder_input_ids, decoder_position_ids, decoder_attention_mask,
encoder_outputs=encoder_outputs,
cross_attention_mask=encoder_attention_mask, **kw_args)
return encoder_outputs, decoder_outputs, *decoder_mems
@classmethod @classmethod
def add_model_specific_args(cls, parser): def add_model_specific_args(cls, parser):
group = parser.add_argument_group('EncoderDecoderModel', 'T5 or Bart') group = parser.add_argument_group('EncoderDecoderModel', 'T5 or Bart')
group.add_argument("--dec_num_layers", type=int, default=None) group.add_argument("--dec-num-layers", type=int, default=None)
group.add_argument("--dec_vocab_size", type=int, default=None) group.add_argument("--dec-hidden-size", type=int, default=None)
group.add_argument("--dec_hidden_size", type=int, default=None) group.add_argument("--dec-num-attention-heads", type=int, default=None)
group.add_argument("--dec_num_attention_heads", type=int, default=None) group.add_argument("--dec-max-sequence-length", type=int, default=None)
group.add_argument("--dec_max_sequence_length", type=int, default=None) group.add_argument("--dec-inner-hidden-size", type=int, default=None)
group.add_argument("--dec_sandwich_ln", action='store_true') group.add_argument("--dec-hidden-size-per-attention-head", type=int, default=None)
group.add_argument("--dec_inner_hidden_size", type=int, default=None)
return parser return parser
import math
import torch
from .mixins import BaseMixin
from .encoder_decoder_model import EncoderDecoderModel
from SwissArmyTransformer.mpu import get_model_parallel_world_size
from SwissArmyTransformer.mpu.transformer import standard_attention
from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim
class T5PositionEmbeddingMixin(BaseMixin):
def position_embedding_forward(self, position_ids, **kw_args):
return None
class T5LayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
"""
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
# layer norm should always be calculated in float32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into float16 if necessary
if self.weight.dtype == torch.float16:
hidden_states = hidden_states.to(torch.float16)
return self.weight * hidden_states
class T5AttentionMixin(BaseMixin):
def __init__(self, relative_attention_num_buckets, num_attention_heads, is_decoder=False):
super().__init__()
self.relative_attention_num_buckets = relative_attention_num_buckets
world_size = get_model_parallel_world_size()
self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets,
self.num_attention_heads_per_partition)
self.is_decoder = is_decoder
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_postion_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_postion_if_large = torch.min(
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, cross_attention=False):
"""Compute binned relative position bias"""
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
)
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
# shape (query_length, key_length, num_heads)
if cross_attention:
values = self.cross_relative_attention_bias(relative_position_bucket)
else:
values = self.relative_attention_bias(relative_position_bucket)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
def attention_forward(self, hidden_states, mask, *args, layer_id=None, mems=None, **kw_args):
attn_module = self.transformer.layers[layer_id].attention
seq_length = hidden_states.size(1)
memory_length = mems[layer_id].size(1) if mems else 0
mixed_raw_layer = attn_module.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
dropout_fn = attn_module.attention_dropout if attn_module.training else None
query_layer = attn_module._transpose_for_scores(mixed_query_layer)
key_layer = attn_module._transpose_for_scores(mixed_key_layer)
value_layer = attn_module._transpose_for_scores(mixed_value_layer)
position_bias = self.compute_bias(seq_length, memory_length + seq_length)
context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn,
log_attention_weights=position_bias, scaling_attention_score=False)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = attn_module.dense(context_layer)
if attn_module.training:
output = attn_module.output_dropout(output)
return output, None
def cross_attention_forward(self, hidden_states, cross_mask, encoder_outputs, layer_id=None, *args, **kw_args):
attn_module = self.transformer.layers[layer_id].cross_attention
mixed_query_layer = attn_module.query(hidden_states)
mixed_x_layer = attn_module.key_value(encoder_outputs)
(mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2)
dropout_fn = attn_module.attention_dropout if attn_module.training else None
# Reshape and transpose [b, np, s, hn]
query_layer = attn_module._transpose_for_scores(mixed_query_layer)
key_layer = attn_module._transpose_for_scores(mixed_key_layer)
value_layer = attn_module._transpose_for_scores(mixed_value_layer)
context_layer = standard_attention(query_layer, key_layer, value_layer, cross_mask, dropout_fn,
scaling_attention_score=False)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [b, s, h]
output = attn_module.dense(context_layer)
if attn_module.training:
output = attn_module.output_dropout(output)
return output
class T5Model(EncoderDecoderModel):
def __init__(self, args, **kwargs):
super().__init__(args, **kwargs, use_bias=False, layernorm=T5LayerNorm)
self.encoder.add_mixin(
"t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads)
)
self.encoder.add_mixin(
"t5-position", T5PositionEmbeddingMixin()
)
del self.encoder.transformer.position_embeddings
num_attention_heads = args.dec_num_attention_heads if args.dec_num_attention_heads is not None else args.num_attention_heads
self.decoder.add_mixin(
"t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, num_attention_heads, is_decoder=True)
)
self.decoder.add_mixin(
"t5-position", T5PositionEmbeddingMixin()
)
del self.decoder.transformer.position_embeddings
self.decoder.transformer.word_embeddings = self.encoder.transformer.word_embeddings
@classmethod
def add_model_specific_args(cls, parser):
super().add_model_specific_args(parser)
parser.add_argument("--relative-attention-num-buckets", type=int, default=None)
...@@ -44,14 +44,13 @@ class LayerNorm(FusedLayerNorm): ...@@ -44,14 +44,13 @@ class LayerNorm(FusedLayerNorm):
def standard_attention(query_layer, key_layer, value_layer, attention_mask, def standard_attention(query_layer, key_layer, value_layer, attention_mask,
attention_dropout=None, log_attention_weights=None): attention_dropout=None, log_attention_weights=None, scaling_attention_score=True):
# We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training.
# The implementation in the paper can be done very easily, if you really need it to train very deep transformers. # The implementation in the paper can be done very easily, if you really need it to train very deep transformers.
attention_scores = torch.matmul( if scaling_attention_score:
query_layer / math.sqrt(query_layer.shape[-1]), query_layer / math.sqrt(query_layer.shape[-1])
key_layer.transpose(-1, -2) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
)
if log_attention_weights is not None: if log_attention_weights is not None:
attention_scores += log_attention_weights attention_scores += log_attention_weights
...@@ -73,7 +72,7 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, ...@@ -73,7 +72,7 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask,
class SelfAttention(torch.nn.Module): class SelfAttention(torch.nn.Module):
def __init__(self, hidden_size, num_attention_heads, def __init__(self, hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob, attention_dropout_prob, output_dropout_prob,
init_method, layer_id, output_layer_init_method=None, init_method, layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True,
hooks={}): hooks={}):
super(SelfAttention, self).__init__() super(SelfAttention, self).__init__()
# Set output layer initialization if not provided. # Set output layer initialization if not provided.
...@@ -83,25 +82,31 @@ class SelfAttention(torch.nn.Module): ...@@ -83,25 +82,31 @@ class SelfAttention(torch.nn.Module):
self.layer_id = layer_id self.layer_id = layer_id
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = get_model_parallel_world_size() world_size = get_model_parallel_world_size()
self.hidden_size_per_partition = divide(hidden_size, world_size) if hidden_size_per_attention_head is None:
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
else:
self.hidden_size_per_attention_head = hidden_size_per_attention_head
self.num_attention_heads_per_partition = divide(num_attention_heads, world_size) self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
# Strided linear layer. # Strided linear layer.
self.query_key_value = ColumnParallelLinear( self.query_key_value = ColumnParallelLinear(
hidden_size, hidden_size,
3 * hidden_size, 3 * inner_hidden_size,
stride=3, stride=3,
gather_output=False, gather_output=False,
init_method=init_method init_method=init_method,
bias=bias
) )
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
hidden_size, inner_hidden_size,
hidden_size, hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method init_method=output_layer_init_method,
bias=bias
) )
self.output_dropout = torch.nn.Dropout(output_dropout_prob) self.output_dropout = torch.nn.Dropout(output_dropout_prob)
...@@ -115,7 +120,7 @@ class SelfAttention(torch.nn.Module): ...@@ -115,7 +120,7 @@ class SelfAttention(torch.nn.Module):
tensor = tensor.view(*new_tensor_shape) tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3) return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, mask, **kw_args): def forward(self, hidden_states, mask, *args, **kw_args):
if 'attention_forward' in self.hooks: if 'attention_forward' in self.hooks:
return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id) return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id)
else: else:
...@@ -142,9 +147,91 @@ class SelfAttention(torch.nn.Module): ...@@ -142,9 +147,91 @@ class SelfAttention(torch.nn.Module):
return output, None return output, None
class CrossAttention(torch.nn.Module):
"""Parallel cross-attention layer for Transformer"""
def __init__(self, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, init_method,
layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True, hooks={}):
super().__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
self.hooks = hooks
self.layer_id = layer_id
# Per attention head and per partition values.
world_size = get_model_parallel_world_size()
if hidden_size_per_attention_head is None:
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
else:
self.hidden_size_per_attention_head = hidden_size_per_attention_head
self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
# Strided linear layer.
self.query = ColumnParallelLinear(hidden_size, inner_hidden_size,
gather_output=False,
init_method=init_method, bias=bias)
self.key_value = ColumnParallelLinear(hidden_size, 2 * inner_hidden_size,
stride=2,
gather_output=False,
init_method=init_method, bias=bias)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
# Output.
self.dense = RowParallelLinear(
inner_hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method, bias=bias)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, encoder_outputs, *args, cross_attention_mask=None, **kw_args):
# hidden_states: [b, s, h]
if 'cross_attention_forward' in self.hooks:
return self.hooks['cross_attention_forward'](hidden_states, encoder_outputs,
cross_attention_mask=cross_attention_mask, **kw_args,
layer_id=self.layer_id)
else:
mixed_query_layer = self.query(hidden_states)
mixed_x_layer = self.key_value(encoder_outputs)
(mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2)
dropout_fn = self.attention_dropout if self.training else None
# Reshape and transpose [b, np, s, hn]
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [b, s, h]
output = self.dense(context_layer)
if self.training:
output = self.output_dropout(output)
return output
class MLP(torch.nn.Module): class MLP(torch.nn.Module):
def __init__(self, hidden_size, output_dropout_prob, init_method, def __init__(self, hidden_size, output_dropout_prob, init_method, inner_hidden_size=None,
output_layer_init_method=None, layer_id=None, hooks={}): output_layer_init_method=None, layer_id=None, hooks={}, bias=True):
super(MLP, self).__init__() super(MLP, self).__init__()
self.layer_id = layer_id self.layer_id = layer_id
# Set output layer initialization if not provided. # Set output layer initialization if not provided.
...@@ -152,18 +239,22 @@ class MLP(torch.nn.Module): ...@@ -152,18 +239,22 @@ class MLP(torch.nn.Module):
output_layer_init_method = init_method output_layer_init_method = init_method
self.hooks = hooks self.hooks = hooks
# Project to 4h. # Project to 4h.
if inner_hidden_size is None:
inner_hidden_size = 4 * hidden_size
self.dense_h_to_4h = ColumnParallelLinear( self.dense_h_to_4h = ColumnParallelLinear(
hidden_size, hidden_size,
4 * hidden_size, inner_hidden_size,
gather_output=False, gather_output=False,
init_method=init_method init_method=init_method,
bias=bias
) )
# Project back to h. # Project back to h.
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size, inner_hidden_size,
hidden_size, hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method init_method=output_layer_init_method,
bias=bias
) )
self.dropout = torch.nn.Dropout(output_dropout_prob) self.dropout = torch.nn.Dropout(output_dropout_prob)
...@@ -190,8 +281,13 @@ class BaseTransformerLayer(torch.nn.Module): ...@@ -190,8 +281,13 @@ class BaseTransformerLayer(torch.nn.Module):
layernorm_epsilon, layernorm_epsilon,
init_method, init_method,
layer_id, layer_id,
inner_hidden_size=None,
hidden_size_per_attention_head=None,
output_layer_init_method=None, output_layer_init_method=None,
sandwich_ln=True, sandwich_ln=True,
layernorm=LayerNorm,
is_decoder=False,
use_bias=True,
hooks={} hooks={}
): ):
super(BaseTransformerLayer, self).__init__() super(BaseTransformerLayer, self).__init__()
...@@ -199,10 +295,11 @@ class BaseTransformerLayer(torch.nn.Module): ...@@ -199,10 +295,11 @@ class BaseTransformerLayer(torch.nn.Module):
if output_layer_init_method is None: if output_layer_init_method is None:
output_layer_init_method = init_method output_layer_init_method = init_method
self.layer_id = layer_id self.layer_id = layer_id
self.is_decoder = is_decoder
self.hooks = hooks self.hooks = hooks
# Layernorm on the input data. # Layernorm on the input data.
self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
# Self attention. # Self attention.
self.attention = SelfAttention( self.attention = SelfAttention(
...@@ -212,28 +309,48 @@ class BaseTransformerLayer(torch.nn.Module): ...@@ -212,28 +309,48 @@ class BaseTransformerLayer(torch.nn.Module):
output_dropout_prob, output_dropout_prob,
init_method, init_method,
layer_id, layer_id,
hidden_size_per_attention_head=hidden_size_per_attention_head,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
bias=use_bias,
hooks=hooks hooks=hooks
) )
# Layernorm on the input data. # Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
self.sandwich_ln = sandwich_ln self.sandwich_ln = sandwich_ln
if sandwich_ln: if sandwich_ln:
self.third_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.third_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
self.fourth_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.fourth_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
# Cross attention.
if self.is_decoder:
self.cross_attention = CrossAttention(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
init_method,
layer_id,
hidden_size_per_attention_head=hidden_size_per_attention_head,
output_layer_init_method=output_layer_init_method,
bias=use_bias,
hooks=hooks
)
self.post_cross_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
# MLP # MLP
self.mlp = MLP( self.mlp = MLP(
hidden_size, hidden_size,
output_dropout_prob, output_dropout_prob,
init_method, init_method,
inner_hidden_size=inner_hidden_size,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
bias=use_bias,
layer_id=layer_id, layer_id=layer_id,
hooks=hooks hooks=hooks
) )
def forward(self, hidden_states, mask, **kw_args): def forward(self, hidden_states, mask, encoder_outputs=None, **kw_args):
''' '''
hidden_states: [batch, seq_len, hidden_size] hidden_states: [batch, seq_len, hidden_size]
mask: [(1, 1), seq_len, seq_len] mask: [(1, 1), seq_len, seq_len]
...@@ -252,6 +369,15 @@ class BaseTransformerLayer(torch.nn.Module): ...@@ -252,6 +369,15 @@ class BaseTransformerLayer(torch.nn.Module):
layernorm_input = hidden_states + attention_output layernorm_input = hidden_states + attention_output
# Layer norm post the self attention. # Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input) layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.is_decoder and encoder_outputs is not None:
# Cross attention
attention_output = self.cross_attention(layernorm_output, encoder_outputs, **kw_args)
# Residual connection.
layernorm_input = layernorm_output + attention_output
# Layer norm post the cross attention
layernorm_output = self.post_cross_attention_layernorm(layernorm_input)
# MLP. # MLP.
mlp_output = self.mlp(layernorm_output, **kw_args) mlp_output = self.mlp(layernorm_output, **kw_args)
...@@ -279,13 +405,19 @@ class BaseTransformer(torch.nn.Module): ...@@ -279,13 +405,19 @@ class BaseTransformer(torch.nn.Module):
checkpoint_num_layers=1, checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5, layernorm_epsilon=1.0e-5,
init_method_std=0.02, init_method_std=0.02,
inner_hidden_size=None,
hidden_size_per_attention_head=None,
sandwich_ln=True, sandwich_ln=True,
parallel_output=True, parallel_output=True,
is_decoder=False,
use_bias=True,
layernorm=LayerNorm,
hooks={} hooks={}
): ):
super(BaseTransformer, self).__init__() super(BaseTransformer, self).__init__()
# recording parameters # recording parameters
self.is_decoder = is_decoder
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.checkpoint_activations = checkpoint_activations self.checkpoint_activations = checkpoint_activations
self.checkpoint_num_layers = checkpoint_num_layers self.checkpoint_num_layers = checkpoint_num_layers
...@@ -314,8 +446,13 @@ class BaseTransformer(torch.nn.Module): ...@@ -314,8 +446,13 @@ class BaseTransformer(torch.nn.Module):
layernorm_epsilon, layernorm_epsilon,
self.init_method, self.init_method,
layer_id, layer_id,
inner_hidden_size=inner_hidden_size,
hidden_size_per_attention_head=hidden_size_per_attention_head,
output_layer_init_method=self.output_layer_init_method, output_layer_init_method=self.output_layer_init_method,
is_decoder=self.is_decoder,
sandwich_ln=sandwich_ln, sandwich_ln=sandwich_ln,
layernorm=layernorm,
use_bias=use_bias,
hooks=self.hooks hooks=self.hooks
) )
...@@ -323,10 +460,10 @@ class BaseTransformer(torch.nn.Module): ...@@ -323,10 +460,10 @@ class BaseTransformer(torch.nn.Module):
[get_layer(layer_id) for layer_id in range(num_layers)]) [get_layer(layer_id) for layer_id in range(num_layers)])
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, output_hidden_states=False, def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, encoder_outputs=None,
**kw_args): output_hidden_states=False, **kw_args):
# sanity check # sanity check
assert len(input_ids.shape) == 2 assert len(input_ids.shape) == 2
batch_size, query_length = input_ids.shape batch_size, query_length = input_ids.shape
...@@ -349,7 +486,8 @@ class BaseTransformer(torch.nn.Module): ...@@ -349,7 +486,8 @@ class BaseTransformer(torch.nn.Module):
assert len(position_ids.shape) <= 2 assert len(position_ids.shape) <= 2
assert position_ids.shape[-1] == query_length assert position_ids.shape[-1] == query_length
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
hidden_states = hidden_states + position_embeddings if position_embeddings is not None:
hidden_states = hidden_states + position_embeddings
hidden_states = self.embedding_dropout(hidden_states) hidden_states = self.embedding_dropout(hidden_states)
hidden_states_outputs = [hidden_states] if output_hidden_states else [] hidden_states_outputs = [hidden_states] if output_hidden_states else []
...@@ -363,21 +501,15 @@ class BaseTransformer(torch.nn.Module): ...@@ -363,21 +501,15 @@ class BaseTransformer(torch.nn.Module):
def custom(start, end): def custom(start, end):
def custom_forward(*inputs): def custom_forward(*inputs):
layers_ = self.layers[start:end] layers_ = self.layers[start:end]
x_, mask = inputs[0], inputs[1] x_, mask, encoder_outputs_ = inputs[0], inputs[1], inputs[2]
if len(inputs) > 2: # have branch_input
branch_ = inputs[2]
output_per_layers_part = [] output_per_layers_part = []
for i, layer in enumerate(layers_): for i, layer in enumerate(layers_):
if len(inputs) > 2: if 'layer_forward' in self.hooks:
x_, branch_, output_this_layer = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args
)
elif 'layer_forward' in self.hooks:
x_, output_this_layer = self.hooks['layer_forward']( x_, output_this_layer = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id, **kw_args x_, mask, encoder_outputs_, layer_id=layer.layer_id, **kw_args
) )
else: else:
x_, output_this_layer = layer(x_, mask, **kw_args) x_, output_this_layer = layer(x_, mask, encoder_outputs_, **kw_args)
output_per_layers_part.append(output_this_layer) output_per_layers_part.append(output_this_layer)
return x_, output_per_layers_part return x_, output_per_layers_part
...@@ -386,7 +518,7 @@ class BaseTransformer(torch.nn.Module): ...@@ -386,7 +518,7 @@ class BaseTransformer(torch.nn.Module):
l, num_layers = 0, len(self.layers) l, num_layers = 0, len(self.layers)
chunk_length = self.checkpoint_num_layers chunk_length = self.checkpoint_num_layers
while l < num_layers: while l < num_layers:
args = [hidden_states, attention_mask] args = [hidden_states, attention_mask, encoder_outputs]
if branch_input is not None: if branch_input is not None:
hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args,
branch_input) branch_input)
...@@ -398,7 +530,7 @@ class BaseTransformer(torch.nn.Module): ...@@ -398,7 +530,7 @@ class BaseTransformer(torch.nn.Module):
l += chunk_length l += chunk_length
else: else:
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
args = [hidden_states, attention_mask] args = [hidden_states, attention_mask, encoder_outputs]
if branch_input is not None: # customized layer_forward with branch_input if branch_input is not None: # customized layer_forward with branch_input
hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args,
layer_id=torch.tensor( layer_id=torch.tensor(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment