Skip to content
Snippets Groups Projects
Commit 6a2fa408 authored by Ming Ding's avatar Ming Ding
Browse files

tidy up enc-dec model

parent 3ab8731b
No related branches found
No related tags found
No related merge requests found
...@@ -13,8 +13,7 @@ import math ...@@ -13,8 +13,7 @@ import math
import random import random
import torch import torch
from SwissArmyTransformer.mpu import BaseTransformer, LayerNorm from SwissArmyTransformer.mpu import BaseTransformer
class BaseMixin(torch.nn.Module): class BaseMixin(torch.nn.Module):
def __init__(self): def __init__(self):
......
...@@ -17,60 +17,6 @@ from .base_model import BaseModel, BaseMixin ...@@ -17,60 +17,6 @@ from .base_model import BaseModel, BaseMixin
from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
def get_extended_attention_mask(attention_mask, input_shape, device, dtype=torch.float32, is_decoder=False):
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (:obj:`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (:obj:`Tuple[int]`):
The shape of the input to the model.
device: (:obj:`torch.device`):
The device of the input to the model.
dtype:
is_decoder:
Returns:
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask is None or attention_mask.dim() == 2:
batch_size, seq_length = input_shape
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if is_decoder:
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(dtype)
if attention_mask is not None:
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
causal_mask], axis=-1)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
extended_attention_mask = causal_mask[:, None, :, :]
else:
if attention_mask is None:
extended_attention_mask = torch.ones(1, 1, 1, seq_length, device=device, dtype=dtype)
else:
extended_attention_mask = attention_mask[:, None, None, :]
elif attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
else:
raise ValueError(
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
)
return extended_attention_mask
class EncoderFinalMixin(BaseMixin): class EncoderFinalMixin(BaseMixin):
def final_forward(self, logits, **kwargs): def final_forward(self, logits, **kwargs):
logits = copy_to_model_parallel_region(logits) logits = copy_to_model_parallel_region(logits)
...@@ -78,7 +24,7 @@ class EncoderFinalMixin(BaseMixin): ...@@ -78,7 +24,7 @@ class EncoderFinalMixin(BaseMixin):
class EncoderDecoderModel(torch.nn.Module): class EncoderDecoderModel(torch.nn.Module):
def __init__(self, args, encoder=None, decoder=None, parallel_output=False, **kwargs): def __init__(self, args, encoder=None, decoder=None, tie_word_embeddings=True, parallel_output=False, **kwargs):
super(EncoderDecoderModel, self).__init__() super(EncoderDecoderModel, self).__init__()
if encoder is not None: if encoder is not None:
assert isinstance(encoder, BaseModel) assert isinstance(encoder, BaseModel)
...@@ -86,6 +32,7 @@ class EncoderDecoderModel(torch.nn.Module): ...@@ -86,6 +32,7 @@ class EncoderDecoderModel(torch.nn.Module):
else: else:
self.encoder = BaseModel(args, **kwargs) self.encoder = BaseModel(args, **kwargs)
self.encoder.add_mixin("final", EncoderFinalMixin()) self.encoder.add_mixin("final", EncoderFinalMixin())
if decoder is not None: if decoder is not None:
assert isinstance(decoder, BaseModel) assert isinstance(decoder, BaseModel)
self.decoder = decoder self.decoder = decoder
...@@ -100,6 +47,10 @@ class EncoderDecoderModel(torch.nn.Module): ...@@ -100,6 +47,10 @@ class EncoderDecoderModel(torch.nn.Module):
setattr(dec_args, name, dec_attr) setattr(dec_args, name, dec_attr)
self.decoder = BaseModel(args, is_decoder=True, parallel_output=parallel_output, **kwargs) self.decoder = BaseModel(args, is_decoder=True, parallel_output=parallel_output, **kwargs)
self.tie_word_embeddings = tie_word_embeddings
if tie_word_embeddings:
self.decoder.transformer.word_embeddings = self.encoder.transformer.word_embeddings
def reinit(self): def reinit(self):
self.encoder.reinit() self.encoder.reinit()
self.decoder.reinit() self.decoder.reinit()
...@@ -108,25 +59,19 @@ class EncoderDecoderModel(torch.nn.Module): ...@@ -108,25 +59,19 @@ class EncoderDecoderModel(torch.nn.Module):
self.encoder.disable_untrainable_params() self.encoder.disable_untrainable_params()
self.decoder.disable_untrainable_params() self.decoder.disable_untrainable_params()
def forward(self, input_ids=None, input_position_ids=None, attention_mask=None, decoder_input_ids=None, def encode(self, input_ids, position_ids, attention_mask=None, **kw_args):
decoder_position_ids=None, decoder_attention_mask=None, encoder_outputs=None, encoder_outputs, *_dumps = self.encoder(input_ids, position_ids, attention_mask, **kw_args)
**kw_args): return encoder_outputs
dtype = self.encoder.transformer.word_embeddings.weight.dtype
if encoder_outputs is None: def decode(self, input_ids, position_ids, attention_mask, encoder_outputs,cross_attention_mask=None, **kw_args):
batch_size, encoder_seq_length = input_ids.size()[:2] # If no context, please explicitly pass ``encoder_outputs=None''
else: return self.decoder(input_ids, position_ids, attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args)
batch_size, encoder_seq_length = encoder_outputs.size()[:2]
encoder_attention_mask = get_extended_attention_mask(attention_mask, (batch_size, encoder_seq_length), def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids,dec_attention_mask, *, enc_attention_mask=None, cross_attention_mask=None, **kw_args):
device=input_ids.device, dtype=dtype) # Please use self.decoder for auto-regressive generation.
decoder_seq_length = decoder_input_ids.size(1) encoder_outputs = self.encode(enc_input_ids, enc_position_ids, enc_attention_mask, **kw_args)
if encoder_outputs is None: decoder_outputs, *mems = self.decode(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args)
encoder_outputs, *_dumps = self.encoder(input_ids, input_position_ids, encoder_attention_mask, **kw_args) return encoder_outputs, decoder_outputs, *mems
decoder_attention_mask = get_extended_attention_mask(decoder_attention_mask, (batch_size, decoder_seq_length),
device=input_ids.device, dtype=dtype, is_decoder=True)
decoder_outputs, *decoder_mems = self.decoder(decoder_input_ids, decoder_position_ids, decoder_attention_mask,
encoder_outputs=encoder_outputs,
cross_attention_mask=encoder_attention_mask, **kw_args)
return encoder_outputs, decoder_outputs, *decoder_mems
@classmethod @classmethod
def add_model_specific_args(cls, parser): def add_model_specific_args(cls, parser):
......
...@@ -180,8 +180,8 @@ class T5DecoderFinalMixin(BaseMixin): ...@@ -180,8 +180,8 @@ class T5DecoderFinalMixin(BaseMixin):
class T5Model(EncoderDecoderModel): class T5Model(EncoderDecoderModel):
def __init__(self, args, **kwargs): def __init__(self, args, **kwargs):
super().__init__(args, **kwargs, use_bias=False, layernorm=T5LayerNorm, super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False,
activation_func=torch.nn.functional.relu) layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu)
self.encoder.add_mixin( self.encoder.add_mixin(
"t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads) "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads)
) )
...@@ -200,7 +200,6 @@ class T5Model(EncoderDecoderModel): ...@@ -200,7 +200,6 @@ class T5Model(EncoderDecoderModel):
"t5-final", T5DecoderFinalMixin(args.hidden_size) "t5-final", T5DecoderFinalMixin(args.hidden_size)
) )
del self.decoder.transformer.position_embeddings del self.decoder.transformer.position_embeddings
self.decoder.transformer.word_embeddings = self.encoder.transformer.word_embeddings
@classmethod @classmethod
def add_model_specific_args(cls, parser): def add_model_specific_args(cls, parser):
......
...@@ -198,7 +198,7 @@ class CrossAttention(torch.nn.Module): ...@@ -198,7 +198,7 @@ class CrossAttention(torch.nn.Module):
tensor = tensor.view(*new_tensor_shape) tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3) return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *args, **kw_args): def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *, **kw_args):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
if 'cross_attention_forward' in self.hooks: if 'cross_attention_forward' in self.hooks:
return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs,
...@@ -353,7 +353,7 @@ class BaseTransformerLayer(torch.nn.Module): ...@@ -353,7 +353,7 @@ class BaseTransformerLayer(torch.nn.Module):
hooks=hooks hooks=hooks
) )
def forward(self, hidden_states, mask, encoder_outputs=None, cross_attention_mask=None, **kw_args): def forward(self, hidden_states, mask, **kw_args):
''' '''
hidden_states: [batch, seq_len, hidden_size] hidden_states: [batch, seq_len, hidden_size]
mask: [(1, 1), seq_len, seq_len] mask: [(1, 1), seq_len, seq_len]
...@@ -373,13 +373,16 @@ class BaseTransformerLayer(torch.nn.Module): ...@@ -373,13 +373,16 @@ class BaseTransformerLayer(torch.nn.Module):
# Layer norm post the self attention. # Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input) layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.is_decoder and encoder_outputs is not None: if self.is_decoder:
# Cross attention encoder_outputs = kw_args['encoder_outputs']
attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args) if encoder_outputs is not None:
# Residual connection. cross_attention_mask=kw_args['cross_attention_mask']
layernorm_input = layernorm_input + attention_output # Cross attention
# Layer norm post the cross attention attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args)
layernorm_output = self.post_cross_attention_layernorm(layernorm_input) # Residual connection.
layernorm_input = layernorm_input + attention_output
# Layer norm post the cross attention
layernorm_output = self.post_cross_attention_layernorm(layernorm_input)
# MLP. # MLP.
mlp_output = self.mlp(layernorm_output, **kw_args) mlp_output = self.mlp(layernorm_output, **kw_args)
...@@ -467,12 +470,15 @@ class BaseTransformer(torch.nn.Module): ...@@ -467,12 +470,15 @@ class BaseTransformer(torch.nn.Module):
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, encoder_outputs=None, def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None,
output_hidden_states=False, **kw_args): output_hidden_states=False, **kw_args):
breakpoint()
# sanity check # sanity check
assert len(input_ids.shape) == 2 assert len(input_ids.shape) == 2
batch_size, query_length = input_ids.shape batch_size, query_length = input_ids.shape
if attention_mask is None:
attention_mask = torch.ones(1, 1, device=input_ids.device).type_as(
next(self.parameters())
) # None means full attention
assert len(attention_mask.shape) == 2 or \ assert len(attention_mask.shape) == 2 or \
len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1 len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor) assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor)
...@@ -507,34 +513,37 @@ class BaseTransformer(torch.nn.Module): ...@@ -507,34 +513,37 @@ class BaseTransformer(torch.nn.Module):
def custom(start, end): def custom(start, end):
def custom_forward(*inputs): def custom_forward(*inputs):
layers_ = self.layers[start:end] layers_ = self.layers[start:end]
x_, mask, encoder_outputs_ = inputs[0], inputs[1], inputs[2] x_, mask = inputs[0], inputs[1]
if len(inputs) > 2: # have branch_input
branch_ = inputs[2]
output_per_layers_part = [] output_per_layers_part = []
for i, layer in enumerate(layers_): for i, layer in enumerate(layers_):
if branch_input is not None: if len(inputs) > 2:
x_, encoder_outputs_, output_this_layer = self.hooks['layer_forward']( x_, branch_, output_this_layer = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id, branch_input=encoder_outputs_, **kw_args x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args
) )
elif 'layer_forward' in self.hooks: elif 'layer_forward' in self.hooks:
x_, output_this_layer = self.hooks['layer_forward']( x_, output_this_layer = self.hooks['layer_forward'](
x_, mask, encoder_outputs_, layer_id=layer.layer_id, **kw_args x_, mask, layer_id=layer.layer_id, **kw_args
) )
else: else:
x_, output_this_layer = layer(x_, mask, encoder_outputs_, **kw_args) x_, output_this_layer = layer(x_, mask, **kw_args)
output_per_layers_part.append(output_this_layer) output_per_layers_part.append(output_this_layer)
return x_, output_per_layers_part return x_, output_per_layers_part
return custom_forward return custom_forward
l, num_layers = 0, len(self.layers) # prevent to lose requires_grad in checkpointing.
chunk_length = self.checkpoint_num_layers # To save memory when only finetuning the final layers, don't use checkpointing.
if self.training: if self.training:
hidden_states.requires_grad_(True) hidden_states.requires_grad_(True)
l, num_layers = 0, len(self.layers)
chunk_length = self.checkpoint_num_layers
while l < num_layers: while l < num_layers:
args = [hidden_states, attention_mask]
if branch_input is not None: if branch_input is not None:
args = [hidden_states, attention_mask, branch_input] hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, branch_input)
hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args)
else: else:
args = [hidden_states, attention_mask, encoder_outputs]
hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args) hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args)
if output_hidden_states: if output_hidden_states:
hidden_states_outputs.append(hidden_states) hidden_states_outputs.append(hidden_states)
...@@ -542,16 +551,11 @@ class BaseTransformer(torch.nn.Module): ...@@ -542,16 +551,11 @@ class BaseTransformer(torch.nn.Module):
l += chunk_length l += chunk_length
else: else:
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
args = [hidden_states, attention_mask, encoder_outputs] args = [hidden_states, attention_mask]
if branch_input is not None: # customized layer_forward with branch_input if branch_input is not None: # customized layer_forward with branch_input
hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_args)
layer_id=torch.tensor( elif 'layer_forward' in self.hooks: # customized layer_forward
i), hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args)
branch_input=branch_input,
**kw_args)
elif 'layer_forward' in self.hooks: # customized layer_forward
hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i),
**kw_args)
else: else:
hidden_states, output_this_layer = layer(*args, **kw_args) hidden_states, output_this_layer = layer(*args, **kw_args)
if output_hidden_states: if output_hidden_states:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment