# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""

import math
import random
import argparse

import torch
import torch.nn.init as init
import torch.nn.functional as F
from apex.normalization.fused_layer_norm import FusedLayerNorm

from .initialize import get_model_parallel_world_size
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .mappings import gather_from_model_parallel_region

import deepspeed

from .random import checkpoint
from .random import get_cuda_rng_tracker

from .utils import divide, sqrt
from .utils import split_tensor_along_last_dim
import torch.distributed as dist

class LayerNorm(FusedLayerNorm):
    def __init__(self, pb_relax=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pb_relax = pb_relax
    def forward(self, x):
        if not self.pb_relax:
            return super().forward(x)
        return super().forward(x / (x.abs().max().detach()/8))

class GPT2ParallelSelfAttention(torch.nn.Module):
    """Parallel self-attention layer for GPT2.

    Self-attention layer takes input with size [b, s, h] where b is
    the batch size, s is the sequence length, and h is the hidden size
    and creates output of the same size.
    Arguments:
        hidden_size: total hidden size of the layer (h).
        num_attention_heads: number of attention heads (n). Note that we
                             require n to be divisible by number of GPUs
                             used to parallelize the model. Also, we
                             require hidden size to be divisible by n.
        dropout_prob: dropout probability for the attention scores.
        init_method: weight initialization.
        output_layer_init_method: output layer initialization. If None, use
                                  `init_method`.
    We use the following notation:
        h: hidden_size
        n: num_attention_heads
        p: number of partitions
        np: n/p
        hp: h/p
        hn: h/n
        b: batch size
        s: sequence length
    """
    def __init__(self, hidden_size, num_attention_heads,
                 attention_dropout_prob, output_dropout_prob,
                 init_method, output_layer_init_method=None):
        super(GPT2ParallelSelfAttention, self).__init__()
        # Set output layer initialization if not provided.
        if output_layer_init_method is None:
            output_layer_init_method = init_method
        # Per attention head and per partition values.
        world_size = get_model_parallel_world_size()
        self.hidden_size_per_partition = divide(hidden_size, world_size)
        self.hidden_size_per_attention_head = divide(hidden_size,
                                                     num_attention_heads)
        self.num_attention_heads_per_partition = divide(num_attention_heads,
                                                        world_size)

        # Strided linear layer.
        self.query_key_value = ColumnParallelLinear(hidden_size, 3*hidden_size,
                                                    stride=3,
                                                    gather_output=False,
                                                    init_method=init_method)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)

        # Output.
        self.dense = RowParallelLinear(hidden_size,
                                       hidden_size,
                                       input_is_parallel=True,
                                       init_method=output_layer_init_method)
        self.output_dropout = torch.nn.Dropout(output_dropout_prob)

        if deepspeed.checkpointing.is_configured():
            global get_cuda_rng_tracker, checkpoint
            get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
            checkpoint = deepspeed.checkpointing.checkpoint

    def _transpose_for_scores(self, tensor):
        """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
        size [b, np, s, hn].
        """
        new_tensor_shape = tensor.size()[:-1] + \
                           (self.num_attention_heads_per_partition,
                            self.hidden_size_per_attention_head)
        tensor = tensor.view(*new_tensor_shape)
        return tensor.permute(0, 2, 1, 3)


    def forward(self, hidden_states, mask, sparse_config, mem=None):
        # hidden_states: [b, s, h]
        # ltor_mask: [1, 1, s, s]

        # Attention heads. [b, s, hp]
        query_length = hidden_states.size(1)

        # if mem is None:
        mixed_raw_layer = self.query_key_value(hidden_states)
        if mem is None:
            mixed_x_layer = mixed_raw_layer
        else:
            mixed_x_layer = torch.cat((mem, mixed_raw_layer), dim=1)
        (mixed_query_layer,
            mixed_key_layer,
            mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

        if sparse_config.sparse_type in ['standard', 'torch_1d']:
            # Reshape and transpose [b, np, s, hn]
            query_layer = self._transpose_for_scores(mixed_query_layer)
            key_layer = self._transpose_for_scores(mixed_key_layer)
            value_layer = self._transpose_for_scores(mixed_value_layer)
            if sparse_config.sparse_type == 'standard':
                context_layer = standard_attention(query_layer, key_layer, value_layer, mask, self.attention_dropout)
            else:
                context_layer = sparse_attention(query_layer, key_layer, value_layer, sparse_config.pivot_idx, 
                    mask, sparse_config.query_window, sparse_config.key_window_times, self.attention_dropout)
                # inference: context_layer = sparse_attention_inference(query_layer, key_layer, value_layer, pivot_idx)
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + \
                                    (self.hidden_size_per_partition,)
            # [b, s, hp]
            context_layer = context_layer.view(*new_context_layer_shape)

        elif sparse_config.sparse_type == 'cuda_2d':
            context_layer = sparse_attention_2d(mixed_query_layer, mixed_key_layer, mixed_value_layer, self.num_attention_heads_per_partition,
                 sparse_config.layout, mask, sparse_config.kernel_size, sparse_config.kernel_size2, attention_dropout=self.attention_dropout)

        # Output. [b, s, h]
        output = self.dense(context_layer)
        output = self.output_dropout(output)

        return output, mixed_raw_layer.detach()


@torch.jit.script
def gelu_impl(x):
     """OpenAI's gelu implementation."""
     return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
                                        (1.0 + 0.044715 * x * x)))

