From 308e87b02bbe27c83883272de2ec9199880d0fd3 Mon Sep 17 00:00:00 2001
From: duzx16 <zx-du20@mails.tsinghua.edu.cn>
Date: Sun, 12 Dec 2021 22:11:28 +0800
Subject: [PATCH] Add support for T5 v1.1

---
 .../generation/autoregressive_sampling.py     |  1 +
 .../model/cached_autoregressive_model.py      |  2 +-
 SwissArmyTransformer/model/t5_model.py        | 73 +++++++++++++++++--
 SwissArmyTransformer/mpu/transformer.py       |  2 +-
 .../tokenization/hf_tokenizer.py              |  2 +-
 5 files changed, 72 insertions(+), 8 deletions(-)

diff --git a/SwissArmyTransformer/generation/autoregressive_sampling.py b/SwissArmyTransformer/generation/autoregressive_sampling.py
index 4eada99..8fe6909 100644
--- a/SwissArmyTransformer/generation/autoregressive_sampling.py
+++ b/SwissArmyTransformer/generation/autoregressive_sampling.py
@@ -108,6 +108,7 @@ def filling_sequence(
             log_attention_weights=log_attention_weights_part,
             **kw_args
         )
+        mem_kv = [item[0] for item in mem_kv]
         mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
         counter += 1
         index = counter
diff --git a/SwissArmyTransformer/model/cached_autoregressive_model.py b/SwissArmyTransformer/model/cached_autoregressive_model.py
index a0e1699..ea7cbfa 100755
--- a/SwissArmyTransformer/model/cached_autoregressive_model.py
+++ b/SwissArmyTransformer/model/cached_autoregressive_model.py
@@ -20,7 +20,7 @@ class CachedAutoregressiveMixin(BaseMixin):
     def __init__(self):
         super().__init__()
         
-    def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, log_attention_weights=None, **kwargs):
+    def attention_forward(self, hidden_states, mask, *args, mems=None, layer_id=None, log_attention_weights=None, **kwargs):
         attn_module = self.transformer.layers[layer_id].attention
         mem = mems[layer_id] if mems is not None else None
         
diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
index 7143950..03c3412 100644
--- a/SwissArmyTransformer/model/t5_model.py
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -6,7 +6,8 @@ from .encoder_decoder_model import EncoderDecoderModel
 from SwissArmyTransformer.mpu import get_model_parallel_world_size
 from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP
 from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
-from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim
+from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim, unscaled_init_method
+from SwissArmyTransformer.mpu.layers import ColumnParallelLinear, VocabParallelEmbedding
 
 
 class T5PositionEmbeddingMixin(BaseMixin):
@@ -171,17 +172,67 @@ class T5AttentionMixin(BaseMixin):
 
 
 class T5DecoderFinalMixin(BaseMixin):
-    def __init__(self, hidden_size):
+    def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True):
         super().__init__()
         self.hidden_size = hidden_size
+        self.tie_word_embeddings = tie_word_embeddings
+        if not tie_word_embeddings:
+            self.lm_head = VocabParallelEmbedding(
+                vocab_size, hidden_size, init_method=unscaled_init_method(0.02))
 
     def final_forward(self, logits, **kwargs):
         logits_parallel = copy_to_model_parallel_region(logits)
-        logits_parallel = logits_parallel * (self.hidden_size ** -0.5)
-        logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight)
+        if self.tie_word_embeddings:
+            logits_parallel = logits_parallel * (self.hidden_size ** -0.5)
+            logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight)
+        else:
+            logits_parallel = F.linear(logits_parallel, self.lm_head.weight)
         return logits_parallel
 
 
