diff --git a/SwissArmyTransformer/model/cached_autoregressive_model.py b/SwissArmyTransformer/model/cached_autoregressive_model.py index ed25fe48aaa3ef83b6c784429b229d40a22b7020..8caed663b27e4b6c1b4382090ddd754a1c875d3e 100755 --- a/SwissArmyTransformer/model/cached_autoregressive_model.py +++ b/SwissArmyTransformer/model/cached_autoregressive_model.py @@ -21,20 +21,22 @@ class CachedAutoregressiveMixin(BaseMixin): super().__init__() @non_conflict - def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, old_impl=standard_attention, **kw_args): - mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size - b, nh, seq_len, hidden_size = k.shape + def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, cross_attention=False, old_impl=standard_attention, + **kw_args): + if not cross_attention: + mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size + b, nh, seq_len, hidden_size = k.shape - cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2) - kw_args['output_this_layer']['mem_kv'] = cache_kv - - if mem is not None: # the first time, mem is None - # might change batch_size - mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4) - memk, memv = mem[0], mem[1] - k = torch.cat((memk, k), dim=2) - v = torch.cat((memv, v), dim=2) - return old_impl(q, k, v, mask, dropout_fn, **kw_args) + cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2) + kw_args['output_this_layer']['mem_kv'] = cache_kv + + if mem is not None: # the first time, mem is None + # might change batch_size + mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4) + memk, memv = mem[0], mem[1] + k = torch.cat((memk, k), dim=2) + v = torch.cat((memv, v), dim=2) + return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, mems=mems, **kw_args) class CachedAutoregressiveModel(BaseModel): diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py index a5f13760351ce27da585f96960493421b4378dc5..d1a50951ea352babcd763a0178ed2de2013695dc 100644 --- a/SwissArmyTransformer/model/t5_model.py +++ b/SwissArmyTransformer/model/t5_model.py @@ -3,10 +3,12 @@ import torch import torch.nn.functional as F from .mixins import BaseMixin from .encoder_decoder_model import EncoderDecoderModel +from .base_model import non_conflict from SwissArmyTransformer.mpu import get_model_parallel_world_size from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region -from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim +from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim, unscaled_init_method +from SwissArmyTransformer.mpu.layers import ColumnParallelLinear, VocabParallelEmbedding class T5PositionEmbeddingMixin(BaseMixin): @@ -94,7 +96,7 @@ class T5AttentionMixin(BaseMixin): relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, cross_attention=False): + def compute_bias(self, query_length, key_length): """Compute binned relative position bias""" context_position = torch.arange(query_length, dtype=torch.long)[:, None] memory_position = torch.arange(key_length, dtype=torch.long)[None, :] @@ -106,84 +108,88 @@ class T5AttentionMixin(BaseMixin): ) relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) # shape (query_length, key_length, num_heads) - if cross_attention: - values = self.cross_relative_attention_bias(relative_position_bucket) - else: - values = self.relative_attention_bias(relative_position_bucket) + values = self.relative_attention_bias(relative_position_bucket) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values - def attention_forward(self, hidden_states, mask, position_bias=None, *args, layer_id=None, mems=None, **kw_args): - attn_module = self.transformer.layers[layer_id].attention - seq_length = hidden_states.size(1) - memory_length = mems[layer_id].size(1) if mems else 0 - mixed_raw_layer = attn_module.query_key_value(hidden_states) - (mixed_query_layer, - mixed_key_layer, - mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) - - dropout_fn = attn_module.attention_dropout if attn_module.training else None - - query_layer = attn_module._transpose_for_scores(mixed_query_layer) - key_layer = attn_module._transpose_for_scores(mixed_key_layer) - value_layer = attn_module._transpose_for_scores(mixed_value_layer) - - if position_bias is None: - position_bias = self.compute_bias(seq_length, memory_length + seq_length) - context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn, - log_attention_weights=position_bias, scaling_attention_score=False) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - output = attn_module.dense(context_layer) - - if attn_module.training: - output = attn_module.output_dropout(output) - - kw_args['output_cross_layer']['position_bias'] = position_bias - - return output - - def cross_attention_forward(self, hidden_states, cross_attention_mask, encoder_outputs, layer_id=None, *args, - **kw_args): - attn_module = self.transformer.layers[layer_id].cross_attention - mixed_query_layer = attn_module.query(hidden_states) - mixed_x_layer = attn_module.key_value(encoder_outputs) - (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2) - - dropout_fn = attn_module.attention_dropout if attn_module.training else None - # Reshape and transpose [b, np, s, hn] - query_layer = attn_module._transpose_for_scores(mixed_query_layer) - key_layer = attn_module._transpose_for_scores(mixed_key_layer) - value_layer = attn_module._transpose_for_scores(mixed_value_layer) - - context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn, - scaling_attention_score=False) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,) - # [b, s, hp] - context_layer = context_layer.view(*new_context_layer_shape) - - # Output. [b, s, h] - output = attn_module.dense(context_layer) - if attn_module.training: - output = attn_module.output_dropout(output) - - return output + @non_conflict + def attention_fn(self, q, k, v, mask, dropout_fn, position_bias=None, old_impl=standard_attention, + cross_attention=False, **kw_args): + log_attention_weights = None + if not cross_attention: + if position_bias is None: + seq_length = q.size(2) + key_length = k.size(2) + position_bias = self.compute_bias(key_length, key_length) + position_bias = position_bias[:, :, -seq_length:, :] + kw_args['output_cross_layer']['position_bias'] = position_bias + log_attention_weights = position_bias + return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, position_bias=position_bias, + log_attention_weights=log_attention_weights, scaling_attention_score=False, **kw_args) class T5DecoderFinalMixin(BaseMixin): - def __init__(self, hidden_size): + def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True): super().__init__() self.hidden_size = hidden_size + self.tie_word_embeddings = tie_word_embeddings + if not tie_word_embeddings: + self.lm_head = VocabParallelEmbedding( + vocab_size, hidden_size, init_method=unscaled_init_method(0.02)) def final_forward(self, logits, **kwargs): logits_parallel = copy_to_model_parallel_region(logits) - logits_parallel = logits_parallel * (self.hidden_size ** -0.5) - logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight) + if self.tie_word_embeddings: + logits_parallel = logits_parallel * (self.hidden_size ** -0.5) + logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight) + else: + logits_parallel = F.linear(logits_parallel, self.lm_head.weight) return logits_parallel +def t5_gelu(x): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5GatedGeluMLPMixin(BaseMixin): + def __init__(self, num_layers, hidden_size, inner_hidden_size=None, bias=True, init_method_std=0.02): + super().__init__() + self.hidden_size = hidden_size + if inner_hidden_size is None: + inner_hidden_size = 4 * hidden_size + self.inner_hidden_size = inner_hidden_size + self.init_method_std = init_method_std + self.gated_h_to_4h_list = torch.nn.ModuleList([ + ColumnParallelLinear( + self.hidden_size, + self.inner_hidden_size, + gather_output=False, + init_method=self._init_weights, + bias=bias, + module=self, + name="gated_h_to_4h" + ) + for layer_id in range(num_layers)]) + + def _init_weights(self, weight, **kwargs): + torch.nn.init.normal_(weight, mean=0, std=self.init_method_std * (self.hidden_size ** -0.5)) + + def mlp_forward(self, hidden_states, layer_id=None, **kw_args): + mlp_module = self.transformer.layers[layer_id].mlp + hidden_gelu = t5_gelu(mlp_module.dense_h_to_4h(hidden_states)) + hidden_linear = self.gated_h_to_4h_list[layer_id](hidden_states) + hidden_states = hidden_gelu * hidden_linear + output = mlp_module.dense_4h_to_h(hidden_states) + + if self.training: + output = mlp_module.dropout(output) + return output + + class T5Model(EncoderDecoderModel): def __init__(self, args, **kwargs): self.init_method_std = args.init_method_std @@ -205,10 +211,19 @@ class T5Model(EncoderDecoderModel): "t5-position", T5PositionEmbeddingMixin() ) self.decoder.add_mixin( - "t5-final", T5DecoderFinalMixin(args.hidden_size) + "t5-final", + T5DecoderFinalMixin(args.vocab_size, args.hidden_size, tie_word_embeddings=not args.no_share_embeddings) ) del self.decoder.transformer.position_embeddings - + if args.gated_gelu_mlp: + self.encoder.add_mixin( + "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std, + inner_hidden_size=args.inner_hidden_size, bias=False) + ) + self.decoder.add_mixin( + "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std, + inner_hidden_size=args.inner_hidden_size, bias=False) + ) def _init_weights(self, weight, module, name): init_method_std = self.init_method_std @@ -246,6 +261,8 @@ class T5Model(EncoderDecoderModel): super().add_model_specific_args(parser) parser.add_argument("--relative-attention-num-buckets", type=int, default=None) parser.add_argument("--init-method-std", type=float, default=0.02) + parser.add_argument("--gated-gelu-mlp", action='store_true') + parser.add_argument("--no-share-embeddings", action='store_true') def encode(self, input_ids, attention_mask=None, **kw_args): return super().encode(input_ids, None, attention_mask, **kw_args) @@ -254,7 +271,7 @@ class T5Model(EncoderDecoderModel): return super().decode(input_ids, None, attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args) - def forward(self, enc_input_ids, dec_input_ids, dec_attention_mask, *, enc_attention_mask=None, + def forward(self, enc_input_ids, dec_input_ids, *, enc_attention_mask=None, dec_attention_mask=None, cross_attention_mask=None, **kw_args): batch_size, seq_length = enc_input_ids.size()[:2] if enc_attention_mask is None: diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index 98e42a5ac7e4a9dc3050fc677e90dc6d8f11de6b..289056948feff75964127fd5a2e4b0f4e8ed3944 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -218,6 +218,10 @@ class CrossAttention(torch.nn.Module): if 'cross_attention_forward' in self.hooks: return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, **kw_args) else: + attention_fn = standard_attention + if 'attention_fn' in self.hooks: + attention_fn = self.hooks['attention_fn'] + mixed_query_layer = self.query(hidden_states) mixed_x_layer = self.key_value(encoder_outputs) (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2) @@ -228,7 +232,8 @@ class CrossAttention(torch.nn.Module): key_layer = self._transpose_for_scores(mixed_key_layer) value_layer = self._transpose_for_scores(mixed_value_layer) - context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn) + context_layer = attention_fn(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn, + cross_attention=True, **kw_args) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) # [b, s, hp] @@ -394,7 +399,7 @@ class BaseTransformerLayer(torch.nn.Module): if self.is_decoder: encoder_outputs = kw_args['encoder_outputs'] if encoder_outputs is not None: - cross_attention_mask = kw_args['cross_attention_mask'] + assert 'cross_attention_mask' in kw_args # Cross attention attention_output = self.cross_attention(layernorm_output, **kw_args) # Residual connection. @@ -504,7 +509,7 @@ class BaseTransformer(torch.nn.Module): ) # None means full attention assert len(attention_mask.shape) == 2 or \ len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1 - + # embedding part if 'word_embedding_forward' in self.hooks: hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args) @@ -526,7 +531,7 @@ class BaseTransformer(torch.nn.Module): output_cross_layer = self.hooks['cross_layer_embedding_forward'](hidden_states, **kw_args) else: output_cross_layer = {} - + output_per_layers = [] if self.checkpoint_activations: # define custom_forward for checkpointing @@ -534,7 +539,7 @@ class BaseTransformer(torch.nn.Module): def custom_forward(*inputs): layers_ = self.layers[start:end] x_, mask = inputs[0], inputs[1] - + # recover kw_args and output_cross_layer flat_inputs = inputs[2:] kw_args, output_cross_layer = {}, {} @@ -543,19 +548,19 @@ class BaseTransformer(torch.nn.Module): for k, idx in cross_layer_index.items(): output_cross_layer[k] = flat_inputs[idx] # ----------------- - + output_per_layers_part = [] for i, layer in enumerate(layers_): if 'layer_forward' in self.hooks: layer_ret = self.hooks['layer_forward']( - x_, mask, layer_id=layer.layer_id, - **kw_args, **output_cross_layer, + x_, mask, layer_id=layer.layer_id, + **kw_args, **output_cross_layer, output_this_layer={}, output_cross_layer={} ) else: layer_ret = layer( - x_, mask, layer_id=layer.layer_id, - **kw_args, **output_cross_layer, + x_, mask, layer_id=layer.layer_id, + **kw_args, **output_cross_layer, output_this_layer={}, output_cross_layer={} ) if torch.is_tensor(layer_ret): # only hidden_states @@ -563,7 +568,7 @@ class BaseTransformer(torch.nn.Module): elif len(layer_ret) == 2: # hidden_states & output_this_layer x_, output_this_layer = layer_ret output_cross_layer = {} - elif len(layer_ret) == 3: + elif len(layer_ret) == 3: x_, output_this_layer, output_cross_layer = layer_ret assert isinstance(output_this_layer, dict) assert isinstance(output_cross_layer, dict) @@ -590,7 +595,7 @@ class BaseTransformer(torch.nn.Module): # To save memory when only finetuning the final layers, don't use checkpointing. if self.training: hidden_states.requires_grad_(True) - + l, num_layers = 0, len(self.layers) chunk_length = self.checkpoint_num_layers output_this_layer = [] @@ -625,20 +630,20 @@ class BaseTransformer(torch.nn.Module): args = [hidden_states, attention_mask] if 'layer_forward' in self.hooks: # customized layer_forward - layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), - **kw_args, - **output_cross_layer, + layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), + **kw_args, + **output_cross_layer, output_this_layer={}, output_cross_layer={} ) else: - layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer, + layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer, output_this_layer={}, output_cross_layer={}) if torch.is_tensor(layer_ret): # only hidden_states hidden_states, output_this_layer, output_cross_layer = layer_ret, {}, {} elif len(layer_ret) == 2: # hidden_states & output_this_layer hidden_states, output_this_layer = layer_ret output_cross_layer = {} - elif len(layer_ret) == 3: + elif len(layer_ret) == 3: hidden_states, output_this_layer, output_cross_layer = layer_ret if output_hidden_states: diff --git a/examples/t5/test_t5.py b/examples/t5/test_t5.py index a91956cbb69f31c079a06fc97e35ca492a40605c..41b9d2ca84320ab87a6e0eaa3c3b8912b755e014 100644 --- a/examples/t5/test_t5.py +++ b/examples/t5/test_t5.py @@ -1,7 +1,12 @@ from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer -tokenizer = T5Tokenizer.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large") -model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large") -input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids -decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids -output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) -breakpoint() \ No newline at end of file +device = 'cuda:1' +tokenizer = T5Tokenizer.from_pretrained("t5-large") +model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-xl-lm-adapt") +model = model.to(device) +model.eval() +input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids.to(device) +decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids.to(device) +breakpoint() +output = model(input_ids=input_ids, labels=decoder_input_ids) +output.loss.backward() +a = 1 \ No newline at end of file