diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py index 5e5793e071e5cec6308f81093ab374f11e737961..8dfb9899e85d1d6850d6b1028287460e846d5d53 100644 --- a/SwissArmyTransformer/model/base_model.py +++ b/SwissArmyTransformer/model/base_model.py @@ -87,7 +87,7 @@ class BaseModel(torch.nn.Module): def collect_hooks_(self): names = ['word_embedding_forward', 'position_embedding_forward', - 'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward', + 'attention_forward', 'cross_attention_forward', 'mlp_forward', 'final_forward', 'layer_forward', 'branch_embedding_forward', 'branch_final_forward' ] hooks = {} diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py index 1beace877aaa5aaa7bbe84fc8ffea500fa03bd72..a31fa6e886f53201f0e00c0c81f919d4ecbac0f0 100644 --- a/SwissArmyTransformer/model/encoder_decoder_model.py +++ b/SwissArmyTransformer/model/encoder_decoder_model.py @@ -56,6 +56,8 @@ def get_extended_attention_mask(attention_mask, input_shape, device, dtype=torch causal_mask], axis=-1) extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = causal_mask[:, None, :, :] else: if attention_mask is None: extended_attention_mask = torch.ones(1, 1, 1, seq_length, device=device, dtype=dtype) @@ -108,16 +110,20 @@ class EncoderDecoderModel(torch.nn.Module): self.decoder.disable_untrainable_params() def forward(self, input_ids=None, input_position_ids=None, attention_mask=None, decoder_input_ids=None, - decoder_position_ids=None, decoder_attention_mask=None, + decoder_position_ids=None, decoder_attention_mask=None, encoder_outputs=None, **kw_args): dtype = self.encoder.transformer.word_embeddings.weight.dtype - batch_size, encoder_seq_length = input_ids.size()[:2] + if encoder_outputs is None: + batch_size, encoder_seq_length = input_ids.size()[:2] + else: + batch_size, encoder_seq_length = encoder_outputs.size()[:2] encoder_attention_mask = get_extended_attention_mask(attention_mask, (batch_size, encoder_seq_length), device=input_ids.device, dtype=dtype) decoder_seq_length = decoder_input_ids.size(1) - encoder_outputs, *_dumps = self.encoder(input_ids, input_position_ids, encoder_attention_mask, **kw_args) + if encoder_outputs is None: + encoder_outputs, *_dumps = self.encoder(input_ids, input_position_ids, encoder_attention_mask, **kw_args) decoder_attention_mask = get_extended_attention_mask(decoder_attention_mask, (batch_size, decoder_seq_length), - device=input_ids.device, dtype=dtype) + device=input_ids.device, dtype=dtype, is_decoder=True) decoder_outputs, *decoder_mems = self.decoder(decoder_input_ids, decoder_position_ids, decoder_attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=encoder_attention_mask, **kw_args) diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py index bf26a3b81bfa41c1978aaf8bb9ab17a921cb0384..bf4f2c9ff0fbfdcd56d024dcb4faab3f5fe89a8a 100644 --- a/SwissArmyTransformer/model/t5_model.py +++ b/SwissArmyTransformer/model/t5_model.py @@ -1,9 +1,11 @@ import math import torch +import torch.nn.functional as F from .mixins import BaseMixin from .encoder_decoder_model import EncoderDecoderModel from SwissArmyTransformer.mpu import get_model_parallel_world_size from SwissArmyTransformer.mpu.transformer import standard_attention +from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim @@ -164,9 +166,22 @@ class T5AttentionMixin(BaseMixin): return output +class T5DecoderFinalMixin(BaseMixin): + def __init__(self, hidden_size): + super().__init__() + self.hidden_size = hidden_size + + def final_forward(self, logits, **kwargs): + logits_parallel = copy_to_model_parallel_region(logits) + logits_parallel = logits_parallel * (self.hidden_size ** -0.5) + logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight) + return logits_parallel + + class T5Model(EncoderDecoderModel): def __init__(self, args, **kwargs): - super().__init__(args, **kwargs, use_bias=False, layernorm=T5LayerNorm) + super().__init__(args, **kwargs, use_bias=False, layernorm=T5LayerNorm, + activation_func=torch.nn.functional.relu) self.encoder.add_mixin( "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads) ) @@ -181,6 +196,9 @@ class T5Model(EncoderDecoderModel): self.decoder.add_mixin( "t5-position", T5PositionEmbeddingMixin() ) + self.decoder.add_mixin( + "t5-final", T5DecoderFinalMixin(args.hidden_size) + ) del self.decoder.transformer.position_embeddings self.decoder.transformer.word_embeddings = self.encoder.transformer.word_embeddings diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index 7a3234b4340f45f52f74082bb34cc86a25818684..d58145552635d26e245382d3c7585ef825db723f 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -198,11 +198,11 @@ class CrossAttention(torch.nn.Module): tensor = tensor.view(*new_tensor_shape) return tensor.permute(0, 2, 1, 3) - def forward(self, hidden_states, encoder_outputs, *args, cross_attention_mask=None, **kw_args): + def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *args, **kw_args): # hidden_states: [b, s, h] if 'cross_attention_forward' in self.hooks: - return self.hooks['cross_attention_forward'](hidden_states, encoder_outputs, - cross_attention_mask=cross_attention_mask, **kw_args, + return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, + **kw_args, layer_id=self.layer_id) else: mixed_query_layer = self.query(hidden_states) @@ -231,9 +231,10 @@ class CrossAttention(torch.nn.Module): class MLP(torch.nn.Module): def __init__(self, hidden_size, output_dropout_prob, init_method, inner_hidden_size=None, - output_layer_init_method=None, layer_id=None, hooks={}, bias=True): + output_layer_init_method=None, layer_id=None, hooks={}, bias=True, activation_func=gelu): super(MLP, self).__init__() self.layer_id = layer_id + self.activation_func = activation_func # Set output layer initialization if not provided. if output_layer_init_method is None: output_layer_init_method = init_method @@ -263,7 +264,7 @@ class MLP(torch.nn.Module): output = self.hooks['mlp_forward'](hidden_states, **kw_args, layer_id=self.layer_id) else: intermediate_parallel = self.dense_h_to_4h(hidden_states) - intermediate_parallel = gelu(intermediate_parallel) + intermediate_parallel = self.activation_func(intermediate_parallel) output = self.dense_4h_to_h(intermediate_parallel) if self.training: @@ -288,6 +289,7 @@ class BaseTransformerLayer(torch.nn.Module): layernorm=LayerNorm, is_decoder=False, use_bias=True, + activation_func=gelu, hooks={} ): super(BaseTransformerLayer, self).__init__() @@ -347,10 +349,11 @@ class BaseTransformerLayer(torch.nn.Module): output_layer_init_method=output_layer_init_method, bias=use_bias, layer_id=layer_id, + activation_func=activation_func, hooks=hooks ) - def forward(self, hidden_states, mask, encoder_outputs=None, **kw_args): + def forward(self, hidden_states, mask, encoder_outputs=None, cross_attention_mask=None, **kw_args): ''' hidden_states: [batch, seq_len, hidden_size] mask: [(1, 1), seq_len, seq_len] @@ -372,9 +375,9 @@ class BaseTransformerLayer(torch.nn.Module): if self.is_decoder and encoder_outputs is not None: # Cross attention - attention_output = self.cross_attention(layernorm_output, encoder_outputs, **kw_args) + attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args) # Residual connection. - layernorm_input = layernorm_output + attention_output + layernorm_input = layernorm_input + attention_output # Layer norm post the cross attention layernorm_output = self.post_cross_attention_layernorm(layernorm_input) @@ -411,6 +414,7 @@ class BaseTransformer(torch.nn.Module): parallel_output=True, is_decoder=False, use_bias=True, + activation_func=gelu, layernorm=LayerNorm, hooks={} ): @@ -453,6 +457,7 @@ class BaseTransformer(torch.nn.Module): sandwich_ln=sandwich_ln, layernorm=layernorm, use_bias=use_bias, + activation_func=activation_func, hooks=self.hooks ) @@ -464,7 +469,7 @@ class BaseTransformer(torch.nn.Module): def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, encoder_outputs=None, output_hidden_states=False, **kw_args): - # sanity check + # sanity check assert len(input_ids.shape) == 2 batch_size, query_length = input_ids.shape assert len(attention_mask.shape) == 2 or \ @@ -504,7 +509,11 @@ class BaseTransformer(torch.nn.Module): x_, mask, encoder_outputs_ = inputs[0], inputs[1], inputs[2] output_per_layers_part = [] for i, layer in enumerate(layers_): - if 'layer_forward' in self.hooks: + if branch_input is not None: + x_, encoder_outputs_, output_this_layer = self.hooks['layer_forward']( + x_, mask, layer_id=layer.layer_id, branch_input=encoder_outputs_, **kw_args + ) + elif 'layer_forward' in self.hooks: x_, output_this_layer = self.hooks['layer_forward']( x_, mask, encoder_outputs_, layer_id=layer.layer_id, **kw_args ) @@ -518,11 +527,11 @@ class BaseTransformer(torch.nn.Module): l, num_layers = 0, len(self.layers) chunk_length = self.checkpoint_num_layers while l < num_layers: - args = [hidden_states, attention_mask, encoder_outputs] if branch_input is not None: - hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, - branch_input) + args = [hidden_states, attention_mask, branch_input] + hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args) else: + args = [hidden_states, attention_mask, encoder_outputs] hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args) if output_hidden_states: hidden_states_outputs.append(hidden_states)