def gelu(x): 
    return gelu_impl(x)


class GPT2ParallelMLP(torch.nn.Module):
    """MLP for GPT2.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform gelu transformation, and project the
    state back into h hidden dimension. At the end, dropout is also
    applied.

    Arguments:
        hidden_size: The hidden size of the self attention.
        output_dropout_prob: dropout probability for the outputs
                             after self attention and final output.
        init_method: initialization method used for the weights. Note
                     that all biases are initialized to zero and
                     layernorm weight are initialized to one.
        output_layer_init_method: output layer initialization. If None,
                                  use `init_method`.
    """

    def __init__(self, hidden_size, output_dropout_prob, init_method,
                 output_layer_init_method=None):
        super(GPT2ParallelMLP, self).__init__()
        # Set output layer initialization if not provided.
        if output_layer_init_method is None:
            output_layer_init_method = init_method
        # Project to 4h.
        self.dense_h_to_4h = ColumnParallelLinear(hidden_size, 4*hidden_size,
                                                  gather_output=False,
                                                  init_method=init_method)
        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            4*hidden_size,
            hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method)
        self.dropout = torch.nn.Dropout(output_dropout_prob)

    def forward(self, hidden_states):
        # [b, s, 4hp]
        intermediate_parallel = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = gelu(intermediate_parallel)

        # [b, s, h]
        output = self.dense_4h_to_h(intermediate_parallel)
        output = self.dropout(output)
        return output


class GPT2ParallelTransformerLayer(torch.nn.Module):
    """A single layer transformer for GPT2.

    We use the following notation:
        h: hidden size
        n: number of attention heads
        b: batch size
        s: sequence length
    Transformore layer takes input with size [b, s, h] and returns an
    output of the same size.

    Arguments:
        hidden_size: The hidden size of the self attention.
        num_attention_heads: number of attention head in the self
                             attention.
        attention_dropout_prob: dropout probability of the attention
                                score in self attention.
        output_dropout_prob: dropout probability for the outputs
                             after self attention and final output.
        layernorm_epsilon: epsilon used in layernorm to avoid
                           division by zero.
        init_method: initialization method used for the weights. Note
                     that all biases are initialized to zero and
                     layernorm weight are initialized to one.
        output_layer_init_method: output layers (attention output and
                                  mlp output) initialization. If None,
                                  use `init_method`.
    """
    def __init__(self,
                 hidden_size,
                 num_attention_heads,
                 attention_dropout_prob,
                 output_dropout_prob,
                 layernorm_epsilon,
                 init_method,
                 output_layer_init_method=None,
                 sandwich_ln=True,
                 sparse_config=argparse.Namespace(sparse_type='standard')
                 ):
        super(GPT2ParallelTransformerLayer, self).__init__()
        # Set output layer initialization if not provided.
        if output_layer_init_method is None:
            output_layer_init_method = init_method

        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)

        # Self attention.
        self.attention = GPT2ParallelSelfAttention(
            hidden_size,
            num_attention_heads,
            attention_dropout_prob,
            output_dropout_prob,
            init_method,
            output_layer_init_method=output_layer_init_method)

        # Layernorm on the input data.
        self.post_attention_layernorm = LayerNorm(hidden_size,
                                                  eps=layernorm_epsilon)
        self.sandwich_ln = sandwich_ln
        if sandwich_ln:
            self.third_layernorm = LayerNorm(hidden_size,
                                                    eps=layernorm_epsilon)
            self.fourth_layernorm = LayerNorm(hidden_size,
                                                    eps=layernorm_epsilon)

        # MLP
        self.mlp = GPT2ParallelMLP(
            hidden_size,
            output_dropout_prob,
            init_method,
            output_layer_init_method=output_layer_init_method)

        self.sparse_config = sparse_config

    def forward(self, hidden_states, ltor_mask, mem=None):
        # hidden_states: [b, s, h]
        # ltor_mask: [1, 1, s, s]

        # Layer norm at the begining of the transformer layer.
        layernorm_output1 = self.input_layernorm(hidden_states)
        mem = self.input_layernorm(mem) if mem is not None else None
        # Self attention.
        attention_output, qkv = self.attention(layernorm_output1, ltor_mask, self.sparse_config, mem)

        # Third LayerNorm
        if self.sandwich_ln:
            attention_output = self.third_layernorm(attention_output)

        # Residual connection.
        layernorm_input = hidden_states + attention_output
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)
        # MLP.
        mlp_output = self.mlp(layernorm_output)

        # Fourth LayerNorm
        if self.sandwich_ln:
            mlp_output = self.fourth_layernorm(mlp_output)

        # Second residual connection.
        output = layernorm_input + mlp_output

        return output, qkv

def unscaled_init_method(sigma):
    """Init method based on N(0, sigma)."""
    def init_(tensor):
        return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)

    return init_


def scaled_init_method(sigma, num_layers):
    """Init method based on N(0, sigma/sqrt(2*num_layers)."""
    std = sigma / math.sqrt(2.0 * num_layers)
    def init_(tensor):
        return torch.nn.init.normal_(tensor, mean=0.0, std=std)

    return init_


