# coding=utf-
# rewritten, Copyright (c) 2021, Ming Ding.  All rights reserved.
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""

import math
import copy
import torch
import torch.nn.functional as F
from apex.normalization.fused_layer_norm import FusedLayerNorm

from .initialize import get_model_parallel_world_size
from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from .mappings import gather_from_model_parallel_region, copy_to_model_parallel_region

import deepspeed

from .random import checkpoint
from .random import get_cuda_rng_tracker

from .utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu
from .utils import split_tensor_along_last_dim

class LayerNorm(FusedLayerNorm):
    def __init__(self, *args, pb_relax=False, **kwargs):
        super().__init__(*args, **kwargs)
        self.pb_relax = pb_relax
    def forward(self, x):
        if not self.pb_relax:
            return super().forward(x)
        return super().forward(x / (x.abs().max().detach()/8))
        
def standard_attention(query_layer, key_layer, value_layer, attention_mask,
                    attention_dropout=None, log_attention_weights=None):
    # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. 
    # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. 

    attention_scores = torch.matmul(
        query_layer / math.sqrt(query_layer.shape[-1]),
        key_layer.transpose(-1, -2)
    )
    
    if attention_mask.shape[-2] > 1: # if auto-regressive, skip
        attention_scores = torch.mul(attention_scores, attention_mask) - \
                    10000.0 * (1.0 - attention_mask)
    if log_attention_weights is not None:
        attention_scores += log_attention_weights
    
    attention_probs = F.softmax(attention_scores, dim=-1)

    if attention_dropout is not None:
        with get_cuda_rng_tracker().fork():
            attention_probs = attention_dropout(attention_probs)

    context_layer = torch.matmul(attention_probs, value_layer)
    return context_layer

class SelfAttention(torch.nn.Module):
    def __init__(self, hidden_size, num_attention_heads,
                attention_dropout_prob, output_dropout_prob,
                init_method, layer_id, output_layer_init_method=None,
                hooks={}):
        super(SelfAttention, self).__init__()
        # Set output layer initialization if not provided.
        if output_layer_init_method is None:
            output_layer_init_method = init_method
        self.hooks = hooks
        self.layer_id = layer_id
        # Per attention head and per partition values.
        world_size = get_model_parallel_world_size()
        self.hidden_size_per_partition = divide(hidden_size, world_size)
        self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
        self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)

        # Strided linear layer.
        self.query_key_value = ColumnParallelLinear(
            hidden_size, 
            3*hidden_size,
            stride=3,
            gather_output=False,
            init_method=init_method
        )
        self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)

        self.dense = RowParallelLinear(
            hidden_size,
            hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method
        )
        self.output_dropout = torch.nn.Dropout(output_dropout_prob)


    def _transpose_for_scores(self, tensor):
        """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
        size [b, np, s, hn].
        """
        new_tensor_shape = tensor.size()[:-1] + \
                            (self.num_attention_heads_per_partition,
                            self.hidden_size_per_attention_head)
        tensor = tensor.view(*new_tensor_shape)
        return tensor.permute(0, 2, 1, 3)

    def forward(self, hidden_states, mask, *other_tensors):
        if 'attention_forward' in self.hooks:
            return self.hooks['attention_forward'](hidden_states, mask, *other_tensors,layer_id=self.layer_id)
        else:
            mixed_raw_layer = self.query_key_value(hidden_states)
            (mixed_query_layer,
                mixed_key_layer,
                mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)

            dropout_fn = self.attention_dropout if self.training else None

            query_layer = self._transpose_for_scores(mixed_query_layer)
            key_layer = self._transpose_for_scores(mixed_key_layer)
            value_layer = self._transpose_for_scores(mixed_value_layer)
            
            context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn)
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
            context_layer = context_layer.view(*new_context_layer_shape)
            output = self.dense(context_layer)
            
            if self.training:
                output = self.output_dropout(output)
            
            return output, None


