From 34990eea7b857973a143cdebd31e48470cacf79b Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Wed, 18 Aug 2021 18:21:39 +0000 Subject: [PATCH] sparse 2d and cache qkv --- mpu/local_attention_function.py | 149 ++++++++++++ mpu/sparse_transformer.py | 401 +++++++++++++------------------- mpu/utils.py | 4 + test_sparse_attention.py | 169 ++++++++++++++ 4 files changed, 486 insertions(+), 237 deletions(-) create mode 100644 mpu/local_attention_function.py create mode 100644 test_sparse_attention.py diff --git a/mpu/local_attention_function.py b/mpu/local_attention_function.py new file mode 100644 index 0000000..5ba073c --- /dev/null +++ b/mpu/local_attention_function.py @@ -0,0 +1,149 @@ +import torch +from torch import nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from localAttention import (similar_forward, + similar_backward, + weighting_forward, + weighting_backward_ori, + weighting_backward_weight) + +__all__ = ['f_similar', 'f_weighting', 'LocalAttention', 'TorchLocalAttention'] + + +class similarFunction(Function): + @staticmethod + def forward(ctx, x_ori, x_loc, kH, kW, casual_mask=False): + ctx.save_for_backward(x_ori, x_loc) + ctx.kHW = (kH, kW) + ctx.casual_mask = casual_mask + output = similar_forward(x_ori, x_loc, kH, kW, casual_mask) + + return output + + @staticmethod + #@once_differentiable + def backward(ctx, grad_outputs): + x_ori, x_loc = ctx.saved_tensors + kH, kW = ctx.kHW + casual_mask = ctx.casual_mask + grad_ori = similar_backward(x_ori, x_loc, grad_outputs, kH, kW, True, casual_mask) + grad_loc = similar_backward(x_ori, x_loc, grad_outputs, kH, kW, False, casual_mask) + + return grad_ori, grad_loc, None, None, None + + +class weightingFunction(Function): + @staticmethod + def forward(ctx, x_ori, x_weight, kH, kW, casual_mask=False): + ctx.save_for_backward(x_ori, x_weight) + ctx.kHW = (kH, kW) + ctx.casual_mask = casual_mask + output = weighting_forward(x_ori, x_weight, kH, kW, casual_mask) + + return output + + @staticmethod + #@once_differentiable + def backward(ctx, grad_outputs): + x_ori, x_weight = ctx.saved_tensors + kH, kW = ctx.kHW + casual_mask = ctx.casual_mask + grad_ori = weighting_backward_ori(x_ori, x_weight, grad_outputs, kH, kW, casual_mask) + grad_weight = weighting_backward_weight(x_ori, x_weight, grad_outputs, kH, kW, casual_mask) + + return grad_ori, grad_weight, None, None, None + + +f_similar = similarFunction.apply +f_weighting = weightingFunction.apply + + +class LocalAttention(nn.Module): + def __init__(self, inp_channels, out_channels, kH, kW): + super(LocalAttention, self).__init__() + self.conv1 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False) + self.conv2 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False) + self.conv3 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False) + self.kH = kH + self.kW = kW + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x) + x3 = self.conv3(x) + + weight = f_similar(x1, x2, self.kH, self.kW) + weight = F.softmax(weight, -1) + out = f_weighting(x3, weight, self.kH, self.kW) + + return out + + +class TorchLocalAttention(nn.Module): + def __init__(self, inp_channels, out_channels, kH, kW): + super(TorchLocalAttention, self).__init__() + self.conv1 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False) + self.conv2 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False) + self.conv3 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False) + self.kH = kH + self.kW = kW + + @staticmethod + def f_similar(x_theta, x_phi, kh, kw, casual_mask=False): + n, c, h, w = x_theta.size() # (N, inter_channels, H, W) + pad = (kh // 2, kw // 2) + x_theta = x_theta.permute(0, 2, 3, 1).contiguous() + x_theta = x_theta.view(n * h * w, 1, c) + + x_phi = F.unfold(x_phi, kernel_size=(kh, kw), stride=1, padding=pad) + x_phi = x_phi.contiguous().view(n, c, kh * kw, h * w) + x_phi = x_phi.permute(0, 3, 1, 2).contiguous() + x_phi = x_phi.view(n * h * w, c, kh * kw) + out = x_theta @ x_phi + out = out.view(n, h, w, kh * kw) + if casual_mask: + out = out[..., :kh * kw // 2 + 1] + return out + + @staticmethod + def f_weighting(x_theta, x_phi, kh, kw, casual_mask=False): + n, c, h, w = x_theta.size() # (N, inter_channels, H, W) + pad = (kh // 2, kw // 2) + x_theta = F.unfold(x_theta, kernel_size=(kh, kw), stride=1, padding=pad) + x_theta = x_theta.permute(0, 2, 1).contiguous() + x_theta = x_theta.view(n * h * w, c, kh * kw) + + if casual_mask: + x_theta = x_theta[..., :kh * kw // 2 + 1] + x_phi = x_phi.view(n * h * w, kh * kw // 2 + 1, 1) + else: + x_phi = x_phi.view(n * h * w, kh * kw, 1) + + out = torch.matmul(x_theta, x_phi) + out = out.squeeze(-1) + out = out.view(n, h, w, c) + out = out.permute(0, 3, 1, 2).contiguous() + + return out + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x) + x3 = self.conv3(x) + + weight = self.f_similar(x1, x2, self.kH, self.kW) + weight = F.softmax(weight, -1) + out = self.f_weighting(x3, weight, self.kH, self.kW) + + return out + + +if __name__ == '__main__': + b, c, h, w = 8, 3, 32, 32 + kH, kW = 5, 5 + x = torch.rand(b, c, h, w).cuda() + m = LocalAttention(c, c, kH, kW) + m.cuda() + y = m(x) \ No newline at end of file diff --git a/mpu/sparse_transformer.py b/mpu/sparse_transformer.py index 7fb3e48..5635d66 100755 --- a/mpu/sparse_transformer.py +++ b/mpu/sparse_transformer.py @@ -17,10 +17,12 @@ import math import random +import argparse import torch import torch.nn.init as init -from apex.normalization.fused_layer_norm import FusedLayerNorm #as LayerNorm +import torch.nn.functional as F +from apex.normalization.fused_layer_norm import FusedLayerNorm from .initialize import get_model_parallel_world_size from .layers import ColumnParallelLinear @@ -32,15 +34,17 @@ import deepspeed from .random import checkpoint from .random import get_cuda_rng_tracker -from .utils import divide +from .utils import divide, sqrt from .utils import split_tensor_along_last_dim import torch.distributed as dist - class LayerNorm(FusedLayerNorm): - def __init__(self, *args, **kwargs): + def __init__(self, pb_relax=False, *args, **kwargs): super().__init__(*args, **kwargs) + self.pb_relax = pb_relax def forward(self, x): + if not self.pb_relax: + return super().forward(x) return super().forward(x / (x.abs().max().detach()/8)) class GPT2ParallelSelfAttention(torch.nn.Module): @@ -71,7 +75,7 @@ class GPT2ParallelSelfAttention(torch.nn.Module): """ def __init__(self, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, - init_method, output_layer_init_method=None, query_window=128, key_window_times=6): + init_method, output_layer_init_method=None): super(GPT2ParallelSelfAttention, self).__init__() # Set output layer initialization if not provided. if output_layer_init_method is None: @@ -83,8 +87,6 @@ class GPT2ParallelSelfAttention(torch.nn.Module): num_attention_heads) self.num_attention_heads_per_partition = divide(num_attention_heads, world_size) - self.query_window = query_window - self.key_window_times = key_window_times # Strided linear layer. self.query_key_value = ColumnParallelLinear(hidden_size, 3*hidden_size, @@ -120,53 +122,49 @@ class GPT2ParallelSelfAttention(torch.nn.Module): return tensor.permute(0, 2, 1, 3) - def forward(self, hidden_states, ltor_mask, pivot_idx=None, is_sparse=0, mem=None): + def forward(self, hidden_states, mask, sparse_config, mem=None): # hidden_states: [b, s, h] # ltor_mask: [1, 1, s, s] # Attention heads. [b, s, hp] query_length = hidden_states.size(1) + # if mem is None: + mixed_raw_layer = self.query_key_value(hidden_states) if mem is None: - mixed_x_layer = self.query_key_value(hidden_states) - (mixed_query_layer, - mixed_key_layer, - mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - else: - cat = torch.cat((mem, hidden_states), 1) - mixed_x_layer = self.query_key_value(cat) - (mixed_query_layer, - mixed_key_layer, - mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - mixed_query_layer = mixed_query_layer[:, -query_length:] - - # Reshape and transpose [b, np, s, hn] - query_layer = self._transpose_for_scores(mixed_query_layer) - key_layer = self._transpose_for_scores(mixed_key_layer) - value_layer = self._transpose_for_scores(mixed_value_layer) - - # ===================== Core Attention Code ======================== # - if is_sparse == 1: - context_layer = sparse_attention(query_layer, key_layer, value_layer, pivot_idx, ltor_mask, self.query_window, self.key_window_times, self.attention_dropout) - elif is_sparse == 2: - context_layer = sparse_attention_inference(query_layer, key_layer, value_layer, pivot_idx) + mixed_x_layer = mixed_raw_layer else: - context_layer = standard_attention(query_layer, key_layer, value_layer, ltor_mask, self.attention_dropout) - - # ===================== END OF BLOCK ======================= # - - # [b, s, np, hn] - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.hidden_size_per_partition,) - # [b, s, hp] - context_layer = context_layer.view(*new_context_layer_shape) + mixed_x_layer = torch.cat((mem, mixed_raw_layer), dim=1) + (mixed_query_layer, + mixed_key_layer, + mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + if sparse_config.sparse_type in ['standard', 'torch_1d']: + # Reshape and transpose [b, np, s, hn] + query_layer = self._transpose_for_scores(mixed_query_layer) + key_layer = self._transpose_for_scores(mixed_key_layer) + value_layer = self._transpose_for_scores(mixed_value_layer) + if sparse_config.sparse_type == 'standard': + context_layer = standard_attention(query_layer, key_layer, value_layer, mask, self.attention_dropout) + else: + context_layer = sparse_attention(query_layer, key_layer, value_layer, sparse_config.pivot_idx, + mask, sparse_config.query_window, sparse_config.key_window_times, self.attention_dropout) + # inference: context_layer = sparse_attention_inference(query_layer, key_layer, value_layer, pivot_idx) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + # [b, s, hp] + context_layer = context_layer.view(*new_context_layer_shape) + + elif sparse_config.sparse_type == 'cuda_2d': + context_layer = sparse_attention_2d(mixed_query_layer, mixed_key_layer, mixed_value_layer, self.num_attention_heads_per_partition, + sparse_config.layout, mask, sparse_config.kernel_size, sparse_config.kernel_size2, attention_dropout=self.attention_dropout) # Output. [b, s, h] output = self.dense(context_layer) output = self.output_dropout(output) - return output + return output, mixed_raw_layer.detach() @torch.jit.script @@ -178,13 +176,6 @@ def gelu_impl(x): def gelu(x): return gelu_impl(x) -@torch.jit.script -def elu1_impl(x): - """OpenAI's gelu implementation.""" - return torch.nn.functional.elu(x) + 1. - -def elu1(x): - return elu1_impl(x) class GPT2ParallelMLP(torch.nn.Module): """MLP for GPT2. @@ -270,9 +261,8 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): layernorm_epsilon, init_method, output_layer_init_method=None, - query_window=128, - key_window_times=6, - scale_normalization=True + sandwich_ln=True, + sparse_config=argparse.Namespace(sparse_type='standard') ): super(GPT2ParallelTransformerLayer, self).__init__() # Set output layer initialization if not provided. @@ -282,7 +272,6 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): # Layernorm on the input data. self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) - # Self attention. self.attention = GPT2ParallelSelfAttention( hidden_size, @@ -290,15 +279,13 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): attention_dropout_prob, output_dropout_prob, init_method, - output_layer_init_method=output_layer_init_method, - query_window=query_window, - key_window_times=key_window_times) + output_layer_init_method=output_layer_init_method) # Layernorm on the input data. self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) - self.scale_normalization = scale_normalization - if scale_normalization: + self.sandwich_ln = sandwich_ln + if sandwich_ln: self.third_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.fourth_layernorm = LayerNorm(hidden_size, @@ -311,7 +298,9 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): init_method, output_layer_init_method=output_layer_init_method) - def forward(self, hidden_states, ltor_mask, pivot_idx=None, is_sparse=0, mem=None): + self.sparse_config = sparse_config + + def forward(self, hidden_states, ltor_mask, mem=None): # hidden_states: [b, s, h] # ltor_mask: [1, 1, s, s] @@ -319,10 +308,10 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): layernorm_output1 = self.input_layernorm(hidden_states) mem = self.input_layernorm(mem) if mem is not None else None # Self attention. - attention_output = self.attention(layernorm_output1, ltor_mask, pivot_idx, is_sparse, mem) + attention_output, qkv = self.attention(layernorm_output1, ltor_mask, self.sparse_config, mem) # Third LayerNorm - if self.scale_normalization: + if self.sandwich_ln: attention_output = self.third_layernorm(attention_output) # Residual connection. @@ -333,13 +322,13 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): mlp_output = self.mlp(layernorm_output) # Fourth LayerNorm - if self.scale_normalization: + if self.sandwich_ln: mlp_output = self.fourth_layernorm(mlp_output) # Second residual connection. output = layernorm_input + mlp_output - return output + return output, qkv def unscaled_init_method(sigma): """Init method based on N(0, sigma).""" @@ -406,9 +395,8 @@ class GPT2ParallelTransformer(torch.nn.Module): layernorm_epsilon=1.0e-5, init_method_std=0.02, use_scaled_init_for_output_weights=True, - query_window=128, - key_window_times=6, - num_pivot=768 + sandwich_ln=True, + sparse_config=argparse.Namespace(sparse_type='standard') ): super(GPT2ParallelTransformer, self).__init__() # Store activation checkpoiting flag. @@ -446,15 +434,10 @@ class GPT2ParallelTransformer(torch.nn.Module): layernorm_epsilon, unscaled_init_method(init_method_std), output_layer_init_method=output_layer_init_method, - query_window=query_window, - key_window_times=key_window_times, - scale_normalization=True + sandwich_ln=sandwich_ln, + sparse_config=sparse_config ) - self.query_window = query_window - self.key_window_times = key_window_times - self.num_pivot = num_pivot - # Transformer layers. self.layers = torch.nn.ModuleList( [get_layer(layer_id) for layer_id in range(num_layers)]) @@ -466,14 +449,15 @@ class GPT2ParallelTransformer(torch.nn.Module): global get_cuda_rng_tracker, checkpoint get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint - self.rmask = None + self.sparse_config = sparse_config - def forward(self, hidden_states, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse=0, *mems): + def forward(self, hidden_states, position_ids, attention_mask, *mems): batch_size, query_length = hidden_states.size()[:2] memory_length = mems[0].size(1) if mems else 0 key_length = query_length + memory_length + # legacy if isinstance(attention_mask, int) or attention_mask.numel() == 1: # if given a int "sep", means the seperation of full attention part and single direction part # attention mask is the beginning postion of B region, \in [0, query_len) @@ -488,20 +472,6 @@ class GPT2ParallelTransformer(torch.nn.Module): return m attention_mask = build_mask_matrix(query_length, key_length, sep) - if is_sparse == 1 and (self.rmask is None): - w, times = self.query_window, self.key_window_times - g = key_length // w - tmp = torch.ones((g-times+1, w , w), device=hidden_states.device, dtype=hidden_states.dtype) - tmp = torch.tril(1 - torch.block_diag(*tmp)) - self.rmask = torch.nn.functional.pad(tmp, (0, (times-1)*w, (times-1)*w, 0)) # pad (left, right, top, bottom) - - if is_sparse == 2: - left_boundary = max(0, key_length - self.key_window_times * self.query_window) - window_idx = torch.arange(left_boundary, key_length, device=hidden_states.device, dtype=torch.long).expand(batch_size, -1) - elif is_sparse == 1: - left_boundary = key_length - num_pivot = self.num_pivot - # ===================== Image & Text Type Embedding ======================== # # TODO: after testing, this is not useful. # extend_len = (key_length + 63) // 64 @@ -509,39 +479,21 @@ class GPT2ParallelTransformer(torch.nn.Module): # img_indices_bool.unsqueeze(-1) * self.img_type_embeddings.expand(extend_len, 64, -1).reshape(extend_len * 64, -1)[memory_length: key_length] # ===================== END OF BLOCK ======================= # - if is_sparse: # 1 or 2 - # select out the real indices for sampling - img_indices = [img_indices_bool[i][:left_boundary].nonzero(as_tuple=False).view(-1) for i in range(batch_size)] - txt_indices = [txt_indices_bool[i][:left_boundary].nonzero(as_tuple=False).view(-1) for i in range(batch_size)] - - if is_sparse == 2: - ratio = self.num_pivot / self.max_sequence_length - max_text_num = max(len(text_idx) for text_idx in txt_indices) - num_pivot = max_text_num + int((left_boundary - max_text_num) * ratio) - position_embeddings = self.position_embeddings(position_ids) hidden_states = hidden_states + position_embeddings hidden_states = self.embedding_dropout(hidden_states) - if self.max_memory_length > 0: - mem_layers = [hidden_states.detach()] - else: - mem_layers = [] + mem_layers = [] def custom(start, end): def custom_forward(*inputs): layers_ = self.layers[start:end] - x_, inputs = inputs[0], inputs[1:] - - if is_sparse > 0: - inputs, mems_ = inputs[:3], inputs[3:] - else: - inputs, mems_ = inputs[:1], inputs[1:] - + x_, mask, mems_ = inputs[0], inputs[1], inputs[1:] + for i, layer in enumerate(layers_): mem_i_ = mems_[i] if mems_ else None - x_ = layer(x_, *inputs, mem=mem_i_) + x_, qkv = layer(x_, mask, mem=mem_i_) if self.max_memory_length > 0: - mem_layers.append(x_.detach()) + mem_layers.append(qkv) return x_ return custom_forward @@ -552,30 +504,7 @@ class GPT2ParallelTransformer(torch.nn.Module): num_layers = len(self.layers) chunk_length = self.checkpoint_num_layers while l < num_layers: - if is_sparse > 0: - # ===================== Pivot Mask ======================== # - pivot_idx = torch.stack([ - torch.cat(( - text_idx, - img_indices[i][ - torch.tensor(random.sample(range(len(img_indices[i])), k=num_pivot - len(text_idx)), dtype=torch.long, device=text_idx.device) - ] - ), dim=0) - for i, text_idx in enumerate(txt_indices) - ]) - if is_sparse == 1: # sparse training - assert key_length == query_length - b, s = batch_size, key_length - pivot_attention_mask = self.rmask.expand(b, s, s).gather(dim=-1, index=pivot_idx.unsqueeze(1).expand(b, s, self.num_pivot)) - args = [hidden_states, pivot_attention_mask, pivot_idx, torch.tensor(is_sparse)] - elif is_sparse == 2: # sparse inference - pw_idx = torch.cat((pivot_idx, window_idx), dim=-1) - args = [hidden_states, attention_mask_saved, pw_idx, torch.tensor(is_sparse)] - else: - raise NotImplementedError - # ===================== END OF BLOCK ======================= # - else: - args = [hidden_states, attention_mask_saved] + args = [hidden_states, attention_mask_saved] if mems: args += mems[l: l + chunk_length] @@ -583,31 +512,18 @@ class GPT2ParallelTransformer(torch.nn.Module): hidden_states = checkpoint(custom(l, l + chunk_length), *args) l += chunk_length else: - assert is_sparse != 1, 'Please use checkpoint_activations for sparse attention training.' + assert self.sparse_config.sparse_type == 'standard' for i, layer in enumerate(self.layers): - if is_sparse == 0: - args = [hidden_states, attention_mask_saved] - elif is_sparse == 2: - pivot_idx = torch.stack([ - torch.cat(( - text_idx, - img_indices[i][ - torch.tensor(random.sample(range(len(img_indices[i])), k=num_pivot - len(text_idx)), dtype=torch.long, device=text_idx.device) - ] - ), dim=0) - for i, text_idx in enumerate(txt_indices) - ]) - pw_idx = torch.cat((pivot_idx, window_idx), dim=-1) - args = [hidden_states, attention_mask_saved, pw_idx, torch.tensor(is_sparse)] + args = [hidden_states, attention_mask_saved] mem_i = mems[i] if mems else None - hidden_states = layer(*args, mem=mem_i) + hidden_states, qkv = layer(*args, mem=mem_i) if self.max_memory_length > 0: - mem_layers.append(hidden_states.detach()) + mem_layers.append(qkv) # Final layer norm. output = self.final_layernorm(hidden_states) - if self.max_memory_length > 0: + if self.max_memory_length > 0: # TODO cache mem_layers = self.update_mems(mem_layers, mems) return (output, *mem_layers) @@ -672,7 +588,7 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, atte context_layer = torch.matmul(attention_probs, value_layer) return context_layer -def sparse_attention(q, k, v, pivot_idx, pivot_attention_mask, query_window=128, key_window_times=6, attention_dropout=None): +def sparse_attention_1d(q, k, v, pivot_idx, pivot_attention_mask, query_window=128, key_window_times=6, attention_dropout=None): ''' Sparse Attention Args: q, k, v: inputs, [b, num_heads, s, hn], k is padded to n * query_window @@ -724,98 +640,109 @@ def sparse_attention(q, k, v, pivot_idx, pivot_attention_mask, query_window=128, return context_layer -def sparse_attention_inference(q, k, v, pivot_and_window_idx, **kwargs): - '''the inference process of sparse attention. - The Qs are in the same block, but seq_len mod window size might != 0. - - The Qs are the final tokens of Ks. the pivot_and_window_idx[-query_len] are Qs. - - ''' - b, n_head, sq, hn = q.shape - sk = k.shape[2] - _b, n_piv = pivot_and_window_idx.shape - - pivot_and_window_idx_dummy = pivot_and_window_idx.view(b, 1, n_piv, 1).expand(b, n_head, n_piv, hn) - pivot_k, pivot_v = torch.gather(k, 2, pivot_and_window_idx_dummy), torch.gather(v, 2, pivot_and_window_idx_dummy) - attention_scores = torch.matmul(q / math.sqrt(hn), pivot_k.transpose(-1, -2)) - if sq > 1: - query_part_scores = attention_scores[:, :, -sq:, -sq:] - m = torch.ones((sq, sq), device=q.device, dtype=q.dtype) * -10000. - m.triu_(diagonal=1) - query_part_scores += m - - attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) - - context_layer = torch.matmul(attention_probs, pivot_v) - return context_layer - - -def test_sparse_attention(): +# def sparse_attention_inference_1d(q, k, v, pivot_and_window_idx, **kwargs): +# '''the inference process of sparse attention. +# The Qs are in the same block, but seq_len mod window size might != 0. - s, w, times = 4096 + 128, 128, 2 - num_pivot = 768 - b = 2 - g = s // w +# The Qs are the final tokens of Ks. the pivot_and_window_idx[-query_len] are Qs. - q, k, v = raw = torch.rand(3, b, 16, s, 64, dtype=torch.float, device='cuda', requires_grad=True) - q1, k1, v1 = raw1 = torch.tensor(raw.cpu().detach().numpy(), dtype=torch.float, device='cuda', requires_grad=True) - txt_indices = [torch.arange(0, 128, dtype=torch.long, device='cuda'), torch.arange(0, 22, dtype=torch.long, device='cuda')] - img_indices = [torch.arange(128, s, dtype=torch.long, device='cuda'), torch.arange(22, s, dtype=torch.long, device='cuda')] +# ''' +# b, n_head, sq, hn = q.shape +# sk = k.shape[2] +# _b, n_piv = pivot_and_window_idx.shape - pivot_idx = torch.stack([ - torch.cat(( - text_idx, - img_indices[i][ - torch.tensor(random.sample(range(len(img_indices[i]) - times*w), k=num_pivot - len(text_idx)), dtype=torch.long, device=text_idx.device) - ] - ), dim=0) - for i, text_idx in enumerate(txt_indices) - ]) # -times * w to verify inference +# pivot_and_window_idx_dummy = pivot_and_window_idx.view(b, 1, n_piv, 1).expand(b, n_head, n_piv, hn) +# pivot_k, pivot_v = torch.gather(k, 2, pivot_and_window_idx_dummy), torch.gather(v, 2, pivot_and_window_idx_dummy) +# attention_scores = torch.matmul(q / math.sqrt(hn), pivot_k.transpose(-1, -2)) +# if sq > 1: +# query_part_scores = attention_scores[:, :, -sq:, -sq:] +# m = torch.ones((sq, sq), device=q.device, dtype=q.dtype) * -10000. +# m.triu_(diagonal=1) +# query_part_scores += m - tmp = torch.ones((g-times+1, w , w), device='cuda', dtype=torch.long) - tmp = torch.tril(1 - torch.block_diag(*tmp)) - rmask = torch.nn.functional.pad(tmp, (0, (times-1)*w, (times-1)*w, 0)) # pad (left, right, top, bottom) +# attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) - pivot_attention_mask = rmask.expand(b, s, s).gather(dim=-1, index=pivot_idx.unsqueeze(1).expand(b, s, num_pivot)) +# context_layer = torch.matmul(attention_probs, pivot_v) +# return context_layer - real_mask = torch.ones((b, s, s), device='cuda', dtype=torch.long) - rmask - for i in range(b): - real_mask[i][:, pivot_idx[i]] = 1 - real_mask[i].tril_() +def transpose_and_split(x, layout, n_head): + x = x.transpose(1, 2) + x = x.reshape(x.shape[0] * n_head, x.shape[1] // n_head, x.shape[2]) + x_text = x[..., :layout[0]] + x0 = x[...,layout[1]:layout[2]].view(x.shape[0], x.shape[1], sqrt(layout[2] - layout[1]), -1).contiguous() + x1 = x[...,layout[2]:layout[3]].view(x.shape[0], x.shape[1], sqrt(layout[3] - layout[2]), -1).contiguous() + return x, x_text, x0, x1 - # test inference - - # q_part = q[..., -1:, :] - # r0 = standard_attention(q, k, v, real_mask) - # r0 = r0[..., -1:, :] - # pw_idx = torch.cat((pivot_idx, torch.arange(s-times*w, s, device='cuda', dtype=torch.long).expand(b, -1)), dim=-1) - - # r1 = sparse_attention_inference(q_part, k, v, pw_idx) - # print(( (r1-r0).abs() / (r1.abs()+r0.abs())).max()) - - import time +def sparse_attention_2d(q, k, v, n_head, layout, attention_mask_text2d, kernel_size=9, kernel_size2=7, attention_dropout=None, **kwargs): + ''' + q, k, v: [batch_size, 64+1024+4096, hidden_size] + n_head: int + layout: [endoftext/startofpad, startof0, startof1, endofall] + attention_mask_text2d: [batch_size, sq_len, endoftext] + ''' + from .local_attention_function import f_similar, f_weighting + b, sq_len, hn = q.shape + alpha = sqrt((layout[3] - layout[2]) // (layout[2] - layout[1])) - r0 = standard_attention(q1, k1, v1, real_mask) - torch.cuda.synchronize() - t0 = time.time() - r1 = standard_attention(q1, k1, v1, real_mask) - torch.cuda.synchronize() - t1 = time.time() - r2 = sparse_attention(q, k, v, pivot_idx, pivot_attention_mask, w, times) - torch.cuda.synchronize() - t2 = time.time() - print('times: standard ', t1-t0, ' sparse ', t2-t1) + q = q / math.sqrt(hn // n_head) # normalization - print(( (r1-r2).abs() / (r1.abs()+r2.abs())).max()) + q_all, q_text, q0, q1 = transpose_and_split(q, layout, n_head) # 0, 1 [batch * n_head, hn_per_head, h, w] text [batch * n_head, hn_per_head, endoftext] + k_all, k_text, k0, k1 = transpose_and_split(k, layout, n_head) + v_all, v_text, v0, v1 = transpose_and_split(v, layout, n_head) - raw.retain_grad() - l2 = r2.mean() - l1 = r1.mean() - l2.backward() - l1.backward() + # import pdb; pdb.set_trace() + # all to text + scores_all_to_text = torch.einsum('bhi,bhj->bij', q_all, k_text).view(b, n_head, layout[3], layout[0]) * attention_mask_text2d - 10000.0 * (1.0 - attention_mask_text2d) + scores_all_to_text = scores_all_to_text.view(b*n_head, layout[3], layout[0]) + # 0 to 0 + scores_0_to_0 = f_similar(q0, k0, kernel_size, kernel_size, True) + # 1 to 1 + scores_1_to_1 = f_similar(q1, k1, kernel_size, kernel_size, True) + # 1 to 0 + scores_1_to_0 = f_similar(q1, k0, kernel_size2, kernel_size2, False) # [batch * n_head, 2h, 2w, kernel_size2**2] + # softmax + probs_text = F.softmax(scores_all_to_text[:, :layout[0]], dim=-1) # [batch * n_head, seq_text, seq_text] + + scores_0 = torch.cat( + (scores_all_to_text[:, layout[1]:layout[2]], + scores_0_to_0.view(b * n_head, layout[2]-layout[1], scores_0_to_0.shape[-1])), + dim=-1) + probs_0 = F.softmax(scores_0, dim=-1) # + scores_1 = torch.cat( + (scores_all_to_text[:, layout[2]:layout[3]], + scores_1_to_0.view(scores_1_to_0.shape[0], -1, scores_1_to_0.shape[3]), + scores_1_to_1.view(scores_1_to_1.shape[0], -1, scores_1_to_1.shape[3])), + dim=-1) + probs_1 = F.softmax(scores_1, dim=-1) - g1 = raw1.grad - g2 = raw.grad - print( (g1-g2).abs().max()) + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + probs_0 = attention_dropout(probs_0) + probs_1 = attention_dropout(probs_1) + # weighting + pad = torch.zeros(layout[1], device=q.device, dtype=q.dtype) + probs_all_to_text = torch.cat(( + probs_text, + pad[-layout[0]:].expand(b*n_head, layout[1]-layout[0], layout[0]), + probs_0[:, :, :layout[0]], + probs_1[:, :, :layout[0]] + ), dim=1) + + context_all_to_text = torch.einsum('bhij,bhcj->bihc', + probs_all_to_text.view(b, n_head, probs_all_to_text.shape[1], probs_all_to_text.shape[2]), + v_text.view(b, n_head, v_text.shape[1], v_text.shape[2])).reshape(b, -1, hn) + + context_0_to_0 = f_weighting(v0, probs_0[..., layout[0]:].view_as(scores_0_to_0).contiguous(), kernel_size, kernel_size, True) + + context_1_to_0 = f_weighting(v0, probs_1[:, :, layout[0]:layout[0]+scores_1_to_0.shape[-1]].view_as(scores_1_to_0).contiguous(), kernel_size2, kernel_size2, False) + + context_1_to_1 = f_weighting(v1, probs_1[:, :, -scores_1_to_1.shape[-1]:].view_as(scores_1_to_1).contiguous(), kernel_size, kernel_size, True) + + context_all_to_01 =torch.cat( + ( + pad.expand(b*n_head, hn//n_head, layout[1]), + context_0_to_0.view(b*n_head, hn//n_head, layout[2]-layout[1]), + (context_1_to_0 + context_1_to_1).view(b*n_head, hn//n_head, layout[3]-layout[2]) + ), dim=-1).view(b, hn, -1).transpose(1, 2) + return context_all_to_text + context_all_to_01 - # import pdb; pdb.set_trace() diff --git a/mpu/utils.py b/mpu/utils.py index adee41c..d9b1a8d 100755 --- a/mpu/utils.py +++ b/mpu/utils.py @@ -15,6 +15,7 @@ import torch +import math def ensure_divisibility(numerator, denominator): @@ -78,3 +79,6 @@ def split_out_sums(x, BLOCK_SIZE=32, all_ret=False): return oris.reshape(b, -1, *rs), sums.reshape(b, -1, *rs) else: return sums.reshape(b, -1, *rs) + +def sqrt(x): + return int(math.sqrt(x) + 1e-4) \ No newline at end of file diff --git a/test_sparse_attention.py b/test_sparse_attention.py new file mode 100644 index 0000000..a8c22f8 --- /dev/null +++ b/test_sparse_attention.py @@ -0,0 +1,169 @@ +import math +import random +from tqdm import tqdm + +import torch +import numpy as np +from mpu.sparse_transformer import standard_attention, sparse_attention_1d, sparse_attention_2d + +def test_sparse_attention_1d(): + s, w, times = 4096 + 128, 128, 2 + num_pivot = 768 + b = 2 + g = s // w + + q, k, v = raw = torch.rand(3, b, 16, s, 64, dtype=torch.float, device='cuda', requires_grad=True) + q1, k1, v1 = raw1 = torch.tensor(raw.cpu().detach().numpy(), dtype=torch.float, device='cuda', requires_grad=True) + txt_indices = [torch.arange(0, 128, dtype=torch.long, device='cuda'), torch.arange(0, 22, dtype=torch.long, device='cuda')] + img_indices = [torch.arange(128, s, dtype=torch.long, device='cuda'), torch.arange(22, s, dtype=torch.long, device='cuda')] + + pivot_idx = torch.stack([ + torch.cat(( + text_idx, + img_indices[i][ + torch.tensor(random.sample(range(len(img_indices[i]) - times*w), k=num_pivot - len(text_idx)), dtype=torch.long, device=text_idx.device) + ] + ), dim=0) + for i, text_idx in enumerate(txt_indices) + ]) # -times * w to verify inference + + tmp = torch.ones((g-times+1, w , w), device='cuda', dtype=torch.long) + tmp = torch.tril(1 - torch.block_diag(*tmp)) + rmask = torch.nn.functional.pad(tmp, (0, (times-1)*w, (times-1)*w, 0)) # pad (left, right, top, bottom) + + pivot_attention_mask = rmask.expand(b, s, s).gather(dim=-1, index=pivot_idx.unsqueeze(1).expand(b, s, num_pivot)) + + real_mask = torch.ones((b, s, s), device='cuda', dtype=torch.long) - rmask + for i in range(b): + real_mask[i][:, pivot_idx[i]] = 1 + real_mask[i].tril_() + + # test inference + + # q_part = q[..., -1:, :] + # r0 = standard_attention(q, k, v, real_mask) + # r0 = r0[..., -1:, :] + # pw_idx = torch.cat((pivot_idx, torch.arange(s-times*w, s, device='cuda', dtype=torch.long).expand(b, -1)), dim=-1) + + # r1 = sparse_attention_inference(q_part, k, v, pw_idx) + # print(( (r1-r0).abs() / (r1.abs()+r0.abs())).max()) + + import time + + r0 = standard_attention(q1, k1, v1, real_mask) + torch.cuda.synchronize() + t0 = time.time() + r1 = standard_attention(q1, k1, v1, real_mask) + torch.cuda.synchronize() + t1 = time.time() + r2 = sparse_attention(q, k, v, pivot_idx, pivot_attention_mask, w, times) + torch.cuda.synchronize() + t2 = time.time() + print('times: standard ', t1-t0, ' sparse ', t2-t1) + + print(( (r1-r2).abs() / (r1.abs()+r2.abs())).max()) + + raw.retain_grad() + l2 = r2.mean() + l1 = r1.mean() + l2.backward() + l1.backward() + + g1 = raw1.grad + g2 = raw.grad + print( (g1-g2).abs().max()) + + # import pdb; pdb.set_trace() + +def test_sparse_attention_2d(): + dtype = torch.float + device = 'cuda' + b, n_head, hn = 1, 40, 2560 + h = w = 32 + layout = [10, 10, 10+h*w, 10+h*w*5] + k1 = 9 + k2 = 7 + + qkv = torch.rand(3, b, layout[-1], hn, dtype=dtype, device=device) + qkv2 = qkv.clone() + qkv.requires_grad_() + qkv2.requires_grad_() + mask = torch.zeros(b, layout[-1], layout[-1], dtype=dtype, device=device) + + m = mask[0] + for i in range(layout[1]): + m[i, :i+1] = 1 + m[layout[1]:, :layout[0]] = 1 + for i in tqdm(range(layout[1], layout[2])): + x = (i - layout[1]) // w + y = (i - layout[1]) % w + lx = max(0, x - k1 // 2) + ly = max(0, y - k1 // 2) + rx = min(h-1, x + k1 // 2) + ry = min(w-1, y + k1 // 2) + m[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = 1 + m[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = 1 + for i in tqdm(range(layout[2], layout[3])): + x = (i - layout[2]) // (2*w) + y = (i - layout[2]) % (2*w) + lx = max(0, x - k1 // 2) + ly = max(0, y - k1 // 2) + rx = min(2*h-1, x + k1 // 2) + ry = min(2*w-1, y + k1 // 2) + m[i, layout[2]:layout[3]].view(h*2, w*2)[lx:x, ly:ry+1] = 1 + m[i, layout[2]:layout[3]].view(h*2, w*2)[x, ly:y+1] = 1 + + x = x // 2 + y = y // 2 + lx = max(0, x - k2 // 2) + ly = max(0, y - k2 // 2) + rx = min(h-1, x + k2 // 2) + ry = min(w-1, y + k2 // 2) + m[i, layout[1]:layout[2]].view(h, w)[lx:rx+1, ly:ry+1] = 1 + + # mask[1:] = mask[0] + # mask[1][layout[1]:, layout[0]-1] = 0 + + print('finish making mask...') + + import time + torch.cuda.synchronize() + t0 = time.time() + qkv_tmp = qkv.view(3, b, layout[-1], n_head, hn//n_head).permute(0, 1, 3, 2, 4).contiguous() + r1 = standard_attention(*qkv_tmp, mask.unsqueeze(1)).transpose(1, 2).reshape(b, layout[3], hn) + + torch.cuda.synchronize() + t1 = time.time() + r2 = sparse_attention_2d(*qkv2, n_head, layout, mask[...,:layout[0]].unsqueeze(1), kernel_size=k1, kernel_size2=k2) + torch.cuda.synchronize() + t2 = time.time() + print('times: standard ', t1-t0, ' sparse ', t2-t1) + print(( (r1[:,:layout[0]]-r2[:,:layout[0]]).abs() / (r1[:,:layout[0]].abs()+r2[:,:layout[0]].abs())).max()) + print(( (r1[:,layout[1]:]-r2[:,layout[1]:]).abs() / (r1[:,layout[1]:].abs()+r2[:,layout[1]:].abs())).max()) + qkv.retain_grad() + l2 = r2[:,layout[1]:].mean() + l1 = r1[:,layout[1]:].mean() + l2.backward() + l1.backward() + + g1 = qkv.grad + g2 = qkv2.grad + print( (g1-g2).abs().max()) + # import pdb;pdb.set_trace() + + +def seed_torch(seed=1029): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.enabled = False + +if __name__ == '__main__': + seed_torch() + torch.backends.cuda.matmul.allow_tf32 = False + test_sparse_attention_2d() + \ No newline at end of file -- GitLab