From cc852e00817460b110ff5b3639e5c56b0b72ba68 Mon Sep 17 00:00:00 2001
From: duzx16 <zx-du20@mails.tsinghua.edu.cn>
Date: Thu, 2 Dec 2021 21:16:24 +0800
Subject: [PATCH] Implement T5 model

---
 SwissArmyTransformer/arguments.py             |   2 +
 SwissArmyTransformer/model/__init__.py        |   3 +-
 SwissArmyTransformer/model/base_model.py      |  31 +--
 .../model/encoder_decoder_model.py            | 192 ++++++++--------
 SwissArmyTransformer/model/t5_model.py        | 190 ++++++++++++++++
 SwissArmyTransformer/mpu/transformer.py       | 212 ++++++++++++++----
 6 files changed, 474 insertions(+), 156 deletions(-)
 create mode 100644 SwissArmyTransformer/model/t5_model.py

diff --git a/SwissArmyTransformer/arguments.py b/SwissArmyTransformer/arguments.py
index dfa75c9..97e95ee 100755
--- a/SwissArmyTransformer/arguments.py
+++ b/SwissArmyTransformer/arguments.py
@@ -33,6 +33,8 @@ def add_model_config_args(parser):
                        help='num of transformer attention heads')
     group.add_argument('--hidden-size', type=int, default=1024,
                        help='tansformer hidden size')
+    group.add_argument('--inner-hidden-size', type=int, default=None)
+    group.add_argument('--hidden-size-per-attention-head', type=int, default=None)
     group.add_argument('--num-layers', type=int, default=24,
                        help='num decoder layers')
     group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
diff --git a/SwissArmyTransformer/model/__init__.py b/SwissArmyTransformer/model/__init__.py
index 32f46e4..4fbcd53 100755
--- a/SwissArmyTransformer/model/__init__.py
+++ b/SwissArmyTransformer/model/__init__.py
@@ -2,4 +2,5 @@ from .base_model import BaseModel
 from .cached_autoregressive_model import CachedAutoregressiveModel
 from .cuda2d_model import Cuda2dModel
 from .glm_model import GLMModel
-from .encoder_decoder_model import EncoderDecoderModel
\ No newline at end of file
+from .encoder_decoder_model import EncoderDecoderModel
+from .t5_model import T5Model
diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py
index c9e1c90..5e5793e 100644
--- a/SwissArmyTransformer/model/base_model.py
+++ b/SwissArmyTransformer/model/base_model.py
@@ -13,20 +13,23 @@ import math
 import random
 import torch
 
-from SwissArmyTransformer.mpu import BaseTransformer
+from SwissArmyTransformer.mpu import BaseTransformer, LayerNorm
+
 
 class BaseMixin(torch.nn.Module):
     def __init__(self):
         super(BaseMixin, self).__init__()
         # define new params
+
     def reinit(self, *pre_mixins):
         # reload the initial params from previous trained modules
         pass
     # can also define hook-functions here
     # ...
 
+
 class BaseModel(torch.nn.Module):
-    def __init__(self, args, transformer=None, parallel_output=True):
+    def __init__(self, args, transformer=None, **kwargs):
         super(BaseModel, self).__init__()
         self.mixins = torch.nn.ModuleDict()
         self.collect_hooks_()
@@ -42,14 +45,16 @@ class BaseModel(torch.nn.Module):
                 embedding_dropout_prob=args.hidden_dropout,
                 attention_dropout_prob=args.attention_dropout,
                 output_dropout_prob=args.hidden_dropout,
+                inner_hidden_size=args.inner_hidden_size,
+                hidden_size_per_attention_head=args.hidden_size_per_attention_head,
                 checkpoint_activations=args.checkpoint_activations,
                 checkpoint_num_layers=args.checkpoint_num_layers,
                 sandwich_ln=args.sandwich_ln,
-                parallel_output=parallel_output,
-                hooks=self.hooks
+                hooks=self.hooks,
+                **kwargs
             )
 
-    def reinit(self): # will be called when loading model
+    def reinit(self):  # will be called when loading model
         # if some mixins are loaded, overrides this function
         for m in self.mixins.values():
             m.reinit(self.transformer)
@@ -58,11 +63,11 @@ class BaseModel(torch.nn.Module):
         assert name not in self.mixins
         assert isinstance(new_mixin, BaseMixin)
 
-        self.mixins[name] = new_mixin # will auto-register parameters
-        object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr
+        self.mixins[name] = new_mixin  # will auto-register parameters
+        object.__setattr__(new_mixin, 'transformer', self.transformer)  # cannot use pytorch set_attr
 
         if reinit:
-            new_mixin.reinit(self.transformer, **self.mixins) # also pass current mixins
+            new_mixin.reinit(self.transformer, **self.mixins)  # also pass current mixins
         self.collect_hooks_()
 
     def del_mixin(self, name):
