diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py index cfe97ca6c120316d44c79d988f348ef78450f26e..ddcc8fdeb8f9368d55a4df59f4dd589b4ab0e024 100644 --- a/SwissArmyTransformer/model/t5_model.py +++ b/SwissArmyTransformer/model/t5_model.py @@ -6,7 +6,8 @@ from .encoder_decoder_model import EncoderDecoderModel from SwissArmyTransformer.mpu import get_model_parallel_world_size from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region -from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim +from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim, unscaled_init_method +from SwissArmyTransformer.mpu.layers import ColumnParallelLinear, VocabParallelEmbedding class T5PositionEmbeddingMixin(BaseMixin): @@ -142,7 +143,7 @@ class T5AttentionMixin(BaseMixin): kw_args['output_cross_layer']['position_bias'] = position_bias - return output + return output def cross_attention_forward(self, hidden_states, cross_attention_mask, encoder_outputs, layer_id=None, *args, **kw_args): @@ -173,17 +174,63 @@ class T5AttentionMixin(BaseMixin): class T5DecoderFinalMixin(BaseMixin): - def __init__(self, hidden_size): + def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True): super().__init__() self.hidden_size = hidden_size def final_forward(self, logits, **kwargs): logits_parallel = copy_to_model_parallel_region(logits) - logits_parallel = logits_parallel * (self.hidden_size ** -0.5) - logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight) + if self.tie_word_embeddings: + logits_parallel = logits_parallel * (self.hidden_size ** -0.5) + logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight) + else: + logits_parallel = F.linear(logits_parallel, self.lm_head.weight) return logits_parallel +def t5_gelu(x): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5GatedGeluMLPMixin(BaseMixin): + def __init__(self, num_layers, hidden_size, inner_hidden_size=None, bias=True, init_method_std=0.02): + super().__init__() + self.hidden_size = hidden_size + if inner_hidden_size is None: + inner_hidden_size = 4 * hidden_size + self.inner_hidden_size = inner_hidden_size + self.init_method_std = init_method_std + self.gated_h_to_4h_list = torch.nn.ModuleList([ + ColumnParallelLinear( + self.hidden_size, + self.inner_hidden_size, + gather_output=False, + init_method=self._init_weights, + bias=bias, + module=self, + name="gated_h_to_4h" + ) + for layer_id in range(num_layers)]) + + def _init_weights(self, weight, **kwargs): + torch.nn.init.normal_(weight, mean=0, std=self.init_method_std * (self.hidden_size ** -0.5)) + + def mlp_forward(self, hidden_states, layer_id=None, **kw_args): + mlp_module = self.transformer.layers[layer_id].mlp + hidden_gelu = t5_gelu(mlp_module.dense_h_to_4h(hidden_states)) + hidden_linear = self.gated_h_to_4h_list[layer_id](hidden_states) + hidden_states = hidden_gelu * hidden_linear + output = mlp_module.dense_4h_to_h(hidden_states) + + if self.training: + output = mlp_module.dropout(output) + return output + + class T5Model(EncoderDecoderModel): def __init__(self, args, **kwargs): self.init_method_std = args.init_method_std @@ -205,10 +252,19 @@ class T5Model(EncoderDecoderModel): "t5-position", T5PositionEmbeddingMixin() ) self.decoder.add_mixin( - "t5-final", T5DecoderFinalMixin(args.hidden_size) + "t5-final", + T5DecoderFinalMixin(args.vocab_size, args.hidden_size, tie_word_embeddings=not args.no_share_embeddings) ) del self.decoder.transformer.position_embeddings - + if args.gated_gelu_mlp: + self.encoder.add_mixin( + "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std, + inner_hidden_size=args.inner_hidden_size, bias=False) + ) + self.decoder.add_mixin( + "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std, + inner_hidden_size=args.inner_hidden_size, bias=False) + ) def _init_weights(self, weight, module, name): init_method_std = self.init_method_std @@ -246,6 +302,8 @@ class T5Model(EncoderDecoderModel): super().add_model_specific_args(parser) parser.add_argument("--relative-attention-num-buckets", type=int, default=None) parser.add_argument("--init-method-std", type=float, default=0.02) + parser.add_argument("--gated-gelu-mlp", action='store_true') + parser.add_argument("--no-share-embeddings", action='store_true') def encode(self, input_ids, attention_mask=None, **kw_args): return super().encode(input_ids, None, attention_mask, **kw_args) diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index 3b4ae73cc6609178af67564198be9b15038ab9d2..d9d3181abbdd5ab288d1e3c81f7d8be31fd59969 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -158,7 +158,7 @@ class SelfAttention(torch.nn.Module): if self.training: output = self.output_dropout(output) - return output + return output, None class CrossAttention(torch.nn.Module): diff --git a/examples/t5/test_t5.py b/examples/t5/test_t5.py index a91956cbb69f31c079a06fc97e35ca492a40605c..41b9d2ca84320ab87a6e0eaa3c3b8912b755e014 100644 --- a/examples/t5/test_t5.py +++ b/examples/t5/test_t5.py @@ -1,7 +1,12 @@ from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer -tokenizer = T5Tokenizer.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large") -model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large") -input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids -decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids -output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) -breakpoint() \ No newline at end of file +device = 'cuda:1' +tokenizer = T5Tokenizer.from_pretrained("t5-large") +model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-xl-lm-adapt") +model = model.to(device) +model.eval() +input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids.to(device) +decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids.to(device) +breakpoint() +output = model(input_ids=input_ids, labels=decoder_input_ids) +output.loss.backward() +a = 1 \ No newline at end of file