From 34990eea7b857973a143cdebd31e48470cacf79b Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Wed, 18 Aug 2021 18:21:39 +0000
Subject: [PATCH] sparse 2d and cache qkv

---
 mpu/local_attention_function.py | 149 ++++++++++++
 mpu/sparse_transformer.py       | 401 +++++++++++++-------------------
 mpu/utils.py                    |   4 +
 test_sparse_attention.py        | 169 ++++++++++++++
 4 files changed, 486 insertions(+), 237 deletions(-)
 create mode 100644 mpu/local_attention_function.py
 create mode 100644 test_sparse_attention.py

diff --git a/mpu/local_attention_function.py b/mpu/local_attention_function.py
new file mode 100644
index 0000000..5ba073c
--- /dev/null
+++ b/mpu/local_attention_function.py
@@ -0,0 +1,149 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from localAttention import (similar_forward,
+                            similar_backward,
+                            weighting_forward,
+                            weighting_backward_ori,
+                            weighting_backward_weight)
+
+__all__ = ['f_similar', 'f_weighting', 'LocalAttention', 'TorchLocalAttention']
+
+
+class similarFunction(Function):
+    @staticmethod
+    def forward(ctx, x_ori, x_loc, kH, kW, casual_mask=False):
+        ctx.save_for_backward(x_ori, x_loc)
+        ctx.kHW = (kH, kW)
+        ctx.casual_mask = casual_mask
+        output = similar_forward(x_ori, x_loc, kH, kW, casual_mask)
+
+        return output
+
+    @staticmethod
+    #@once_differentiable
+    def backward(ctx, grad_outputs):
+        x_ori, x_loc = ctx.saved_tensors
+        kH, kW = ctx.kHW
+        casual_mask = ctx.casual_mask
+        grad_ori = similar_backward(x_ori, x_loc, grad_outputs, kH, kW, True, casual_mask)
+        grad_loc = similar_backward(x_ori, x_loc, grad_outputs, kH, kW, False, casual_mask)
+
+        return grad_ori, grad_loc, None, None, None
+
+
+class weightingFunction(Function):
+    @staticmethod
+    def forward(ctx, x_ori, x_weight, kH, kW, casual_mask=False):
+        ctx.save_for_backward(x_ori, x_weight)
+        ctx.kHW = (kH, kW)
+        ctx.casual_mask = casual_mask
+        output = weighting_forward(x_ori, x_weight, kH, kW, casual_mask)
+
+        return output
+
+    @staticmethod
+    #@once_differentiable
+    def backward(ctx, grad_outputs):
+        x_ori, x_weight = ctx.saved_tensors
+        kH, kW = ctx.kHW
+        casual_mask = ctx.casual_mask
+        grad_ori = weighting_backward_ori(x_ori, x_weight, grad_outputs, kH, kW, casual_mask)
+        grad_weight = weighting_backward_weight(x_ori, x_weight, grad_outputs, kH, kW, casual_mask)
+
+        return grad_ori, grad_weight, None, None, None
+
+
+f_similar = similarFunction.apply
+f_weighting = weightingFunction.apply
+
+
+class LocalAttention(nn.Module):
+    def __init__(self, inp_channels, out_channels, kH, kW):
+        super(LocalAttention, self).__init__()
+        self.conv1 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
+        self.conv2 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
+        self.conv3 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
+        self.kH = kH
+        self.kW = kW
+
+    def forward(self, x):
+        x1 = self.conv1(x)
+        x2 = self.conv2(x)
+        x3 = self.conv3(x)
+
+        weight = f_similar(x1, x2, self.kH, self.kW)
+        weight = F.softmax(weight, -1)
+        out = f_weighting(x3, weight, self.kH, self.kW)
+
+        return out
+
+
+class TorchLocalAttention(nn.Module):
+    def __init__(self, inp_channels, out_channels, kH, kW):
+        super(TorchLocalAttention, self).__init__()
+        self.conv1 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
+        self.conv2 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
+        self.conv3 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
+        self.kH = kH
+        self.kW = kW
+
+    @staticmethod
+    def f_similar(x_theta, x_phi, kh, kw, casual_mask=False):
+        n, c, h, w = x_theta.size()  # (N, inter_channels, H, W)
+        pad = (kh // 2, kw // 2)
+        x_theta = x_theta.permute(0, 2, 3, 1).contiguous()
+        x_theta = x_theta.view(n * h * w, 1, c)
+
+        x_phi = F.unfold(x_phi, kernel_size=(kh, kw), stride=1, padding=pad)
+        x_phi = x_phi.contiguous().view(n, c, kh * kw, h * w)
+        x_phi = x_phi.permute(0, 3, 1, 2).contiguous()
+        x_phi = x_phi.view(n * h * w, c, kh * kw)
+        out = x_theta @ x_phi
+        out = out.view(n, h, w, kh * kw)
+        if casual_mask:
+            out = out[..., :kh * kw // 2 + 1]
+        return out
+
+    @staticmethod
+    def f_weighting(x_theta, x_phi, kh, kw, casual_mask=False):
+        n, c, h, w = x_theta.size()  # (N, inter_channels, H, W)
+        pad = (kh // 2, kw // 2)
+        x_theta = F.unfold(x_theta, kernel_size=(kh, kw), stride=1, padding=pad)
+        x_theta = x_theta.permute(0, 2, 1).contiguous()
+        x_theta = x_theta.view(n * h * w, c, kh * kw)
+
+        if casual_mask:
+            x_theta = x_theta[..., :kh * kw // 2 + 1]
+            x_phi = x_phi.view(n * h * w, kh * kw // 2 + 1, 1)
+        else:   
+            x_phi = x_phi.view(n * h * w, kh * kw, 1)
+
+        out = torch.matmul(x_theta, x_phi)
+        out = out.squeeze(-1)
+        out = out.view(n, h, w, c)
+        out = out.permute(0, 3, 1, 2).contiguous()
+
+        return out
+
+    def forward(self, x):
+        x1 = self.conv1(x)
+        x2 = self.conv2(x)
+        x3 = self.conv3(x)
+
+        weight = self.f_similar(x1, x2, self.kH, self.kW)
+        weight = F.softmax(weight, -1)
+        out = self.f_weighting(x3, weight, self.kH, self.kW)
+
+        return out
+    
+    
+if __name__ == '__main__':
+    b, c, h, w = 8, 3, 32, 32
+    kH, kW = 5, 5
+    x = torch.rand(b, c, h, w).cuda()
+    m = LocalAttention(c, c, kH, kW)
+    m.cuda()
+    y = m(x)
\ No newline at end of file
diff --git a/mpu/sparse_transformer.py b/mpu/sparse_transformer.py
index 7fb3e48..5635d66 100755
--- a/mpu/sparse_transformer.py
+++ b/mpu/sparse_transformer.py
@@ -17,10 +17,12 @@
 
 import math
 import random
+import argparse
 
 import torch
 import torch.nn.init as init
-from apex.normalization.fused_layer_norm import FusedLayerNorm #as LayerNorm
+import torch.nn.functional as F
+from apex.normalization.fused_layer_norm import FusedLayerNorm
 
 from .initialize import get_model_parallel_world_size
 from .layers import ColumnParallelLinear
@@ -32,15 +34,17 @@ import deepspeed
 from .random import checkpoint
 from .random import get_cuda_rng_tracker
 
-from .utils import divide
+from .utils import divide, sqrt
 from .utils import split_tensor_along_last_dim
 import torch.distributed as dist
 
-
 class LayerNorm(FusedLayerNorm):
-    def __init__(self, *args, **kwargs):
+    def __init__(self, pb_relax=False, *args, **kwargs):
         super().__init__(*args, **kwargs)
+        self.pb_relax = pb_relax
     def forward(self, x):
+        if not self.pb_relax:
+            return super().forward(x)
         return super().forward(x / (x.abs().max().detach()/8))
 
 class GPT2ParallelSelfAttention(torch.nn.Module):
@@ -71,7 +75,7 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
     """
     def __init__(self, hidden_size, num_attention_heads,
                  attention_dropout_prob, output_dropout_prob,
-                 init_method, output_layer_init_method=None, query_window=128, key_window_times=6):
+                 init_method, output_layer_init_method=None):
         super(GPT2ParallelSelfAttention, self).__init__()
         # Set output layer initialization if not provided.
         if output_layer_init_method is None:
@@ -83,8 +87,6 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
                                                      num_attention_heads)
         self.num_attention_heads_per_partition = divide(num_attention_heads,
                                                         world_size)
-        self.query_window = query_window
-        self.key_window_times = key_window_times
 
         # Strided linear layer.
         self.query_key_value = ColumnParallelLinear(hidden_size, 3*hidden_size,
@@ -120,53 +122,49 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
         return tensor.permute(0, 2, 1, 3)
 
 
-    def forward(self, hidden_states, ltor_mask, pivot_idx=None, is_sparse=0, mem=None):
+    def forward(self, hidden_states, mask, sparse_config, mem=None):
         # hidden_states: [b, s, h]
         # ltor_mask: [1, 1, s, s]
 
         # Attention heads. [b, s, hp]
         query_length = hidden_states.size(1)
 
+        # if mem is None:
+        mixed_raw_layer = self.query_key_value(hidden_states)
         if mem is None:
-            mixed_x_layer = self.query_key_value(hidden_states)
-            (mixed_query_layer,
-             mixed_key_layer,
-             mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
-        else:
-            cat = torch.cat((mem, hidden_states), 1)
-            mixed_x_layer = self.query_key_value(cat)
-            (mixed_query_layer,
-             mixed_key_layer,
-             mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
-            mixed_query_layer = mixed_query_layer[:, -query_length:]
-
-        # Reshape and transpose [b, np, s, hn]
-        query_layer = self._transpose_for_scores(mixed_query_layer)
-        key_layer = self._transpose_for_scores(mixed_key_layer)
-        value_layer = self._transpose_for_scores(mixed_value_layer)
-
-        # =====================   Core Attention Code  ======================== #
-        if is_sparse == 1:
-            context_layer = sparse_attention(query_layer, key_layer, value_layer, pivot_idx, ltor_mask, self.query_window, self.key_window_times, self.attention_dropout)
-        elif is_sparse == 2:
-            context_layer = sparse_attention_inference(query_layer, key_layer, value_layer, pivot_idx)
+            mixed_x_layer = mixed_raw_layer
         else:
-            context_layer = standard_attention(query_layer, key_layer, value_layer, ltor_mask, self.attention_dropout)
-        
-        # ===================== END OF BLOCK ======================= #
-
-        # [b, s, np, hn]
-        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-        new_context_layer_shape = context_layer.size()[:-2] + \
-                                  (self.hidden_size_per_partition,)
-        # [b, s, hp]
-        context_layer = context_layer.view(*new_context_layer_shape)
+            mixed_x_layer = torch.cat((mem, mixed_raw_layer), dim=1)
+        (mixed_query_layer,
+            mixed_key_layer,
+            mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
+
+        if sparse_config.sparse_type in ['standard', 'torch_1d']:
+            # Reshape and transpose [b, np, s, hn]
+            query_layer = self._transpose_for_scores(mixed_query_layer)
+            key_layer = self._transpose_for_scores(mixed_key_layer)
+            value_layer = self._transpose_for_scores(mixed_value_layer)
+            if sparse_config.sparse_type == 'standard':
+                context_layer = standard_attention(query_layer, key_layer, value_layer, mask, self.attention_dropout)
+            else:
+                context_layer = sparse_attention(query_layer, key_layer, value_layer, sparse_config.pivot_idx, 
+                    mask, sparse_config.query_window, sparse_config.key_window_times, self.attention_dropout)
+                # inference: context_layer = sparse_attention_inference(query_layer, key_layer, value_layer, pivot_idx)
+            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+            new_context_layer_shape = context_layer.size()[:-2] + \
+                                    (self.hidden_size_per_partition,)
+            # [b, s, hp]
+            context_layer = context_layer.view(*new_context_layer_shape)
+
+        elif sparse_config.sparse_type == 'cuda_2d':
+            context_layer = sparse_attention_2d(mixed_query_layer, mixed_key_layer, mixed_value_layer, self.num_attention_heads_per_partition,
+                 sparse_config.layout, mask, sparse_config.kernel_size, sparse_config.kernel_size2, attention_dropout=self.attention_dropout)
 
         # Output. [b, s, h]
         output = self.dense(context_layer)
         output = self.output_dropout(output)
 
-        return output
+        return output, mixed_raw_layer.detach()
 
 
 @torch.jit.script
@@ -178,13 +176,6 @@ def gelu_impl(x):
 def gelu(x): 
     return gelu_impl(x)
 
-@torch.jit.script
-def elu1_impl(x):
-     """OpenAI's gelu implementation."""
-     return torch.nn.functional.elu(x) + 1.
-
-def elu1(x):
-    return elu1_impl(x)
 
 class GPT2ParallelMLP(torch.nn.Module):
     """MLP for GPT2.
@@ -270,9 +261,8 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
                  layernorm_epsilon,
                  init_method,
                  output_layer_init_method=None,
-                 query_window=128,
-                 key_window_times=6,
-                 scale_normalization=True
+                 sandwich_ln=True,
+                 sparse_config=argparse.Namespace(sparse_type='standard')
                  ):
         super(GPT2ParallelTransformerLayer, self).__init__()
         # Set output layer initialization if not provided.
@@ -282,7 +272,6 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
         # Layernorm on the input data.
         self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
 
-
         # Self attention.
         self.attention = GPT2ParallelSelfAttention(
             hidden_size,
@@ -290,15 +279,13 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
             attention_dropout_prob,
             output_dropout_prob,
             init_method,
-            output_layer_init_method=output_layer_init_method,
-            query_window=query_window,
-            key_window_times=key_window_times)
+            output_layer_init_method=output_layer_init_method)
 
         # Layernorm on the input data.
         self.post_attention_layernorm = LayerNorm(hidden_size,
                                                   eps=layernorm_epsilon)
-        self.scale_normalization = scale_normalization
-        if scale_normalization:
+        self.sandwich_ln = sandwich_ln
+        if sandwich_ln:
             self.third_layernorm = LayerNorm(hidden_size,
                                                     eps=layernorm_epsilon)
             self.fourth_layernorm = LayerNorm(hidden_size,
@@ -311,7 +298,9 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
             init_method,
             output_layer_init_method=output_layer_init_method)
 
-    def forward(self, hidden_states, ltor_mask, pivot_idx=None, is_sparse=0, mem=None):
+        self.sparse_config = sparse_config
+
+    def forward(self, hidden_states, ltor_mask, mem=None):
         # hidden_states: [b, s, h]
         # ltor_mask: [1, 1, s, s]
 
@@ -319,10 +308,10 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
         layernorm_output1 = self.input_layernorm(hidden_states)
         mem = self.input_layernorm(mem) if mem is not None else None
         # Self attention.
-        attention_output = self.attention(layernorm_output1, ltor_mask, pivot_idx, is_sparse, mem)
+        attention_output, qkv = self.attention(layernorm_output1, ltor_mask, self.sparse_config, mem)
 
         # Third LayerNorm
-        if self.scale_normalization:
+        if self.sandwich_ln:
             attention_output = self.third_layernorm(attention_output)
 
         # Residual connection.
@@ -333,13 +322,13 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
         mlp_output = self.mlp(layernorm_output)
 
         # Fourth LayerNorm
-        if self.scale_normalization:
+        if self.sandwich_ln:
             mlp_output = self.fourth_layernorm(mlp_output)
 
         # Second residual connection.
         output = layernorm_input + mlp_output
 
-        return output
+        return output, qkv
 
 def unscaled_init_method(sigma):
     """Init method based on N(0, sigma)."""
@@ -406,9 +395,8 @@ class GPT2ParallelTransformer(torch.nn.Module):
                  layernorm_epsilon=1.0e-5,
                  init_method_std=0.02,
                  use_scaled_init_for_output_weights=True,
-                 query_window=128,
-                 key_window_times=6,
-                 num_pivot=768
+                 sandwich_ln=True,
+                 sparse_config=argparse.Namespace(sparse_type='standard')
                  ):
         super(GPT2ParallelTransformer, self).__init__()
         # Store activation checkpoiting flag.
@@ -446,15 +434,10 @@ class GPT2ParallelTransformer(torch.nn.Module):
                 layernorm_epsilon,
                 unscaled_init_method(init_method_std),
                 output_layer_init_method=output_layer_init_method,
-                query_window=query_window,
-                key_window_times=key_window_times,
-                scale_normalization=True
+                sandwich_ln=sandwich_ln,
+                sparse_config=sparse_config
                 )
 
-        self.query_window = query_window
-        self.key_window_times = key_window_times
-        self.num_pivot = num_pivot
-
         # Transformer layers.
         self.layers = torch.nn.ModuleList(
             [get_layer(layer_id) for layer_id in range(num_layers)])
@@ -466,14 +449,15 @@ class GPT2ParallelTransformer(torch.nn.Module):
             global get_cuda_rng_tracker, checkpoint
             get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
             checkpoint = deepspeed.checkpointing.checkpoint
-        self.rmask = None
+        self.sparse_config = sparse_config
 
-    def forward(self, hidden_states, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse=0, *mems):
+    def forward(self, hidden_states, position_ids, attention_mask, *mems):
 
         batch_size, query_length = hidden_states.size()[:2]
         memory_length = mems[0].size(1) if mems else 0
         key_length = query_length + memory_length
 
+        # legacy
         if isinstance(attention_mask, int) or attention_mask.numel() == 1:
             # if given a int "sep", means the seperation of full attention part and single direction part
             # attention mask is the beginning postion of B region, \in [0, query_len)
@@ -488,20 +472,6 @@ class GPT2ParallelTransformer(torch.nn.Module):
                 return m
             attention_mask = build_mask_matrix(query_length, key_length, sep)
 
-        if is_sparse == 1 and (self.rmask is None):
-            w, times = self.query_window, self.key_window_times
-            g = key_length // w
-            tmp = torch.ones((g-times+1, w , w), device=hidden_states.device, dtype=hidden_states.dtype)
-            tmp = torch.tril(1 - torch.block_diag(*tmp))
-            self.rmask = torch.nn.functional.pad(tmp, (0, (times-1)*w, (times-1)*w, 0)) # pad (left, right, top, bottom)  
-
-        if is_sparse == 2:
-            left_boundary = max(0, key_length - self.key_window_times * self.query_window)
-            window_idx = torch.arange(left_boundary, key_length, device=hidden_states.device, dtype=torch.long).expand(batch_size, -1)
-        elif is_sparse == 1:
-            left_boundary = key_length
-            num_pivot = self.num_pivot
-                
         # =====================   Image & Text Type Embedding   ======================== #
         # TODO: after testing, this is not useful.
         # extend_len = (key_length + 63) // 64
@@ -509,39 +479,21 @@ class GPT2ParallelTransformer(torch.nn.Module):
         #     img_indices_bool.unsqueeze(-1) * self.img_type_embeddings.expand(extend_len, 64, -1).reshape(extend_len * 64, -1)[memory_length: key_length]
         # ===================== END OF BLOCK ======================= #
 
-        if is_sparse: # 1 or 2                
-            # select out the real indices for sampling
-            img_indices = [img_indices_bool[i][:left_boundary].nonzero(as_tuple=False).view(-1) for i in range(batch_size)]
-            txt_indices = [txt_indices_bool[i][:left_boundary].nonzero(as_tuple=False).view(-1) for i in range(batch_size)]
-        
-        if is_sparse == 2:
-            ratio = self.num_pivot / self.max_sequence_length 
-            max_text_num = max(len(text_idx) for text_idx in txt_indices)
-            num_pivot = max_text_num + int((left_boundary - max_text_num) * ratio)
-
         position_embeddings = self.position_embeddings(position_ids)
         hidden_states = hidden_states + position_embeddings
         hidden_states = self.embedding_dropout(hidden_states)
 
-        if self.max_memory_length > 0:
-            mem_layers = [hidden_states.detach()]
-        else:
-            mem_layers = []
+        mem_layers = []
         def custom(start, end):
             def custom_forward(*inputs):
                 layers_ = self.layers[start:end]
-                x_, inputs = inputs[0], inputs[1:]
-                    
-                if is_sparse > 0:
-                    inputs, mems_ = inputs[:3], inputs[3:]
-                else:
-                    inputs, mems_ = inputs[:1], inputs[1:]
-
+                x_, mask, mems_ = inputs[0], inputs[1], inputs[1:]
+            
                 for i, layer in enumerate(layers_):
                     mem_i_ = mems_[i] if mems_ else None
-                    x_ = layer(x_, *inputs, mem=mem_i_)
+                    x_, qkv = layer(x_, mask, mem=mem_i_)
                     if self.max_memory_length > 0:
-                        mem_layers.append(x_.detach())
+                        mem_layers.append(qkv)
                 return x_
             return custom_forward
 
@@ -552,30 +504,7 @@ class GPT2ParallelTransformer(torch.nn.Module):
             num_layers = len(self.layers)
             chunk_length = self.checkpoint_num_layers
             while l < num_layers:
-                if is_sparse > 0:
-                    # =====================   Pivot Mask   ======================== #
-                    pivot_idx = torch.stack([
-                        torch.cat((
-                            text_idx,
-                            img_indices[i][
-                                torch.tensor(random.sample(range(len(img_indices[i])), k=num_pivot - len(text_idx)), dtype=torch.long, device=text_idx.device)
-                            ]
-                        ), dim=0)
-                        for i, text_idx in enumerate(txt_indices)
-                    ])
-                    if is_sparse == 1: # sparse training
-                        assert key_length == query_length
-                        b, s = batch_size, key_length
-                        pivot_attention_mask = self.rmask.expand(b, s, s).gather(dim=-1, index=pivot_idx.unsqueeze(1).expand(b, s, self.num_pivot))
-                        args = [hidden_states, pivot_attention_mask, pivot_idx, torch.tensor(is_sparse)]
-                    elif is_sparse == 2: # sparse inference
-                        pw_idx = torch.cat((pivot_idx, window_idx), dim=-1)
-                        args = [hidden_states, attention_mask_saved, pw_idx, torch.tensor(is_sparse)]
-                    else:
-                        raise NotImplementedError
-                    # ===================== END OF BLOCK ======================= #
-                else:
-                    args = [hidden_states, attention_mask_saved]
+                args = [hidden_states, attention_mask_saved]
 
                 if mems:
                     args += mems[l: l + chunk_length]
@@ -583,31 +512,18 @@ class GPT2ParallelTransformer(torch.nn.Module):
                 hidden_states = checkpoint(custom(l, l + chunk_length), *args)
                 l += chunk_length
         else:
-            assert is_sparse != 1, 'Please use checkpoint_activations for sparse attention training.'
+            assert self.sparse_config.sparse_type == 'standard'
             for i, layer in enumerate(self.layers):
-                if is_sparse == 0:
-                    args = [hidden_states, attention_mask_saved]
-                elif is_sparse == 2:
-                    pivot_idx = torch.stack([
-                        torch.cat((
-                            text_idx,
-                            img_indices[i][
-                                torch.tensor(random.sample(range(len(img_indices[i])), k=num_pivot - len(text_idx)), dtype=torch.long, device=text_idx.device)
-                            ]
-                        ), dim=0)
-                        for i, text_idx in enumerate(txt_indices)
-                    ])
-                    pw_idx = torch.cat((pivot_idx, window_idx), dim=-1)
-                    args = [hidden_states, attention_mask_saved, pw_idx, torch.tensor(is_sparse)]
+                args = [hidden_states, attention_mask_saved]
 
                 mem_i = mems[i] if mems else None
-                hidden_states = layer(*args, mem=mem_i)
+                hidden_states, qkv = layer(*args, mem=mem_i)
                 if self.max_memory_length > 0:
-                    mem_layers.append(hidden_states.detach())
+                    mem_layers.append(qkv) 
 
         # Final layer norm.
         output = self.final_layernorm(hidden_states)
-        if self.max_memory_length > 0:
+        if self.max_memory_length > 0: # TODO cache
             mem_layers = self.update_mems(mem_layers, mems)
 
         return (output, *mem_layers)
@@ -672,7 +588,7 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, atte
     context_layer = torch.matmul(attention_probs, value_layer)
     return context_layer
 
-def sparse_attention(q, k, v, pivot_idx, pivot_attention_mask, query_window=128, key_window_times=6, attention_dropout=None):
+def sparse_attention_1d(q, k, v, pivot_idx, pivot_attention_mask, query_window=128, key_window_times=6, attention_dropout=None):
     ''' Sparse Attention
     Args:
         q, k, v: inputs, [b, num_heads, s, hn], k is padded to n * query_window
@@ -724,98 +640,109 @@ def sparse_attention(q, k, v, pivot_idx, pivot_attention_mask, query_window=128,
 
     return context_layer
 
-def sparse_attention_inference(q, k, v, pivot_and_window_idx, **kwargs):
-    '''the inference process of sparse attention.
-    The Qs are in the same block, but seq_len mod window size might != 0.
-
-    The Qs are the final tokens of Ks. the pivot_and_window_idx[-query_len] are Qs.
-
-    '''
-    b, n_head, sq, hn = q.shape
-    sk = k.shape[2]
-    _b, n_piv = pivot_and_window_idx.shape
-
-    pivot_and_window_idx_dummy = pivot_and_window_idx.view(b, 1, n_piv, 1).expand(b, n_head, n_piv, hn)
-    pivot_k, pivot_v = torch.gather(k, 2, pivot_and_window_idx_dummy), torch.gather(v, 2, pivot_and_window_idx_dummy)
-    attention_scores = torch.matmul(q / math.sqrt(hn), pivot_k.transpose(-1, -2))
-    if sq > 1:
-        query_part_scores = attention_scores[:, :, -sq:, -sq:]
-        m = torch.ones((sq, sq), device=q.device, dtype=q.dtype) * -10000.
-        m.triu_(diagonal=1)
-        query_part_scores += m
-
-    attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
-
-    context_layer = torch.matmul(attention_probs, pivot_v) 
-    return context_layer
-
-
-def test_sparse_attention():       
+# def sparse_attention_inference_1d(q, k, v, pivot_and_window_idx, **kwargs):
+#     '''the inference process of sparse attention.
+#     The Qs are in the same block, but seq_len mod window size might != 0.
 
-    s, w, times = 4096 + 128, 128, 2
-    num_pivot = 768
-    b = 2
-    g = s // w
+#     The Qs are the final tokens of Ks. the pivot_and_window_idx[-query_len] are Qs.
 
-    q, k, v = raw = torch.rand(3, b, 16, s, 64, dtype=torch.float, device='cuda', requires_grad=True)
-    q1, k1, v1 = raw1 = torch.tensor(raw.cpu().detach().numpy(), dtype=torch.float, device='cuda', requires_grad=True)
-    txt_indices = [torch.arange(0, 128, dtype=torch.long, device='cuda'), torch.arange(0, 22, dtype=torch.long, device='cuda')]
-    img_indices = [torch.arange(128, s, dtype=torch.long, device='cuda'), torch.arange(22, s, dtype=torch.long, device='cuda')]
+#     '''
+#     b, n_head, sq, hn = q.shape
+#     sk = k.shape[2]
+#     _b, n_piv = pivot_and_window_idx.shape
 
-    pivot_idx = torch.stack([
-        torch.cat((
-            text_idx,
-            img_indices[i][
-                torch.tensor(random.sample(range(len(img_indices[i]) - times*w),  k=num_pivot - len(text_idx)), dtype=torch.long, device=text_idx.device)
-            ]
-        ), dim=0)
-        for i, text_idx in enumerate(txt_indices)
-    ]) # -times * w to verify inference
+#     pivot_and_window_idx_dummy = pivot_and_window_idx.view(b, 1, n_piv, 1).expand(b, n_head, n_piv, hn)
+#     pivot_k, pivot_v = torch.gather(k, 2, pivot_and_window_idx_dummy), torch.gather(v, 2, pivot_and_window_idx_dummy)
+#     attention_scores = torch.matmul(q / math.sqrt(hn), pivot_k.transpose(-1, -2))
+#     if sq > 1:
+#         query_part_scores = attention_scores[:, :, -sq:, -sq:]
+#         m = torch.ones((sq, sq), device=q.device, dtype=q.dtype) * -10000.
+#         m.triu_(diagonal=1)
+#         query_part_scores += m
 
-    tmp = torch.ones((g-times+1, w , w), device='cuda', dtype=torch.long)
-    tmp = torch.tril(1 - torch.block_diag(*tmp))
-    rmask = torch.nn.functional.pad(tmp, (0, (times-1)*w, (times-1)*w, 0)) # pad (left, right, top, bottom)
+#     attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
 
-    pivot_attention_mask = rmask.expand(b, s, s).gather(dim=-1, index=pivot_idx.unsqueeze(1).expand(b, s, num_pivot))
+#     context_layer = torch.matmul(attention_probs, pivot_v) 
+#     return context_layer
 
-    real_mask = torch.ones((b, s, s), device='cuda', dtype=torch.long) - rmask
-    for i in range(b):
-        real_mask[i][:, pivot_idx[i]] = 1
-        real_mask[i].tril_()
+def transpose_and_split(x, layout, n_head):
+    x = x.transpose(1, 2)
+    x = x.reshape(x.shape[0] * n_head, x.shape[1] // n_head, x.shape[2])
+    x_text = x[..., :layout[0]]
+    x0 = x[...,layout[1]:layout[2]].view(x.shape[0], x.shape[1], sqrt(layout[2] - layout[1]), -1).contiguous()
+    x1 = x[...,layout[2]:layout[3]].view(x.shape[0], x.shape[1], sqrt(layout[3] - layout[2]), -1).contiguous()
+    return x, x_text, x0, x1
 
-    # test inference
-
-    # q_part = q[..., -1:, :]
-    # r0 = standard_attention(q, k, v, real_mask)
-    # r0 = r0[..., -1:, :]
-    # pw_idx = torch.cat((pivot_idx, torch.arange(s-times*w, s, device='cuda', dtype=torch.long).expand(b, -1)), dim=-1)
-
-    # r1 = sparse_attention_inference(q_part, k, v, pw_idx)
-    # print(( (r1-r0).abs() / (r1.abs()+r0.abs())).max())
-
-    import time
+def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs):
+    '''
+    q, k, v: [batch_size, 64+1024+4096, hidden_size]
+    n_head: int
+    layout: [endoftext/startofpad, startof0, startof1, endofall]
+    attention_mask_text2d: [batch_size, sq_len, endoftext]
+    '''
+    from .local_attention_function import f_similar, f_weighting
+    b, sq_len, hn = q.shape
+    alpha = sqrt((layout[3] - layout[2]) // (layout[2] - layout[1]))
 
-    r0 = standard_attention(q1, k1, v1, real_mask)
-    torch.cuda.synchronize()
-    t0 = time.time()
-    r1 = standard_attention(q1, k1, v1, real_mask)
-    torch.cuda.synchronize()
-    t1 = time.time()
-    r2 = sparse_attention(q, k, v, pivot_idx, pivot_attention_mask, w, times)
-    torch.cuda.synchronize()
-    t2 = time.time()
-    print('times: standard ', t1-t0, ' sparse ', t2-t1)
+    q = q / math.sqrt(hn // n_head) # normalization
 
-    print(( (r1-r2).abs() / (r1.abs()+r2.abs())).max())
+    q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext]
+    k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head)
+    v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head)
 
-    raw.retain_grad()
-    l2 = r2.mean()
-    l1 = r1.mean()
-    l2.backward()
-    l1.backward()
+    # import pdb; pdb.set_trace()
+    # all to text
+    scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d)
+    scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0])
+    # 0 to 0
+    scores_0_to_0 = f_similar(q0, k0, kernel_size, kernel_size, True)
+    # 1 to 1
+    scores_1_to_1 = f_similar(q1, k1, kernel_size, kernel_size, True)    
+    # 1 to 0
+    scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2]
+    # softmax
+    probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text]
+
+    scores_0 = torch.cat(
+        (scores_all_to_text[:, layout[1]:layout[2]], 
+        scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])), 
+        dim=-1)
+    probs_0 = F.softmax(scores_0, dim=-1) # 
+    scores_1 = torch.cat(
+        (scores_all_to_text[:, layout[2]:layout[3]],
+         scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]),
+         scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])),
+         dim=-1)
+    probs_1 = F.softmax(scores_1, dim=-1)
 
-    g1 = raw1.grad
-    g2 = raw.grad
-    print( (g1-g2).abs().max())
+    if attention_dropout is not None:
+        with get_cuda_rng_tracker().fork():
+            probs_0 = attention_dropout(probs_0)
+            probs_1 = attention_dropout(probs_1)
+    # weighting
+    pad = torch.zeros(layout[1], device=q.device, dtype=q.dtype)
+    probs_all_to_text = torch.cat((
+        probs_text,
+        pad[-layout[0]:].expand(b*n_head, layout[1]-layout[0], layout[0]),
+        probs_0[:, :, :layout[0]],
+        probs_1[:, :, :layout[0]]
+    ), dim=1)
+
+    context_all_to_text = torch.einsum('bhij,bhcj->bihc', 
+        probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]), 
+        v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn)
+    
+    context_0_to_0 = f_weighting(v0, probs_0[..., layout[0]:].view_as(scores_0_to_0).contiguous(), kernel_size, kernel_size, True)
+
+    context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False)
+
+    context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size, kernel_size, True)
+    
+    context_all_to_01 =torch.cat(
+        (
+            pad.expand(b*n_head, hn//n_head, layout[1]),
+            context_0_to_0.view(b*n_head, hn//n_head, layout[2]-layout[1]),
+            (context_1_to_0 + context_1_to_1).view(b*n_head, hn//n_head, layout[3]-layout[2])
+        ), dim=-1).view(b, hn, -1).transpose(1, 2)
+    return context_all_to_text + context_all_to_01 
 
-    # import pdb; pdb.set_trace()
diff --git a/mpu/utils.py b/mpu/utils.py
index adee41c..d9b1a8d 100755
--- a/mpu/utils.py
+++ b/mpu/utils.py
@@ -15,6 +15,7 @@
 
 
 import torch
+import math
 
 
 def ensure_divisibility(numerator, denominator):
@@ -78,3 +79,6 @@ def split_out_sums(x, BLOCK_SIZE=32, all_ret=False):
         return oris.reshape(b, -1, *rs), sums.reshape(b, -1, *rs)
     else: 
         return sums.reshape(b, -1, *rs)
+
+def sqrt(x):
+    return int(math.sqrt(x) + 1e-4)
\ No newline at end of file
diff --git a/test_sparse_attention.py b/test_sparse_attention.py
new file mode 100644
index 0000000..a8c22f8
--- /dev/null
+++ b/test_sparse_attention.py
@@ -0,0 +1,169 @@
+import math
+import random
+from tqdm import tqdm
+
+import torch
+import numpy as np
+from mpu.sparse_transformer import standard_attention, sparse_attention_1d, sparse_attention_2d
+
+def test_sparse_attention_1d():       
+    s, w, times = 4096 + 128, 128, 2
+    num_pivot = 768
+    b = 2
+    g = s // w
+
+    q, k, v = raw = torch.rand(3, b, 16, s, 64, dtype=torch.float, device='cuda', requires_grad=True)
+    q1, k1, v1 = raw1 = torch.tensor(raw.cpu().detach().numpy(), dtype=torch.float, device='cuda', requires_grad=True)
+    txt_indices = [torch.arange(0, 128, dtype=torch.long, device='cuda'), torch.arange(0, 22, dtype=torch.long, device='cuda')]
+    img_indices = [torch.arange(128, s, dtype=torch.long, device='cuda'), torch.arange(22, s, dtype=torch.long, device='cuda')]
+
+    pivot_idx = torch.stack([
+        torch.cat((
+            text_idx,
+            img_indices[i][
+                torch.tensor(random.sample(range(len(img_indices[i]) - times*w),  k=num_pivot - len(text_idx)), dtype=torch.long, device=text_idx.device)
+            ]
+        ), dim=0)
+        for i, text_idx in enumerate(txt_indices)
+    ]) # -times * w to verify inference
+
+    tmp = torch.ones((g-times+1, w , w), device='cuda', dtype=torch.long)
+    tmp = torch.tril(1 - torch.block_diag(*tmp))
+    rmask = torch.nn.functional.pad(tmp, (0, (times-1)*w, (times-1)*w, 0)) # pad (left, right, top, bottom)
+
+    pivot_attention_mask = rmask.expand(b, s, s).gather(dim=-1, index=pivot_idx.unsqueeze(1).expand(b, s, num_pivot))
+
+    real_mask = torch.ones((b, s, s), device='cuda', dtype=torch.long) - rmask
+    for i in range(b):
+        real_mask[i][:, pivot_idx[i]] = 1
+        real_mask[i].tril_()
+
+    # test inference
+
+    # q_part = q[..., -1:, :]
+    # r0 = standard_attention(q, k, v, real_mask)
+    # r0 = r0[..., -1:, :]
+    # pw_idx = torch.cat((pivot_idx, torch.arange(s-times*w, s, device='cuda', dtype=torch.long).expand(b, -1)), dim=-1)
+
+    # r1 = sparse_attention_inference(q_part, k, v, pw_idx)
+    # print(( (r1-r0).abs() / (r1.abs()+r0.abs())).max())
+
+    import time
+
+    r0 = standard_attention(q1, k1, v1, real_mask)
+    torch.cuda.synchronize()
+    t0 = time.time()
+    r1 = standard_attention(q1, k1, v1, real_mask)
+    torch.cuda.synchronize()
+    t1 = time.time()
+    r2 = sparse_attention(q, k, v, pivot_idx, pivot_attention_mask, w, times)
+    torch.cuda.synchronize()
+    t2 = time.time()
+    print('times: standard ', t1-t0, ' sparse ', t2-t1)
+
+    print(( (r1-r2).abs() / (r1.abs()+r2.abs())).max())
+
+    raw.retain_grad()
+    l2 = r2.mean()
+    l1 = r1.mean()
+    l2.backward()
+    l1.backward()
+
+    g1 = raw1.grad
+    g2 = raw.grad
+    print( (g1-g2).abs().max())
+
+    # import pdb; pdb.set_trace()
+
+def test_sparse_attention_2d():
+    dtype = torch.float
+    device = 'cuda'
+    b, n_head, hn = 1, 40, 2560
+    h = w = 32
+    layout = [10, 10, 10+h*w, 10+h*w*5]
+    k1 = 9
+    k2 = 7
+
+    qkv = torch.rand(3, b, layout[-1], hn, dtype=dtype, device=device)
+    qkv2 = qkv.clone()
+    qkv.requires_grad_()
+    qkv2.requires_grad_()
+    mask = torch.zeros(b, layout[-1], layout[-1], dtype=dtype, device=device)
+    
+    m = mask[0]
+    for i in range(layout[1]):
+        m[i, :i+1] = 1
+    m[layout[1]:, :layout[0]] = 1
+    for i in tqdm(range(layout[1], layout[2])):
+        x = (i - layout[1]) // w
+        y = (i - layout[1]) % w
+        lx = max(0, x - k1 // 2)
+        ly = max(0, y - k1 // 2)
+        rx = min(h-1, x + k1 // 2)
+        ry = min(w-1, y + k1 // 2)
+        m[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1
+        m[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1
+    for i in tqdm(range(layout[2], layout[3])):
+        x = (i - layout[2]) // (2*w)
+        y = (i - layout[2]) % (2*w)
+        lx = max(0, x - k1 // 2)
+        ly = max(0, y - k1 // 2)
+        rx = min(2*h-1, x + k1 // 2)
+        ry = min(2*w-1, y + k1 // 2)
+        m[i, layout[2]:layout[3]].view(h*2, w*2)[lx:x, ly:ry+1] = 1
+        m[i, layout[2]:layout[3]].view(h*2, w*2)[x, ly:y+1] = 1
+
+        x = x // 2
+        y = y // 2
+        lx = max(0, x - k2 // 2)
+        ly = max(0, y - k2 // 2)
+        rx = min(h-1, x + k2 // 2)
+        ry = min(w-1, y + k2 // 2)
+        m[i, layout[1]:layout[2]].view(h, w)[lx:rx+1, ly:ry+1] = 1
+    
+    # mask[1:] = mask[0]
+    # mask[1][layout[1]:, layout[0]-1] = 0
+
+    print('finish making mask...')
+
+    import time
+    torch.cuda.synchronize()
+    t0 = time.time()
+    qkv_tmp = qkv.view(3, b, layout[-1], n_head, hn//n_head).permute(0, 1, 3, 2, 4).contiguous()
+    r1 = standard_attention(*qkv_tmp, mask.unsqueeze(1)).transpose(1, 2).reshape(b, layout[3], hn)
+    
+    torch.cuda.synchronize()
+    t1 = time.time()
+    r2 = sparse_attention_2d(*qkv2, n_head, layout, mask[...,:layout[0]].unsqueeze(1), kernel_size=k1, kernel_size2=k2)
+    torch.cuda.synchronize()
+    t2 = time.time()
+    print('times: standard ', t1-t0, ' sparse ', t2-t1)
+    print(( (r1[:,:layout[0]]-r2[:,:layout[0]]).abs() / (r1[:,:layout[0]].abs()+r2[:,:layout[0]].abs())).max())
+    print(( (r1[:,layout[1]:]-r2[:,layout[1]:]).abs() / (r1[:,layout[1]:].abs()+r2[:,layout[1]:].abs())).max())
+    qkv.retain_grad()
+    l2 = r2[:,layout[1]:].mean()
+    l1 = r1[:,layout[1]:].mean()
+    l2.backward()
+    l1.backward()
+
+    g1 = qkv.grad
+    g2 = qkv2.grad
+    print( (g1-g2).abs().max())
+    # import pdb;pdb.set_trace()
+    
+
+def seed_torch(seed=1029):
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
+    torch.backends.cudnn.benchmark = False
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.enabled = False
+
+if __name__ == '__main__':
+    seed_torch()
+    torch.backends.cuda.matmul.allow_tf32 = False
+    test_sparse_attention_2d()
+    
\ No newline at end of file
-- 
GitLab