@@ -82,15 +87,15 @@ class BaseModel(torch.nn.Module):
 
     def collect_hooks_(self):
         names = ['word_embedding_forward', 'position_embedding_forward',
-                'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
-                'branch_embedding_forward', 'branch_final_forward'
-                ]
+                 'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
+                 'branch_embedding_forward', 'branch_final_forward'
+                 ]
         hooks = {}
         hook_origins = {}
         for name in names:
             for mixin_name, m in self.mixins.items():
                 if hasattr(m, name):
-                    if name in hooks: # conflict
+                    if name in hooks:  # conflict
                         raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
                     hooks[name] = getattr(m, name)
                     hook_origins[name] = mixin_name
@@ -104,4 +109,4 @@ class BaseModel(torch.nn.Module):
         return hooks
 
     def disable_untrainable_params(self):
-        pass
\ No newline at end of file
+        pass
diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py
index 7e868f2..1beace8 100644
--- a/SwissArmyTransformer/model/encoder_decoder_model.py
+++ b/SwissArmyTransformer/model/encoder_decoder_model.py
@@ -14,108 +14,90 @@ import random
 import torch
 import argparse
 from .base_model import BaseModel, BaseMixin
-from .common_layers import CrossAttention, LayerNorm
-
-
-class CrossAttentionMixin(BaseMixin):
-    def __init__(self, num_layers, hidden_size, num_attention_heads,
-                 attention_dropout_prob, output_dropout_prob,
-                 init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None):
-        super().__init__()
-
-        self.cross_attentions = torch.nn.ModuleList(
-            [CrossAttention(
-                hidden_size, num_attention_heads,
-                attention_dropout_prob, output_dropout_prob,
-                init_method, enc_hidden_size=enc_hidden_size, inner_hidden_size=inner_hidden_size,
-                output_layer_init_method=output_layer_init_method
-            ) for layer_id in range(num_layers)]
-        )  # Just copy args
-        self.cross_lns = torch.nn.ModuleList(
-            [LayerNorm(hidden_size, 1e-5)
-             for layer_id in range(num_layers)]
-        )
-
-    def layer_forward(self, hidden_states, mask, layer_id, **kw_args):
-        layer = self.transformer.layers[layer_id]
-        encoder_outputs = kw_args['encoder_outputs']
-        '''
-            hidden_states: [batch, seq_len, hidden_size]
-            mask: [(1, 1), seq_len, seq_len]
-            encoder_outputs: [batch, enc_seq_len, enc_hidden_size]
-        '''
-        # Layer norm at the begining of the transformer layer.
-        layernorm_output = layer.input_layernorm(hidden_states)
-        attention_output, output_this_layer = layer.attention(layernorm_output, mask, **kw_args)
-        # Third LayerNorm
-        if layer.sandwich_ln:
-            attention_output = layer.third_layernorm(attention_output)
-        # Residual connection.
-        hidden_states = hidden_states + attention_output
-
-        # Cross attention.
-        layernorm_output = self.cross_lns[layer_id](hidden_states)
-        cross_attn_output = self.cross_attentions[layer_id](
-            layernorm_output,
-            torch.ones(1, 1, device=hidden_states.device, dtype=hidden_states.dtype),
-            encoder_outputs
+from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
+from .common_layers import LayerNorm
+
+
+def get_extended_attention_mask(attention_mask, input_shape, device, dtype=torch.float32, is_decoder=False):
+    """
+    Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+    Arguments:
+        attention_mask (:obj:`torch.Tensor`):
+            Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+        input_shape (:obj:`Tuple[int]`):
+            The shape of the input to the model.
+        device: (:obj:`torch.device`):
+            The device of the input to the model.
+        dtype:
+        is_decoder:
+
+    Returns:
+        :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+    """
+    # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+    # ourselves in which case we just need to make it broadcastable to all heads.
+    if attention_mask is None or attention_mask.dim() == 2:
+        batch_size, seq_length = input_shape
+        # Provided a padding mask of dimensions [batch_size, seq_length]
+        # - if the model is a decoder, apply a causal mask in addition to the padding mask
+        # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+        if is_decoder:
+            seq_ids = torch.arange(seq_length, device=device)
+            causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+            # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+            # causal and attention masks must have same type with pytorch version < 1.3
+            causal_mask = causal_mask.to(dtype)
+            if attention_mask is not None:
+                if causal_mask.shape[1] < attention_mask.shape[1]:
+                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+                    causal_mask = torch.cat(
+                        [torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
+                         causal_mask], axis=-1)
+
+                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+        else:
+            if attention_mask is None:
+                extended_attention_mask = torch.ones(1, 1, 1, seq_length, device=device, dtype=dtype)
+            else:
+                extended_attention_mask = attention_mask[:, None, None, :]
+    elif attention_mask.dim() == 3:
+        extended_attention_mask = attention_mask[:, None, :, :]
+    else:
+        raise ValueError(
+            f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
         )
-        hidden_states = hidden_states + cross_attn_output
-
-        # Layer norm post the layer attention.
-        layernorm_output = layer.post_attention_layernorm(hidden_states)
-        # MLP.
-        mlp_output = layer.mlp(layernorm_output, **kw_args)
-
-        # Fourth LayerNorm
-        if layer.sandwich_ln:
-            mlp_output = layer.fourth_layernorm(mlp_output)
-        output = hidden_states + mlp_output
+    return extended_attention_mask
 
-        return output, output_this_layer
 
-
-class DecoderModel(BaseModel):
-    def __init__(self, args, transformer=None):
-        dec_args = argparse.Namespace(**vars(args))
-        dec_args.enc_hidden_size = dec_args.hidden_size  # used for cross attn
-        override_attrs = ['num_layers', 'vocab_size',
-                          'hidden_size', 'num_attention_heads',
-                          'max_sequence_length', 'sandwich_ln'  # TODO
-                          ]
-        for name in override_attrs:
-            dec_attr = getattr(dec_args, 'dec_' + name, None)
-            if dec_attr is not None:  # else use encoder-config
-                setattr(dec_args, name, dec_attr)
-
-        super().__init__(dec_args, transformer=transformer)
-        self.add_mixin('cross_attention',
-                       CrossAttentionMixin(
-                           dec_args.num_layers,
-                           dec_args.hidden_size, dec_args.num_attention_heads,
-                           dec_args.attention_dropout, dec_args.hidden_dropout,
-                           self.transformer.init_method,
-                           enc_hidden_size=dec_args.enc_hidden_size,
-                           inner_hidden_size=getattr(dec_args, 'dec_inner_hidden_size', None),
-                           output_layer_init_method=self.transformer.output_layer_init_method
-                       )
-                       )
+class EncoderFinalMixin(BaseMixin):
+    def final_forward(self, logits, **kwargs):
+        logits = copy_to_model_parallel_region(logits)
+        return logits
 
 
 class EncoderDecoderModel(torch.nn.Module):
-    def __init__(self, args, encoder=None, decoder=None):
+    def __init__(self, args, encoder=None, decoder=None, parallel_output=False, **kwargs):
         super(EncoderDecoderModel, self).__init__()
         if encoder is not None:
             assert isinstance(encoder, BaseModel)
             self.encoder = encoder
         else:
-            self.encoder = BaseModel(args)
-
+            self.encoder = BaseModel(args, **kwargs)
+        self.encoder.add_mixin("final", EncoderFinalMixin())
         if decoder is not None:
             assert isinstance(decoder, BaseModel)
             self.decoder = decoder
         else:
-            self.decoder = DecoderModel(args)
+            dec_args = argparse.Namespace(**vars(args))
+            dec_args.enc_hidden_size = dec_args.hidden_size  # used for cross attn
+            override_attrs = ['num_layers', 'hidden_size', 'num_attention_heads',
+                              'max_sequence_length', 'inner_hidden_size', 'hidden_size_per_attention_head']
+            for name in override_attrs:
+                dec_attr = getattr(dec_args, 'dec_' + name, None)
+                if dec_attr is not None:  # else use encoder-config
+                    setattr(dec_args, name, dec_attr)
+            self.decoder = BaseModel(args, is_decoder=True, parallel_output=parallel_output, **kwargs)
 
     def reinit(self):
         self.encoder.reinit()
@@ -125,23 +107,29 @@ class EncoderDecoderModel(torch.nn.Module):
         self.encoder.disable_untrainable_params()
         self.decoder.disable_untrainable_params()
 
-    def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids, dec_attention_mask, *,
-                branch_input=None, **kw_args):
-        mask_one = torch.ones(1, 1, device=enc_input_ids.device, dtype=dec_attention_mask.dtype)
-        enc_outputs, *_dumps = self.encoder(enc_input_ids, enc_position_ids, mask_one, branch_input=branch_input,
-                                            **kw_args)
-        dec_outputs, *dec_mems = self.decoder(dec_input_ids, dec_position_ids, dec_attention_mask,
-                                              encoder_outputs=enc_outputs, branch_input=branch_input, **kw_args)
-        return enc_outputs, dec_outputs, *dec_mems
+    def forward(self, input_ids=None, input_position_ids=None, attention_mask=None, decoder_input_ids=None,
+                decoder_position_ids=None, decoder_attention_mask=None,
+                **kw_args):
+        dtype = self.encoder.transformer.word_embeddings.weight.dtype
+        batch_size, encoder_seq_length = input_ids.size()[:2]
+        encoder_attention_mask = get_extended_attention_mask(attention_mask, (batch_size, encoder_seq_length),
+                                                             device=input_ids.device, dtype=dtype)
+        decoder_seq_length = decoder_input_ids.size(1)
+        encoder_outputs, *_dumps = self.encoder(input_ids, input_position_ids, encoder_attention_mask, **kw_args)
+        decoder_attention_mask = get_extended_attention_mask(decoder_attention_mask, (batch_size, decoder_seq_length),
+                                                             device=input_ids.device, dtype=dtype)
+        decoder_outputs, *decoder_mems = self.decoder(decoder_input_ids, decoder_position_ids, decoder_attention_mask,
+                                                      encoder_outputs=encoder_outputs,
+                                                      cross_attention_mask=encoder_attention_mask, **kw_args)
+        return encoder_outputs, decoder_outputs, *decoder_mems
 
     @classmethod
     def add_model_specific_args(cls, parser):
         group = parser.add_argument_group('EncoderDecoderModel', 'T5 or Bart')
-        group.add_argument("--dec_num_layers", type=int, default=None)
-        group.add_argument("--dec_vocab_size", type=int, default=None)
-        group.add_argument("--dec_hidden_size", type=int, default=None)
-        group.add_argument("--dec_num_attention_heads", type=int, default=None)
-        group.add_argument("--dec_max_sequence_length", type=int, default=None)
-        group.add_argument("--dec_sandwich_ln", action='store_true')
-        group.add_argument("--dec_inner_hidden_size", type=int, default=None)
+        group.add_argument("--dec-num-layers", type=int, default=None)
+        group.add_argument("--dec-hidden-size", type=int, default=None)
+        group.add_argument("--dec-num-attention-heads", type=int, default=None)
+        group.add_argument("--dec-max-sequence-length", type=int, default=None)
+        group.add_argument("--dec-inner-hidden-size", type=int, default=None)
+        group.add_argument("--dec-hidden-size-per-attention-head", type=int, default=None)
         return parser
diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
new file mode 100644
index 0000000..bf26a3b
--- /dev/null
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -0,0 +1,190 @@
+import math
+import torch
+from .mixins import BaseMixin
+from .encoder_decoder_model import EncoderDecoderModel
+from SwissArmyTransformer.mpu import get_model_parallel_world_size
+from SwissArmyTransformer.mpu.transformer import standard_attention
+from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim
+
+
+class T5PositionEmbeddingMixin(BaseMixin):
+    def position_embedding_forward(self, position_ids, **kw_args):
+        return None
+
+
+class T5LayerNorm(torch.nn.Module):
+    def __init__(self, hidden_size, eps=1e-6):
+        """
+        Construct a layernorm module in the T5 style No bias and no subtraction of mean.
+        """
+        super().__init__()
+        self.weight = torch.nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states):
+        # layer norm should always be calculated in float32
+        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+        # convert into float16 if necessary
+        if self.weight.dtype == torch.float16:
+            hidden_states = hidden_states.to(torch.float16)
+        return self.weight * hidden_states
+
+
+class T5AttentionMixin(BaseMixin):
+    def __init__(self, relative_attention_num_buckets, num_attention_heads, is_decoder=False):
+        super().__init__()
+        self.relative_attention_num_buckets = relative_attention_num_buckets
+        world_size = get_model_parallel_world_size()
+        self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
+        self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets,
+                                                          self.num_attention_heads_per_partition)
+        self.is_decoder = is_decoder
+
+    @staticmethod
+    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+        """
+        Adapted from Mesh Tensorflow:
+        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+        Translate relative position to a bucket number for relative attention. The relative position is defined as
+        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+        This should allow for more graceful generalization to longer sequences than the model has been trained on
+
+        Args:
+            relative_position: an int32 Tensor
+            bidirectional: a boolean - whether the attention is bidirectional
+            num_buckets: an integer
+            max_distance: an integer
+
+        Returns:
+            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
+        """
+        relative_buckets = 0
+        if bidirectional:
+            num_buckets //= 2
+            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
+            relative_position = torch.abs(relative_position)
+        else:
+            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
+        # now relative_position is in the range [0, inf)
+
+        # half of the buckets are for exact increments in positions
+        max_exact = num_buckets // 2
+        is_small = relative_position < max_exact
+
+        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+        relative_postion_if_large = max_exact + (
+                torch.log(relative_position.float() / max_exact)
+                / math.log(max_distance / max_exact)
+                * (num_buckets - max_exact)
+        ).to(torch.long)
+        relative_postion_if_large = torch.min(
+            relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
+        )
+
+        relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
+        return relative_buckets
+
+    def compute_bias(self, query_length, key_length, cross_attention=False):
+        """Compute binned relative position bias"""
+        context_position = torch.arange(query_length, dtype=torch.long)[:, None]
+        memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
+        relative_position = memory_position - context_position  # shape (query_length, key_length)
+        relative_position_bucket = self._relative_position_bucket(
+            relative_position,  # shape (query_length, key_length)
+            bidirectional=(not self.is_decoder),
+            num_buckets=self.relative_attention_num_buckets,
+        )
+        relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
+        # shape (query_length, key_length, num_heads)
+        if cross_attention:
+            values = self.cross_relative_attention_bias(relative_position_bucket)
+        else:
+            values = self.relative_attention_bias(relative_position_bucket)
+        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
+        return values
+
+    def attention_forward(self, hidden_states, mask, *args, layer_id=None, mems=None, **kw_args):
+        attn_module = self.transformer.layers[layer_id].attention
+        seq_length = hidden_states.size(1)
+        memory_length = mems[layer_id].size(1) if mems else 0
+        mixed_raw_layer = attn_module.query_key_value(hidden_states)
+        (mixed_query_layer,
+         mixed_key_layer,
+         mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
+
+        dropout_fn = attn_module.attention_dropout if attn_module.training else None
+
+        query_layer = attn_module._transpose_for_scores(mixed_query_layer)
+        key_layer = attn_module._transpose_for_scores(mixed_key_layer)
+        value_layer = attn_module._transpose_for_scores(mixed_value_layer)
+
+        position_bias = self.compute_bias(seq_length, memory_length + seq_length)
+        context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn,
+                                           log_attention_weights=position_bias, scaling_attention_score=False)
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
+        context_layer = context_layer.view(*new_context_layer_shape)
+        output = attn_module.dense(context_layer)
+
+        if attn_module.training:
+            output = attn_module.output_dropout(output)
+
+        return output, None
+
+    def cross_attention_forward(self, hidden_states, cross_mask, encoder_outputs, layer_id=None, *args, **kw_args):
+        attn_module = self.transformer.layers[layer_id].cross_attention
+        mixed_query_layer = attn_module.query(hidden_states)
+        mixed_x_layer = attn_module.key_value(encoder_outputs)
+        (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2)
+
+        dropout_fn = attn_module.attention_dropout if attn_module.training else None
+        # Reshape and transpose [b, np, s, hn]
+        query_layer = attn_module._transpose_for_scores(mixed_query_layer)
+        key_layer = attn_module._transpose_for_scores(mixed_key_layer)
+        value_layer = attn_module._transpose_for_scores(mixed_value_layer)
+
+        context_layer = standard_attention(query_layer, key_layer, value_layer, cross_mask, dropout_fn,
+                                           scaling_attention_score=False)
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
+        # [b, s, hp]
+        context_layer = context_layer.view(*new_context_layer_shape)
+
+        # Output. [b, s, h]
+        output = attn_module.dense(context_layer)
+        if attn_module.training:
+            output = attn_module.output_dropout(output)
+
+        return output
+
+
+class T5Model(EncoderDecoderModel):
+    def __init__(self, args, **kwargs):
+        super().__init__(args, **kwargs, use_bias=False, layernorm=T5LayerNorm)
+        self.encoder.add_mixin(
+            "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads)
+        )
+        self.encoder.add_mixin(
+            "t5-position", T5PositionEmbeddingMixin()
+        )
+        del self.encoder.transformer.position_embeddings
+        num_attention_heads = args.dec_num_attention_heads if args.dec_num_attention_heads is not None else args.num_attention_heads
+        self.decoder.add_mixin(
+            "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, num_attention_heads, is_decoder=True)
+        )
+        self.decoder.add_mixin(
+            "t5-position", T5PositionEmbeddingMixin()
+        )
+        del self.decoder.transformer.position_embeddings
+        self.decoder.transformer.word_embeddings = self.encoder.transformer.word_embeddings
+
+    @classmethod
+    def add_model_specific_args(cls, parser):
+        super().add_model_specific_args(parser)
+        parser.add_argument("--relative-attention-num-buckets", type=int, default=None)
diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index 5dfac59..7a3234b 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -44,14 +44,13 @@ class LayerNorm(FusedLayerNorm):
 
 
 def standard_attention(query_layer, key_layer, value_layer, attention_mask,
-                       attention_dropout=None, log_attention_weights=None):
+                       attention_dropout=None, log_attention_weights=None, scaling_attention_score=True):
     # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. 
     # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. 
 
