From 308e87b02bbe27c83883272de2ec9199880d0fd3 Mon Sep 17 00:00:00 2001 From: duzx16 <zx-du20@mails.tsinghua.edu.cn> Date: Sun, 12 Dec 2021 22:11:28 +0800 Subject: [PATCH] Add support for T5 v1.1 --- .../generation/autoregressive_sampling.py | 1 + .../model/cached_autoregressive_model.py | 2 +- SwissArmyTransformer/model/t5_model.py | 73 +++++++++++++++++-- SwissArmyTransformer/mpu/transformer.py | 2 +- .../tokenization/hf_tokenizer.py | 2 +- 5 files changed, 72 insertions(+), 8 deletions(-) diff --git a/SwissArmyTransformer/generation/autoregressive_sampling.py b/SwissArmyTransformer/generation/autoregressive_sampling.py index 4eada99..8fe6909 100644 --- a/SwissArmyTransformer/generation/autoregressive_sampling.py +++ b/SwissArmyTransformer/generation/autoregressive_sampling.py @@ -108,6 +108,7 @@ def filling_sequence( log_attention_weights=log_attention_weights_part, **kw_args ) + mem_kv = [item[0] for item in mem_kv] mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length) counter += 1 index = counter diff --git a/SwissArmyTransformer/model/cached_autoregressive_model.py b/SwissArmyTransformer/model/cached_autoregressive_model.py index a0e1699..ea7cbfa 100755 --- a/SwissArmyTransformer/model/cached_autoregressive_model.py +++ b/SwissArmyTransformer/model/cached_autoregressive_model.py @@ -20,7 +20,7 @@ class CachedAutoregressiveMixin(BaseMixin): def __init__(self): super().__init__() - def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, log_attention_weights=None, **kwargs): + def attention_forward(self, hidden_states, mask, *args, mems=None, layer_id=None, log_attention_weights=None, **kwargs): attn_module = self.transformer.layers[layer_id].attention mem = mems[layer_id] if mems is not None else None diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py index 7143950..03c3412 100644 --- a/SwissArmyTransformer/model/t5_model.py +++ b/SwissArmyTransformer/model/t5_model.py @@ -6,7 +6,8 @@ from .encoder_decoder_model import EncoderDecoderModel from SwissArmyTransformer.mpu import get_model_parallel_world_size from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region -from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim +from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim, unscaled_init_method +from SwissArmyTransformer.mpu.layers import ColumnParallelLinear, VocabParallelEmbedding class T5PositionEmbeddingMixin(BaseMixin): @@ -171,17 +172,67 @@ class T5AttentionMixin(BaseMixin): class T5DecoderFinalMixin(BaseMixin): - def __init__(self, hidden_size): + def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True): super().__init__() self.hidden_size = hidden_size + self.tie_word_embeddings = tie_word_embeddings + if not tie_word_embeddings: + self.lm_head = VocabParallelEmbedding( + vocab_size, hidden_size, init_method=unscaled_init_method(0.02)) def final_forward(self, logits, **kwargs): logits_parallel = copy_to_model_parallel_region(logits) - logits_parallel = logits_parallel * (self.hidden_size ** -0.5) - logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight) + if self.tie_word_embeddings: + logits_parallel = logits_parallel * (self.hidden_size ** -0.5) + logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight) + else: + logits_parallel = F.linear(logits_parallel, self.lm_head.weight) return logits_parallel +def t5_gelu(x): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5GatedGeluMLPMixin(BaseMixin): + def __init__(self, num_layers, hidden_size, inner_hidden_size=None, bias=True, init_method_std=0.02): + super().__init__() + self.hidden_size = hidden_size + if inner_hidden_size is None: + inner_hidden_size = 4 * hidden_size + self.inner_hidden_size = inner_hidden_size + self.init_method_std = init_method_std + self.gated_h_to_4h_list = torch.nn.ModuleList([ + ColumnParallelLinear( + self.hidden_size, + self.inner_hidden_size, + gather_output=False, + init_method=self._init_weights, + bias=bias, + module=self, + name="gated_h_to_4h" + ) + for layer_id in range(num_layers)]) + + def _init_weights(self, weight, **kwargs): + torch.nn.init.normal_(weight, mean=0, std=self.init_method_std * (self.hidden_size ** -0.5)) + + def mlp_forward(self, hidden_states, layer_id=None, **kw_args): + mlp_module = self.transformer.layers[layer_id].mlp + hidden_gelu = t5_gelu(mlp_module.dense_h_to_4h(hidden_states)) + hidden_linear = self.gated_h_to_4h_list[layer_id](hidden_states) + hidden_states = hidden_gelu * hidden_linear + output = mlp_module.dense_4h_to_h(hidden_states) + + if self.training: + output = mlp_module.dropout(output) + return output + + class T5Model(EncoderDecoderModel): def __init__(self, args, **kwargs): self.init_method_std = args.init_method_std @@ -203,9 +254,19 @@ class T5Model(EncoderDecoderModel): "t5-position", T5PositionEmbeddingMixin() ) self.decoder.add_mixin( - "t5-final", T5DecoderFinalMixin(args.hidden_size) + "t5-final", + T5DecoderFinalMixin(args.vocab_size, args.hidden_size, tie_word_embeddings=not args.no_share_embeddings) ) del self.decoder.transformer.position_embeddings + if args.gated_gelu_mlp: + self.encoder.add_mixin( + "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std, + inner_hidden_size=args.inner_hidden_size, bias=False) + ) + self.decoder.add_mixin( + "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std, + inner_hidden_size=args.inner_hidden_size, bias=False) + ) def _init_weights(self, weight, module, name): init_method_std = self.init_method_std @@ -243,6 +304,8 @@ class T5Model(EncoderDecoderModel): super().add_model_specific_args(parser) parser.add_argument("--relative-attention-num-buckets", type=int, default=None) parser.add_argument("--init-method-std", type=float, default=0.02) + parser.add_argument("--gated-gelu-mlp", action='store_true') + parser.add_argument("--no-share-embeddings", action='store_true') def encode(self, input_ids, attention_mask=None, **kw_args): return super().encode(input_ids, None, attention_mask, **kw_args) diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index cab65a4..9d1b3a9 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -153,7 +153,7 @@ class SelfAttention(torch.nn.Module): if self.training: output = self.output_dropout(output) - return output + return output, None class CrossAttention(torch.nn.Module): diff --git a/SwissArmyTransformer/tokenization/hf_tokenizer.py b/SwissArmyTransformer/tokenization/hf_tokenizer.py index e40adb4..790b60d 100644 --- a/SwissArmyTransformer/tokenization/hf_tokenizer.py +++ b/SwissArmyTransformer/tokenization/hf_tokenizer.py @@ -5,7 +5,7 @@ from .glm.tokenization import Tokenization, CommandToken PRETRAINED_VOCAB_FILES_MAP = { "t5-small": "/dataset/fd5061f6/yanan/huggingface_models/t5-small", "t5-base": "/dataset/fd5061f6/yanan/huggingface_models/t5-base", - "t5-large": "/mnt/t5", + "t5-large": "/dataset/fd5061f6/yanan/huggingface_models/t5-large", "t5-3b": "/dataset/fd5061f6/yanan/huggingface_models/t5-3b", "t5-11b": "/dataset/fd5061f6/yanan/huggingface_models/t5-11b" } -- GitLab