From 4e0e9674b16fc0c460d8b3b3fabf3d309d111484 Mon Sep 17 00:00:00 2001 From: duzx16 <zx-du20@mails.tsinghua.edu.cn> Date: Mon, 13 Dec 2021 14:58:28 +0800 Subject: [PATCH] Implement T5 Attention with attention_fn --- .../model/cached_autoregressive_model.py | 28 ++++---- SwissArmyTransformer/model/t5_model.py | 69 ++++--------------- SwissArmyTransformer/mpu/transformer.py | 37 +++++----- 3 files changed, 48 insertions(+), 86 deletions(-) diff --git a/SwissArmyTransformer/model/cached_autoregressive_model.py b/SwissArmyTransformer/model/cached_autoregressive_model.py index ed25fe4..8caed66 100755 --- a/SwissArmyTransformer/model/cached_autoregressive_model.py +++ b/SwissArmyTransformer/model/cached_autoregressive_model.py @@ -21,20 +21,22 @@ class CachedAutoregressiveMixin(BaseMixin): super().__init__() @non_conflict - def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, old_impl=standard_attention, **kw_args): - mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size - b, nh, seq_len, hidden_size = k.shape + def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, cross_attention=False, old_impl=standard_attention, + **kw_args): + if not cross_attention: + mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size + b, nh, seq_len, hidden_size = k.shape - cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2) - kw_args['output_this_layer']['mem_kv'] = cache_kv - - if mem is not None: # the first time, mem is None - # might change batch_size - mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4) - memk, memv = mem[0], mem[1] - k = torch.cat((memk, k), dim=2) - v = torch.cat((memv, v), dim=2) - return old_impl(q, k, v, mask, dropout_fn, **kw_args) + cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2) + kw_args['output_this_layer']['mem_kv'] = cache_kv + + if mem is not None: # the first time, mem is None + # might change batch_size + mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4) + memk, memv = mem[0], mem[1] + k = torch.cat((memk, k), dim=2) + v = torch.cat((memv, v), dim=2) + return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, mems=mems, **kw_args) class CachedAutoregressiveModel(BaseModel): diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py index fb7c7cb..e0d2dc6 100644 --- a/SwissArmyTransformer/model/t5_model.py +++ b/SwissArmyTransformer/model/t5_model.py @@ -114,63 +114,18 @@ class T5AttentionMixin(BaseMixin): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values - def attention_forward(self, hidden_states, mask, position_bias=None, *args, layer_id=None, mems=None, **kw_args): - attn_module = self.transformer.layers[layer_id].attention - seq_length = hidden_states.size(1) - memory_length = mems[layer_id].size(1) if mems else 0 - mixed_raw_layer = attn_module.query_key_value(hidden_states) - (mixed_query_layer, - mixed_key_layer, - mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) - - dropout_fn = attn_module.attention_dropout if attn_module.training else None - - query_layer = attn_module._transpose_for_scores(mixed_query_layer) - key_layer = attn_module._transpose_for_scores(mixed_key_layer) - value_layer = attn_module._transpose_for_scores(mixed_value_layer) - - if position_bias is None: - position_bias = self.compute_bias(seq_length, memory_length + seq_length) - context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn, - log_attention_weights=position_bias, scaling_attention_score=False) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - output = attn_module.dense(context_layer) - - if attn_module.training: - output = attn_module.output_dropout(output) - - kw_args['output_cross_layer']['position_bias'] = position_bias - - return output - - def cross_attention_forward(self, hidden_states, cross_attention_mask, encoder_outputs, layer_id=None, *args, - **kw_args): - attn_module = self.transformer.layers[layer_id].cross_attention - mixed_query_layer = attn_module.query(hidden_states) - mixed_x_layer = attn_module.key_value(encoder_outputs) - (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2) - - dropout_fn = attn_module.attention_dropout if attn_module.training else None - # Reshape and transpose [b, np, s, hn] - query_layer = attn_module._transpose_for_scores(mixed_query_layer) - key_layer = attn_module._transpose_for_scores(mixed_key_layer) - value_layer = attn_module._transpose_for_scores(mixed_value_layer) - - context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn, - scaling_attention_score=False) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,) - # [b, s, hp] - context_layer = context_layer.view(*new_context_layer_shape) - - # Output. [b, s, h] - output = attn_module.dense(context_layer) - if attn_module.training: - output = attn_module.output_dropout(output) - - return output + def attention_fn(self, q, k, v, mask, dropout_fn, position_bias=None, old_impl=standard_attention, + cross_attention=False, **kw_args): + log_attention_weights = None + if not cross_attention: + if position_bias is None: + seq_length = q.size(2) + key_length = k.size(2) + position_bias = self.compute_bias(seq_length, key_length) + kw_args['output_cross_layer']['position_bias'] = position_bias + log_attention_weights = position_bias + return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, position_bias=position_bias, + log_attention_weights=log_attention_weights, scaling_attention_score=False, **kw_args) class T5DecoderFinalMixin(BaseMixin): diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index 3b4ae73..2890569 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -218,6 +218,10 @@ class CrossAttention(torch.nn.Module): if 'cross_attention_forward' in self.hooks: return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, **kw_args) else: + attention_fn = standard_attention + if 'attention_fn' in self.hooks: + attention_fn = self.hooks['attention_fn'] + mixed_query_layer = self.query(hidden_states) mixed_x_layer = self.key_value(encoder_outputs) (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2) @@ -228,7 +232,8 @@ class CrossAttention(torch.nn.Module): key_layer = self._transpose_for_scores(mixed_key_layer) value_layer = self._transpose_for_scores(mixed_value_layer) - context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn) + context_layer = attention_fn(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn, + cross_attention=True, **kw_args) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) # [b, s, hp] @@ -504,7 +509,7 @@ class BaseTransformer(torch.nn.Module): ) # None means full attention assert len(attention_mask.shape) == 2 or \ len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1 - + # embedding part if 'word_embedding_forward' in self.hooks: hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args) @@ -526,7 +531,7 @@ class BaseTransformer(torch.nn.Module): output_cross_layer = self.hooks['cross_layer_embedding_forward'](hidden_states, **kw_args) else: output_cross_layer = {} - + output_per_layers = [] if self.checkpoint_activations: # define custom_forward for checkpointing @@ -534,7 +539,7 @@ class BaseTransformer(torch.nn.Module): def custom_forward(*inputs): layers_ = self.layers[start:end] x_, mask = inputs[0], inputs[1] - + # recover kw_args and output_cross_layer flat_inputs = inputs[2:] kw_args, output_cross_layer = {}, {} @@ -543,19 +548,19 @@ class BaseTransformer(torch.nn.Module): for k, idx in cross_layer_index.items(): output_cross_layer[k] = flat_inputs[idx] # ----------------- - + output_per_layers_part = [] for i, layer in enumerate(layers_): if 'layer_forward' in self.hooks: layer_ret = self.hooks['layer_forward']( - x_, mask, layer_id=layer.layer_id, - **kw_args, **output_cross_layer, + x_, mask, layer_id=layer.layer_id, + **kw_args, **output_cross_layer, output_this_layer={}, output_cross_layer={} ) else: layer_ret = layer( - x_, mask, layer_id=layer.layer_id, - **kw_args, **output_cross_layer, + x_, mask, layer_id=layer.layer_id, + **kw_args, **output_cross_layer, output_this_layer={}, output_cross_layer={} ) if torch.is_tensor(layer_ret): # only hidden_states @@ -563,7 +568,7 @@ class BaseTransformer(torch.nn.Module): elif len(layer_ret) == 2: # hidden_states & output_this_layer x_, output_this_layer = layer_ret output_cross_layer = {} - elif len(layer_ret) == 3: + elif len(layer_ret) == 3: x_, output_this_layer, output_cross_layer = layer_ret assert isinstance(output_this_layer, dict) assert isinstance(output_cross_layer, dict) @@ -590,7 +595,7 @@ class BaseTransformer(torch.nn.Module): # To save memory when only finetuning the final layers, don't use checkpointing. if self.training: hidden_states.requires_grad_(True) - + l, num_layers = 0, len(self.layers) chunk_length = self.checkpoint_num_layers output_this_layer = [] @@ -625,20 +630,20 @@ class BaseTransformer(torch.nn.Module): args = [hidden_states, attention_mask] if 'layer_forward' in self.hooks: # customized layer_forward - layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), - **kw_args, - **output_cross_layer, + layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), + **kw_args, + **output_cross_layer, output_this_layer={}, output_cross_layer={} ) else: - layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer, + layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer, output_this_layer={}, output_cross_layer={}) if torch.is_tensor(layer_ret): # only hidden_states hidden_states, output_this_layer, output_cross_layer = layer_ret, {}, {} elif len(layer_ret) == 2: # hidden_states & output_this_layer hidden_states, output_this_layer = layer_ret output_cross_layer = {} - elif len(layer_ret) == 3: + elif len(layer_ret) == 3: hidden_states, output_this_layer, output_cross_layer = layer_ret if output_hidden_states: -- GitLab