-    attention_scores = torch.matmul(
-        query_layer / math.sqrt(query_layer.shape[-1]),
-        key_layer.transpose(-1, -2)
-    )
+    if scaling_attention_score:
+        query_layer / math.sqrt(query_layer.shape[-1])
+    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
     if log_attention_weights is not None:
         attention_scores += log_attention_weights
 
@@ -73,7 +72,7 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask,
 class SelfAttention(torch.nn.Module):
     def __init__(self, hidden_size, num_attention_heads,
                  attention_dropout_prob, output_dropout_prob,
-                 init_method, layer_id, output_layer_init_method=None,
+                 init_method, layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True,
                  hooks={}):
         super(SelfAttention, self).__init__()
         # Set output layer initialization if not provided.
@@ -83,25 +82,31 @@ class SelfAttention(torch.nn.Module):
         self.layer_id = layer_id
         # Per attention head and per partition values.
         world_size = get_model_parallel_world_size()
-        self.hidden_size_per_partition = divide(hidden_size, world_size)
-        self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
+        if hidden_size_per_attention_head is None:
+            self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
+        else:
+            self.hidden_size_per_attention_head = hidden_size_per_attention_head
         self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
+        inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
+        self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
 
         # Strided linear layer.
         self.query_key_value = ColumnParallelLinear(
             hidden_size,
-            3 * hidden_size,
+            3 * inner_hidden_size,
             stride=3,
             gather_output=False,
-            init_method=init_method
+            init_method=init_method,
+            bias=bias
         )
         self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
 
         self.dense = RowParallelLinear(
-            hidden_size,
+            inner_hidden_size,
             hidden_size,
             input_is_parallel=True,
-            init_method=output_layer_init_method
+            init_method=output_layer_init_method,
+            bias=bias
         )
         self.output_dropout = torch.nn.Dropout(output_dropout_prob)
 