class GPT2ParallelTransformer(torch.nn.Module):
    """GPT-2 transformer.

    This module takes input from embedding layer and it's output can
    be used directly by a logit layer. It consists of L (num-layers)
    blocks of:
        layer norm
        self attention
        residual connection
        layer norm
        mlp
        residual connection
    followed by a final layer norm.

    Arguments:
        num_layers: Number of transformer layers.
        hidden_size: The hidden size of the self attention.
        num_attention_heads: number of attention head in the self
                             attention.
        attention_dropout_prob: dropout probability of the attention
                                score in self attention.
        output_dropout_prob: dropout probability for the outputs
                             after self attention and final output.
        checkpoint_activations: if True, checkpoint activations.
        checkpoint_num_layers: number of layers to checkpoint. This
                               is basically the chunk size in checkpoitning.
        layernorm_epsilon: epsilon used in layernorm to avoid
                           division by zero.
        init_method_std: standard deviation of the init method which has
                         the form N(0, std).
        use_scaled_init_for_output_weights: If Ture use 1/sqrt(2*num_layers)
                                            scaling for the output weights (
                                            output of self attention and mlp).
    """
    def __init__(self,
                 num_layers,
                 hidden_size,
                 num_attention_heads,
                 max_sequence_length,
                 max_memory_length,
                 embedding_dropout_prob,
                 attention_dropout_prob,
                 output_dropout_prob,
                 checkpoint_activations,
                 checkpoint_num_layers=1,
                 layernorm_epsilon=1.0e-5,
                 init_method_std=0.02,
                 use_scaled_init_for_output_weights=True,
                 sandwich_ln=True,
                 sparse_config=argparse.Namespace(sparse_type='standard')
                 ):
        super(GPT2ParallelTransformer, self).__init__()
        # Store activation checkpoiting flag.
        self.checkpoint_activations = checkpoint_activations
        self.checkpoint_num_layers = checkpoint_num_layers
        self.max_memory_length = max_memory_length
        self.max_sequence_length = max_sequence_length

        output_layer_init_method = None
        if use_scaled_init_for_output_weights:
            output_layer_init_method = scaled_init_method(init_method_std,
                                                      num_layers)
        # Embeddings dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

        # Position embedding (serial).
        self.position_embeddings = torch.nn.Embedding(max_sequence_length,
                                                        hidden_size)
        # Initialize the position embeddings.
        torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)

        # TODO: after testing, this is not useful.
        # self.img_type_embeddings = torch.nn.Parameter(torch.Tensor(64, hidden_size)) 
        # torch.nn.init.normal_(self.img_type_embeddings, mean=0.0, std=init_method_std)
        # self.txt_type_embeddings = torch.nn.Parameter(torch.Tensor(hidden_size)) 
        # torch.nn.init.normal_(self.txt_type_embeddings, mean=0.0, std=init_method_std)


        def get_layer(layer_id):
            return GPT2ParallelTransformerLayer(
                hidden_size,
                num_attention_heads,
                attention_dropout_prob,
                output_dropout_prob,
                layernorm_epsilon,
                unscaled_init_method(init_method_std),
                output_layer_init_method=output_layer_init_method,
                sandwich_ln=sandwich_ln,
                sparse_config=sparse_config
                )

        # Transformer layers.
        self.layers = torch.nn.ModuleList(
            [get_layer(layer_id) for layer_id in range(num_layers)])

        # Final layer norm before output.
        self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)

        if deepspeed.checkpointing.is_configured():
            global get_cuda_rng_tracker, checkpoint
            get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
            checkpoint = deepspeed.checkpointing.checkpoint
        self.sparse_config = sparse_config

    def forward(self, hidden_states, position_ids, attention_mask, *mems):

        batch_size, query_length = hidden_states.size()[:2]
        memory_length = mems[0].size(1) if mems else 0
        key_length = query_length + memory_length

        # legacy
        if isinstance(attention_mask, int) or attention_mask.numel() == 1:
            # if given a int "sep", means the seperation of full attention part and single direction part
            # attention mask is the beginning postion of B region, \in [0, query_len)
            sep = attention_mask
            # conventional transformer
            def build_mask_matrix(query_length, key_length, sep):
                m = torch.ones((1, query_length, key_length), device=hidden_states.device, dtype=hidden_states.dtype)
                assert query_length <= key_length
                m[0, :, -query_length:] = torch.tril(m[0, :, -query_length:])
                m[0, :, :sep + (key_length - query_length)] = 1
                m = m.unsqueeze(1)
                return m
            attention_mask = build_mask_matrix(query_length, key_length, sep)

        # =====================   Image & Text Type Embedding   ======================== #
        # TODO: after testing, this is not useful.
        # extend_len = (key_length + 63) // 64
        # hidden_states = hidden_states + txt_indices_bool.unsqueeze(-1) * self.txt_type_embeddings.view(1, 1, -1) + \
        #     img_indices_bool.unsqueeze(-1) * self.img_type_embeddings.expand(extend_len, 64, -1).reshape(extend_len * 64, -1)[memory_length: key_length]
        # ===================== END OF BLOCK ======================= #

        position_embeddings = self.position_embeddings(position_ids)
        hidden_states = hidden_states + position_embeddings
        hidden_states = self.embedding_dropout(hidden_states)

        mem_layers = []
        def custom(start, end):
            def custom_forward(*inputs):
                layers_ = self.layers[start:end]
                x_, mask, mems_ = inputs[0], inputs[1], inputs[1:]
            
                for i, layer in enumerate(layers_):
                    mem_i_ = mems_[i] if mems_ else None
                    x_, qkv = layer(x_, mask, mem=mem_i_)
                    if self.max_memory_length > 0:
                        mem_layers.append(qkv)
                return x_
            return custom_forward

        attention_mask_saved = attention_mask
        
        if self.checkpoint_activations:
            l = 0
            num_layers = len(self.layers)
            chunk_length = self.checkpoint_num_layers
            while l < num_layers:
                args = [hidden_states, attention_mask_saved]

                if mems:
                    args += mems[l: l + chunk_length]

                hidden_states = checkpoint(custom(l, l + chunk_length), *args)
                l += chunk_length
        else:
            assert self.sparse_config.sparse_type == 'standard'
            for i, layer in enumerate(self.layers):
                args = [hidden_states, attention_mask_saved]

                mem_i = mems[i] if mems else None
                hidden_states, qkv = layer(*args, mem=mem_i)
                if self.max_memory_length > 0:
                    mem_layers.append(qkv) 

        # Final layer norm.
        output = self.final_layernorm(hidden_states)
        if self.max_memory_length > 0: # TODO cache
            mem_layers = self.update_mems(mem_layers, mems)

        return (output, *mem_layers)

    def update_mems(self, hiddens, mems):
        memory_length = mems[0].size(1) if mems else 0
        query_length = hiddens[0].size(1)
        new_memory_length = min(self.max_memory_length, memory_length + query_length)
        new_mems = []
        with torch.no_grad():
            for i in range(len(hiddens)):
                if new_memory_length <= query_length:
                    new_mems.append(hiddens[i][:, -new_memory_length:])
                else:
                    new_mems.append(torch.cat((mems[i][:, -new_memory_length+query_length:], hiddens[i]), dim=1))
        return new_mems
        