class MLP(torch.nn.Module):
    def __init__(self, hidden_size, output_dropout_prob, init_method,
                output_layer_init_method=None, hooks={}):
        super(MLP, self).__init__()
        # Set output layer initialization if not provided.
        if output_layer_init_method is None:
            output_layer_init_method = init_method
        self.hooks = hooks
        # Project to 4h.
        self.dense_h_to_4h = ColumnParallelLinear(
            hidden_size,
            4*hidden_size,
            gather_output=False,
            init_method=init_method
        )
        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            4*hidden_size,
            hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method
        )
        self.dropout = torch.nn.Dropout(output_dropout_prob)

    def forward(self, hidden_states, *other_tensors):
        if 'mlp_forward' in self.hooks:
            output = self.hooks['mlp_forward'](hidden_states, *other_tensors, layer_id=self.layer_id)
        else:
            intermediate_parallel = self.dense_h_to_4h(hidden_states)
            intermediate_parallel = gelu(intermediate_parallel)
            output = self.dense_4h_to_h(intermediate_parallel)
            
        if self.training:
            output = self.dropout(output)
        return output


class BaseTransformerLayer(torch.nn.Module):
    """A single layer transformer for GPT2.

    We use the following notation:
        h: hidden size
        n: number of attention heads
        b: batch size
        s: sequence length
    Transformore layer takes input with size [b, s, h] and returns an
    output of the same size.

    Arguments:
        hidden_size: The hidden size of the self attention.
        num_attention_heads: number of attention head in the self
                             attention.
        attention_dropout_prob: dropout probability of the attention
                                score in self attention.
        output_dropout_prob: dropout probability for the outputs
                             after self attention and final output.
        layernorm_epsilon: epsilon used in layernorm to avoid
                           division by zero.
        init_method: initialization method used for the weights. Note
                     that all biases are initialized to zero and
                     layernorm weight are initialized to one.
        output_layer_init_method: output layers (attention output and
                                  mlp output) initialization. If None,
                                  use `init_method`.
    """
    def __init__(
        self,
        hidden_size,
        num_attention_heads,
        attention_dropout_prob,
        output_dropout_prob,
        layernorm_epsilon,
        init_method,
        layer_id,
        output_layer_init_method=None,
        sandwich_ln=True,
        hooks={}
    ):
        super(BaseTransformerLayer, self).__init__()
        # Set output layer initialization if not provided.
        if output_layer_init_method is None:
            output_layer_init_method = init_method
        self.layer_id = layer_id
        self.hooks = hooks

        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)

        # Self attention.
        self.attention = SelfAttention(
            hidden_size,
            num_attention_heads,
            attention_dropout_prob,
            output_dropout_prob,
            init_method,
            layer_id,
            output_layer_init_method=output_layer_init_method,
            hooks=hooks
        )

        # Layernorm on the input data.
        self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
        self.sandwich_ln = sandwich_ln
        if sandwich_ln:
            self.third_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
            self.fourth_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)

        # MLP
        self.mlp = MLP(
            hidden_size,
            output_dropout_prob,
            init_method,
            output_layer_init_method=output_layer_init_method,
            hooks=hooks
        )
    
    def forward(self, hidden_states, mask, *other_tensors):
        '''
            hidden_states: [batch, seq_len, hidden_size]
            mask: [(1, 1), seq_len, seq_len]
        '''

        # Layer norm at the begining of the transformer layer.
        layernorm_output1 = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output, output_this_layer = self.attention(layernorm_output1, mask, *other_tensors)

        # Third LayerNorm
        if self.sandwich_ln:
            attention_output = self.third_layernorm(attention_output)

        # Residual connection.
        layernorm_input = hidden_states + attention_output
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)
        # MLP.
        mlp_output = self.mlp(layernorm_output)

        # Fourth LayerNorm
        if self.sandwich_ln:
            mlp_output = self.fourth_layernorm(mlp_output, *other_tensors)

        # Second residual connection.
        output = layernorm_input + mlp_output

        return output, output_this_layer # temporally, output_this_layer is only from attention