@@ -115,7 +120,7 @@ class SelfAttention(torch.nn.Module):
         tensor = tensor.view(*new_tensor_shape)
         return tensor.permute(0, 2, 1, 3)
 
-    def forward(self, hidden_states, mask, **kw_args):
+    def forward(self, hidden_states, mask, *args, **kw_args):
         if 'attention_forward' in self.hooks:
             return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id)
         else:
@@ -142,9 +147,91 @@ class SelfAttention(torch.nn.Module):
             return output, None
 
 
+class CrossAttention(torch.nn.Module):
+    """Parallel cross-attention layer for Transformer"""
+
+    def __init__(self, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, init_method,
+                 layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True, hooks={}):
+        super().__init__()
+        # Set output layer initialization if not provided.
+        if output_layer_init_method is None:
+            output_layer_init_method = init_method
+        self.hooks = hooks
+        self.layer_id = layer_id
+        # Per attention head and per partition values.
+        world_size = get_model_parallel_world_size()
+        if hidden_size_per_attention_head is None:
+            self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
+        else:
+            self.hidden_size_per_attention_head = hidden_size_per_attention_head
+        self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
+        inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
+        self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
+        # Strided linear layer.
+        self.query = ColumnParallelLinear(hidden_size, inner_hidden_size,
+                                          gather_output=False,
+                                          init_method=init_method, bias=bias)
+        self.key_value = ColumnParallelLinear(hidden_size, 2 * inner_hidden_size,
+                                              stride=2,
+                                              gather_output=False,
+                                              init_method=init_method, bias=bias)
+        # Dropout. Note that for a single iteration, this layer will generate
+        # different outputs on different number of parallel partitions but
+        # on average it should not be partition dependent.
+        self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
+
+        # Output.
+        self.dense = RowParallelLinear(
+            inner_hidden_size,
+            hidden_size,
+            input_is_parallel=True,
+            init_method=output_layer_init_method, bias=bias)
+        self.output_dropout = torch.nn.Dropout(output_dropout_prob)
+
+    def _transpose_for_scores(self, tensor):
+        """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
+        size [b, np, s, hn].
+        """
+        new_tensor_shape = tensor.size()[:-1] + \
+                           (self.num_attention_heads_per_partition,
+                            self.hidden_size_per_attention_head)
+        tensor = tensor.view(*new_tensor_shape)
+        return tensor.permute(0, 2, 1, 3)
+
+    def forward(self, hidden_states, encoder_outputs, *args, cross_attention_mask=None, **kw_args):
+        # hidden_states: [b, s, h]
+        if 'cross_attention_forward' in self.hooks:
+            return self.hooks['cross_attention_forward'](hidden_states, encoder_outputs,
+                                                         cross_attention_mask=cross_attention_mask, **kw_args,
+                                                         layer_id=self.layer_id)
+        else:
+            mixed_query_layer = self.query(hidden_states)
+            mixed_x_layer = self.key_value(encoder_outputs)
+            (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2)
+
+            dropout_fn = self.attention_dropout if self.training else None
+            # Reshape and transpose [b, np, s, hn]
+            query_layer = self._transpose_for_scores(mixed_query_layer)
+            key_layer = self._transpose_for_scores(mixed_key_layer)
+            value_layer = self._transpose_for_scores(mixed_value_layer)
+
+            context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn)
+            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+            # [b, s, hp]
+            context_layer = context_layer.view(*new_context_layer_shape)
+
+            # Output. [b, s, h]
+            output = self.dense(context_layer)
+            if self.training:
+                output = self.output_dropout(output)
+
+            return output
+
+
 class MLP(torch.nn.Module):