def _chunk(x, w, times):
    '''convert into overlapping chunkings. Chunk size = times * w, overlap size = w
    Args:
        x: [b, np, s, hn]
        ...
    '''
    s = x.size(2)
    # x pad to [b, np, s+xx to k*w + w*(times-1), hn]
    assert s % w == 0
    npad = (times-1) * w
    x = torch.nn.functional.pad(x, (0, 0, npad, 0), value=0)

    x = x.view(x.size(0), x.size(1),  x.size(2) // w, w, x.size(3))

    chunk_size = list(x.size())
    chunk_stride = list(x.stride())

    chunk_size[2] = chunk_size[2] - times + 1

    chunk_size[3] = w * times

    return x.as_strided(size=chunk_size, stride=chunk_stride)

def standard_attention(query_layer, key_layer, value_layer, attention_mask, attention_dropout=None):
    # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. 
    # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. 

    if len(attention_mask.shape) == 3:
        attention_mask = attention_mask.unsqueeze(1)
    # Raw attention scores. [b, np, s, s]
    attention_scores = torch.matmul(query_layer / math.sqrt(query_layer.shape[-1]), key_layer.transpose(-1, -2))

    # Apply the left to right attention mask.
    attention_scores = torch.mul(attention_scores, attention_mask) - \
                    10000.0 * (1.0 - attention_mask)
    # Attention probabilities. [b, np, s, s]
    attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)

    if attention_dropout is not None:
        with get_cuda_rng_tracker().fork():
            attention_probs = attention_dropout(attention_probs)
    # Context layer.
    # [b, np, s, hn]
    context_layer = torch.matmul(attention_probs, value_layer)
    return context_layer