class BaseTransformer(torch.nn.Module):
    def __init__(self,
                 num_layers,
                 vocab_size,
                 hidden_size,
                 num_attention_heads,
                 max_sequence_length,
                 embedding_dropout_prob,
                 attention_dropout_prob,
                 output_dropout_prob,
                 checkpoint_activations,
                 checkpoint_num_layers=1,
                 layernorm_epsilon=1.0e-5,
                 init_method_std=0.02,
                 sandwich_ln=True,
                 parallel_output=True,
                 hooks={}
                 ):
        super(BaseTransformer, self).__init__()
        if deepspeed.checkpointing.is_configured():
            global get_cuda_rng_tracker, checkpoint
            get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
            checkpoint = deepspeed.checkpointing.checkpoint
        
        # recording parameters
        self.parallel_output = parallel_output
        self.checkpoint_activations = checkpoint_activations
        self.checkpoint_num_layers = checkpoint_num_layers
        self.max_sequence_length = max_sequence_length
        self.hooks = copy.copy(hooks) # hooks will be updated each forward
        
        # create embedding parameters
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
        
        self.word_embeddings = VocabParallelEmbedding(
            vocab_size, hidden_size, init_method=unscaled_init_method(0.02))
        
        self.position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)
        torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)

        # create all layers
        self.output_layer_init_method = scaled_init_method(init_method_std, num_layers)
        self.init_method = unscaled_init_method(init_method_std)
        def get_layer(layer_id):
            return BaseTransformerLayer(
                hidden_size,
                num_attention_heads,
                attention_dropout_prob,
                output_dropout_prob,
                layernorm_epsilon,
                self.init_method,
                layer_id,
                output_layer_init_method=self.output_layer_init_method,
                sandwich_ln=sandwich_ln,
                hooks=hooks
                )
        self.layers = torch.nn.ModuleList(
            [get_layer(layer_id) for layer_id in range(num_layers)])

        # Final layer norm before output.
        self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)

    def forward(self, input_ids, position_ids, attention_mask, *other_tensors):
        # sanity check 
        assert len(input_ids.shape) == 2 
        batch_size, query_length = input_ids.shape
        assert len(position_ids.shape) <= 2
        assert position_ids.shape[-1] == query_length
        assert len(attention_mask.shape) == 2 or \
            len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1

        # embedding part
        if 'word_embedding_forward' in self.hooks:
            hidden_states = self.hooks['word_embedding_forward'](input_ids, *other_tensors)
        else: # default
            hidden_states = self.word_embeddings(input_ids)
            
        if 'position_embedding_forward' in self.hooks:
            position_embeddings = self.hooks['position_embedding_forward'](position_ids, *other_tensors)
        else:
            position_embeddings = self.position_embeddings(position_ids)    
        hidden_states = hidden_states + position_embeddings
        hidden_states = self.embedding_dropout(hidden_states)

        # define custom_forward for checkpointing
        output_per_layers = []
        if self.checkpoint_activations:
            def custom(start, end):
                def custom_forward(*inputs):
                    layers_ = self.layers[start:end]
                    x_, mask, *other_tensors = inputs[0], inputs[1], inputs[2:]
                    for i, layer in enumerate(layers_):
                        x_, output_this_layer = layer(x_, mask, *other_tensors)
                        output_per_layers.append(output_this_layer)
                    return x_
                return custom_forward
        
            l, num_layers = 0, len(self.layers)
            chunk_length = self.checkpoint_num_layers
            while l < num_layers:
                args = [hidden_states, attention_mask, *other_tensors]
                hidden_states = checkpoint(custom(l, l + chunk_length), *args)
                l += chunk_length
        else:
            for i, layer in enumerate(self.layers):
                args = [hidden_states, attention_mask, *other_tensors]
                hidden_states, output_this_layer = layer(*args, *other_tensors)
                output_per_layers.append(output_this_layer) 

        # Final layer norm.
        logits = self.final_layernorm(hidden_states)
        
        if 'final_forward' in self.hooks:
            logits_parallel = self.hooks['final_forward'](logits, *other_tensors)
        else:
            logits_parallel = copy_to_model_parallel_region(logits)
            logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight)

        if self.parallel_output:
            return (logits_parallel, *output_per_layers)
        return (gather_from_model_parallel_region(logits_parallel), *output_per_layers)