-    def __init__(self, hidden_size, output_dropout_prob, init_method,
-                 output_layer_init_method=None, layer_id=None, hooks={}):
+    def __init__(self, hidden_size, output_dropout_prob, init_method, inner_hidden_size=None,
+                 output_layer_init_method=None, layer_id=None, hooks={}, bias=True):
         super(MLP, self).__init__()
         self.layer_id = layer_id
         # Set output layer initialization if not provided.
@@ -152,18 +239,22 @@ class MLP(torch.nn.Module):
             output_layer_init_method = init_method
         self.hooks = hooks
         # Project to 4h.
+        if inner_hidden_size is None:
+            inner_hidden_size = 4 * hidden_size
         self.dense_h_to_4h = ColumnParallelLinear(
             hidden_size,
-            4 * hidden_size,
+            inner_hidden_size,
             gather_output=False,
-            init_method=init_method
+            init_method=init_method,
+            bias=bias
         )
         # Project back to h.
         self.dense_4h_to_h = RowParallelLinear(
-            4 * hidden_size,
+            inner_hidden_size,
             hidden_size,
             input_is_parallel=True,
-            init_method=output_layer_init_method
+            init_method=output_layer_init_method,
+            bias=bias
         )
         self.dropout = torch.nn.Dropout(output_dropout_prob)
 