def sparse_attention_1d(q, k, v, pivot_idx, pivot_attention_mask, query_window=128, key_window_times=6, attention_dropout=None):
    ''' Sparse Attention
    Args:
        q, k, v: inputs, [b, num_heads, s, hn], k is padded to n * query_window
        pivot_idx: [b, num_pivots]
        pivot_attention_mask: [b, s, num_pivots]
        query_window: .
        key_window_times: key_window = query_window * key_window_times
    '''

    b, n_head, s, hn = q.shape
    b, n_piv = pivot_idx.shape
    w = query_window

    pivot_idx_dummy = pivot_idx.view(b, 1, n_piv, 1).expand(b, n_head, n_piv, hn)
    # =====================   Pivot Attention   ======================== #
    pivot_k, pivot_v = torch.gather(k, 2, pivot_idx_dummy), torch.gather(v, 2, pivot_idx_dummy)
    attention_scores = torch.matmul(q, pivot_k.transpose(-1, -2))
    pivot_attention_mask = pivot_attention_mask.unsqueeze(1)

    attention_scores_pivot = torch.mul(attention_scores, pivot_attention_mask / math.sqrt(hn)) - 10000.0 * (1.0 - pivot_attention_mask)

    attention_scores_pivot = attention_scores_pivot + math.log(s // n_piv)
    # =====================   Window Attention   ======================= #
    window_k = _chunk(k, query_window, key_window_times)
    window_v = _chunk(v, query_window, key_window_times)
    # window_k [b, n_head, s // w up int, w*times, hn]

    if s % w == 0: # training # TODO args check
        assert k.shape[2] == s
        assert window_k.shape[2] == s // w
        window_q = q.view(b, n_head, s // w, w, hn)        
        attention_scores = torch.matmul(window_q, window_k.transpose(-1, -2))
        window_attention_mask = torch.ones((w, w * key_window_times), dtype=attention_scores.dtype, device=q.device).tril_(diagonal=w * (key_window_times - 1))
        attention_scores_window = torch.mul(attention_scores, window_attention_mask / math.sqrt(hn)) - 10000.0 * (1.0 - window_attention_mask)
        for t in range(1, key_window_times):
            attention_scores_window[:, :, t - 1, :, :w * key_window_times - w * t] -= 10000.0
    else: 
        raise ValueError('The seq_len must be exactly divided by window_size.')
    # =====================   Joint Softmax   ======================= #
    attention_scores_window = attention_scores_window.view(b, n_head, s, w * key_window_times)
    attention_scores = torch.cat((attention_scores_pivot, attention_scores_window), dim=-1)
    attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)

    if attention_dropout is not None:
        with get_cuda_rng_tracker().fork():
            attention_probs = attention_dropout(attention_probs)

    context_layer = torch.matmul(attention_probs[..., :-w * key_window_times], pivot_v) + torch.einsum('bcgwk,bcgkh->bcgwh', attention_probs[..., -w * key_window_times:].view(b, n_head, s // w, w, w * key_window_times), window_v).view(b, n_head, s, hn)

    return context_layer

# def sparse_attention_inference_1d(q, k, v, pivot_and_window_idx, **kwargs):
#     '''the inference process of sparse attention.
#     The Qs are in the same block, but seq_len mod window size might != 0.

#     The Qs are the final tokens of Ks. the pivot_and_window_idx[-query_len] are Qs.

#     '''
#     b, n_head, sq, hn = q.shape
#     sk = k.shape[2]
#     _b, n_piv = pivot_and_window_idx.shape

#     pivot_and_window_idx_dummy = pivot_and_window_idx.view(b, 1, n_piv, 1).expand(b, n_head, n_piv, hn)
#     pivot_k, pivot_v = torch.gather(k, 2, pivot_and_window_idx_dummy), torch.gather(v, 2, pivot_and_window_idx_dummy)
#     attention_scores = torch.matmul(q / math.sqrt(hn), pivot_k.transpose(-1, -2))
#     if sq > 1:
#         query_part_scores = attention_scores[:, :, -sq:, -sq:]
#         m = torch.ones((sq, sq), device=q.device, dtype=q.dtype) * -10000.
#         m.triu_(diagonal=1)
#         query_part_scores += m

#     attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)

#     context_layer = torch.matmul(attention_probs, pivot_v) 
#     return context_layer

def transpose_and_split(x, layout, n_head):
    x = x.transpose(1, 2)
    x = x.reshape(x.shape[0] * n_head, x.shape[1] // n_head, x.shape[2])
    x_text = x[..., :layout[0]]
    x0 = x[...,layout[1]:layout[2]].view(x.shape[0], x.shape[1], sqrt(layout[2] - layout[1]), -1).contiguous()
    x1 = x[...,layout[2]:layout[3]].view(x.shape[0], x.shape[1], sqrt(layout[3] - layout[2]), -1).contiguous()
    return x, x_text, x0, x1

def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs):
    '''
    q, k, v: [batch_size, 64+1024+4096, hidden_size]
    n_head: int
    layout: [endoftext/startofpad, startof0, startof1, endofall]
    attention_mask_text2d: [batch_size, sq_len, endoftext]
    '''
    from .local_attention_function import f_similar, f_weighting
    b, sq_len, hn = q.shape
    alpha = sqrt((layout[3] - layout[2]) // (layout[2] - layout[1]))

    q = q / math.sqrt(hn // n_head) # normalization

    q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext]
    k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head)
    v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head)

    # import pdb; pdb.set_trace()
    # all to text
    scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d)
    scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0])
    # 0 to 0
    scores_0_to_0 = f_similar(q0, k0, kernel_size, kernel_size, True)
    # 1 to 1
    scores_1_to_1 = f_similar(q1, k1, kernel_size, kernel_size, True)    
    # 1 to 0
    scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2]
    # softmax
    probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text]

    scores_0 = torch.cat(
        (scores_all_to_text[:, layout[1]:layout[2]], 
        scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])), 
        dim=-1)
    probs_0 = F.softmax(scores_0, dim=-1) # 
    scores_1 = torch.cat(
        (scores_all_to_text[:, layout[2]:layout[3]],
         scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]),
         scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])),
         dim=-1)
    probs_1 = F.softmax(scores_1, dim=-1)

    if attention_dropout is not None:
        with get_cuda_rng_tracker().fork():
            probs_0 = attention_dropout(probs_0)
            probs_1 = attention_dropout(probs_1)
    # weighting
    pad = torch.zeros(layout[1], device=q.device, dtype=q.dtype)
    probs_all_to_text = torch.cat((
        probs_text,
        pad[-layout[0]:].expand(b*n_head, layout[1]-layout[0], layout[0]),
        probs_0[:, :, :layout[0]],
        probs_1[:, :, :layout[0]]
    ), dim=1)

    context_all_to_text = torch.einsum('bhij,bhcj->bihc', 
        probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]), 
        v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn)
    
    context_0_to_0 = f_weighting(v0, probs_0[..., layout[0]:].view_as(scores_0_to_0).contiguous(), kernel_size, kernel_size, True)

    context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False)

    context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size, kernel_size, True)
    
    context_all_to_01 =torch.cat(
        (
            pad.expand(b*n_head, hn//n_head, layout[1]),
            context_0_to_0.view(b*n_head, hn//n_head, layout[2]-layout[1]),
            (context_1_to_0 + context_1_to_1).view(b*n_head, hn//n_head, layout[3]-layout[2])
        ), dim=-1).view(b, hn, -1).transpose(1, 2)
    return context_all_to_text + context_all_to_01