+def t5_gelu(x):
+    """
+    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
+    the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+    """
+    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
+
+
+class T5GatedGeluMLPMixin(BaseMixin):
+    def __init__(self, num_layers, hidden_size, inner_hidden_size=None, bias=True, init_method_std=0.02):
+        super().__init__()
+        self.hidden_size = hidden_size
+        if inner_hidden_size is None:
+            inner_hidden_size = 4 * hidden_size
+        self.inner_hidden_size = inner_hidden_size
+        self.init_method_std = init_method_std
+        self.gated_h_to_4h_list = torch.nn.ModuleList([
+            ColumnParallelLinear(
+                self.hidden_size,
+                self.inner_hidden_size,
+                gather_output=False,
+                init_method=self._init_weights,
+                bias=bias,
+                module=self,
+                name="gated_h_to_4h"
+            )
+            for layer_id in range(num_layers)])
+
+    def _init_weights(self, weight, **kwargs):
+        torch.nn.init.normal_(weight, mean=0, std=self.init_method_std * (self.hidden_size ** -0.5))
+
+    def mlp_forward(self, hidden_states, layer_id=None, **kw_args):
+        mlp_module = self.transformer.layers[layer_id].mlp
+        hidden_gelu = t5_gelu(mlp_module.dense_h_to_4h(hidden_states))
+        hidden_linear = self.gated_h_to_4h_list[layer_id](hidden_states)
+        hidden_states = hidden_gelu * hidden_linear
+        output = mlp_module.dense_4h_to_h(hidden_states)
+
+        if self.training:
+            output = mlp_module.dropout(output)
+        return output
+
+
 class T5Model(EncoderDecoderModel):
     def __init__(self, args, **kwargs):
         self.init_method_std = args.init_method_std
@@ -203,9 +254,19 @@ class T5Model(EncoderDecoderModel):
             "t5-position", T5PositionEmbeddingMixin()
         )
         self.decoder.add_mixin(
-            "t5-final", T5DecoderFinalMixin(args.hidden_size)
+            "t5-final",
+            T5DecoderFinalMixin(args.vocab_size, args.hidden_size, tie_word_embeddings=not args.no_share_embeddings)
         )
         del self.decoder.transformer.position_embeddings
+        if args.gated_gelu_mlp:
+            self.encoder.add_mixin(
+                "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std,
+                                                 inner_hidden_size=args.inner_hidden_size, bias=False)
+            )
+            self.decoder.add_mixin(
+                "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std,
+                                                 inner_hidden_size=args.inner_hidden_size, bias=False)
+            )
 
     def _init_weights(self, weight, module, name):
         init_method_std = self.init_method_std
@@ -243,6 +304,8 @@ class T5Model(EncoderDecoderModel):
         super().add_model_specific_args(parser)
         parser.add_argument("--relative-attention-num-buckets", type=int, default=None)
         parser.add_argument("--init-method-std", type=float, default=0.02)
+        parser.add_argument("--gated-gelu-mlp", action='store_true')
+        parser.add_argument("--no-share-embeddings", action='store_true')
 
     def encode(self, input_ids, attention_mask=None, **kw_args):
         return super().encode(input_ids, None, attention_mask, **kw_args)
diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index cab65a4..9d1b3a9 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -153,7 +153,7 @@ class SelfAttention(torch.nn.Module):
             if self.training:
                 output = self.output_dropout(output)
 
-            return output
+            return output, None
 
 
 class CrossAttention(torch.nn.Module):
diff --git a/SwissArmyTransformer/tokenization/hf_tokenizer.py b/SwissArmyTransformer/tokenization/hf_tokenizer.py
index e40adb4..790b60d 100644
--- a/SwissArmyTransformer/tokenization/hf_tokenizer.py
+++ b/SwissArmyTransformer/tokenization/hf_tokenizer.py
@@ -5,7 +5,7 @@ from .glm.tokenization import Tokenization, CommandToken
 PRETRAINED_VOCAB_FILES_MAP = {
     "t5-small": "/dataset/fd5061f6/yanan/huggingface_models/t5-small",
     "t5-base": "/dataset/fd5061f6/yanan/huggingface_models/t5-base",
-    "t5-large": "/mnt/t5",
+    "t5-large": "/dataset/fd5061f6/yanan/huggingface_models/t5-large",
     "t5-3b": "/dataset/fd5061f6/yanan/huggingface_models/t5-3b",
     "t5-11b": "/dataset/fd5061f6/yanan/huggingface_models/t5-11b"
 }
-- 
GitLab