@@ -190,8 +281,13 @@ class BaseTransformerLayer(torch.nn.Module):
             layernorm_epsilon,
             init_method,
             layer_id,
+            inner_hidden_size=None,
+            hidden_size_per_attention_head=None,
             output_layer_init_method=None,
             sandwich_ln=True,
+            layernorm=LayerNorm,
+            is_decoder=False,
+            use_bias=True,
             hooks={}
     ):
         super(BaseTransformerLayer, self).__init__()
@@ -199,10 +295,11 @@ class BaseTransformerLayer(torch.nn.Module):
         if output_layer_init_method is None:
             output_layer_init_method = init_method
         self.layer_id = layer_id
+        self.is_decoder = is_decoder
         self.hooks = hooks
 
         # Layernorm on the input data.
-        self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
+        self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
 
         # Self attention.
         self.attention = SelfAttention(
@@ -212,28 +309,48 @@ class BaseTransformerLayer(torch.nn.Module):
             output_dropout_prob,
             init_method,
             layer_id,
+            hidden_size_per_attention_head=hidden_size_per_attention_head,
             output_layer_init_method=output_layer_init_method,
+            bias=use_bias,
             hooks=hooks
         )
 
         # Layernorm on the input data.
-        self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
+        self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
         self.sandwich_ln = sandwich_ln
         if sandwich_ln:
-            self.third_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
-            self.fourth_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
+            self.third_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
+            self.fourth_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
+
+        # Cross attention.
+        if self.is_decoder:
+            self.cross_attention = CrossAttention(
+                hidden_size,
+                num_attention_heads,
+                attention_dropout_prob,
+                output_dropout_prob,
+                init_method,
+                layer_id,
+                hidden_size_per_attention_head=hidden_size_per_attention_head,
+                output_layer_init_method=output_layer_init_method,
+                bias=use_bias,
+                hooks=hooks
+            )
+            self.post_cross_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
 
         # MLP
         self.mlp = MLP(
             hidden_size,
             output_dropout_prob,
             init_method,
+            inner_hidden_size=inner_hidden_size,
             output_layer_init_method=output_layer_init_method,
+            bias=use_bias,
             layer_id=layer_id,
             hooks=hooks
         )
 
-    def forward(self, hidden_states, mask, **kw_args):
+    def forward(self, hidden_states, mask, encoder_outputs=None, **kw_args):
         '''
             hidden_states: [batch, seq_len, hidden_size]
             mask: [(1, 1), seq_len, seq_len]
@@ -252,6 +369,15 @@ class BaseTransformerLayer(torch.nn.Module):
         layernorm_input = hidden_states + attention_output
         # Layer norm post the self attention.
         layernorm_output = self.post_attention_layernorm(layernorm_input)
+
+        if self.is_decoder and encoder_outputs is not None:
+            # Cross attention
+            attention_output = self.cross_attention(layernorm_output, encoder_outputs, **kw_args)
+            # Residual connection.
+            layernorm_input = layernorm_output + attention_output
+            # Layer norm post the cross attention
+            layernorm_output = self.post_cross_attention_layernorm(layernorm_input)
+
         # MLP.
         mlp_output = self.mlp(layernorm_output, **kw_args)
 
@@ -279,13 +405,19 @@ class BaseTransformer(torch.nn.Module):
                  checkpoint_num_layers=1,
                  layernorm_epsilon=1.0e-5,
                  init_method_std=0.02,
+                 inner_hidden_size=None,
+                 hidden_size_per_attention_head=None,
                  sandwich_ln=True,
                  parallel_output=True,
+                 is_decoder=False,
+                 use_bias=True,
+                 layernorm=LayerNorm,
                  hooks={}
                  ):
         super(BaseTransformer, self).__init__()
 
         # recording parameters
+        self.is_decoder = is_decoder
         self.parallel_output = parallel_output
         self.checkpoint_activations = checkpoint_activations
         self.checkpoint_num_layers = checkpoint_num_layers
@@ -314,8 +446,13 @@ class BaseTransformer(torch.nn.Module):
                 layernorm_epsilon,
                 self.init_method,
                 layer_id,
+                inner_hidden_size=inner_hidden_size,
+                hidden_size_per_attention_head=hidden_size_per_attention_head,
                 output_layer_init_method=self.output_layer_init_method,
+                is_decoder=self.is_decoder,
                 sandwich_ln=sandwich_ln,
+                layernorm=layernorm,
+                use_bias=use_bias,
                 hooks=self.hooks
             )
 
@@ -323,10 +460,10 @@ class BaseTransformer(torch.nn.Module):
             [get_layer(layer_id) for layer_id in range(num_layers)])
 
         # Final layer norm before output.
-        self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
+        self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
 
-    def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, output_hidden_states=False,
-                **kw_args):
+    def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, encoder_outputs=None,
+                output_hidden_states=False, **kw_args):
         # sanity check 
         assert len(input_ids.shape) == 2
         batch_size, query_length = input_ids.shape
@@ -349,7 +486,8 @@ class BaseTransformer(torch.nn.Module):
             assert len(position_ids.shape) <= 2
             assert position_ids.shape[-1] == query_length
             position_embeddings = self.position_embeddings(position_ids)
-        hidden_states = hidden_states + position_embeddings
+        if position_embeddings is not None:
+            hidden_states = hidden_states + position_embeddings
         hidden_states = self.embedding_dropout(hidden_states)
 
         hidden_states_outputs = [hidden_states] if output_hidden_states else []
@@ -363,21 +501,15 @@ class BaseTransformer(torch.nn.Module):
             def custom(start, end):
                 def custom_forward(*inputs):
                     layers_ = self.layers[start:end]
-                    x_, mask = inputs[0], inputs[1]
-                    if len(inputs) > 2:  # have branch_input
-                        branch_ = inputs[2]
+                    x_, mask, encoder_outputs_ = inputs[0], inputs[1], inputs[2]
                     output_per_layers_part = []
                     for i, layer in enumerate(layers_):
-                        if len(inputs) > 2:
-                            x_, branch_, output_this_layer = self.hooks['layer_forward'](
-                                x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args
-                            )
-                        elif 'layer_forward' in self.hooks:
+                        if 'layer_forward' in self.hooks:
                             x_, output_this_layer = self.hooks['layer_forward'](
-                                x_, mask, layer_id=layer.layer_id, **kw_args
+                                x_, mask, encoder_outputs_, layer_id=layer.layer_id, **kw_args
                             )
                         else:
-                            x_, output_this_layer = layer(x_, mask, **kw_args)
+                            x_, output_this_layer = layer(x_, mask, encoder_outputs_, **kw_args)
                         output_per_layers_part.append(output_this_layer)
                     return x_, output_per_layers_part
 
@@ -386,7 +518,7 @@ class BaseTransformer(torch.nn.Module):
             l, num_layers = 0, len(self.layers)
             chunk_length = self.checkpoint_num_layers
             while l < num_layers:
-                args = [hidden_states, attention_mask]
+                args = [hidden_states, attention_mask, encoder_outputs]
                 if branch_input is not None:
                     hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args,
                                                                                      branch_input)
@@ -398,7 +530,7 @@ class BaseTransformer(torch.nn.Module):
                 l += chunk_length
         else:
             for i, layer in enumerate(self.layers):
-                args = [hidden_states, attention_mask]
+                args = [hidden_states, attention_mask, encoder_outputs]
                 if branch_input is not None:  # customized layer_forward with branch_input
                     hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args,
                                                                                                  layer_id=torch.tensor(
-- 
GitLab