From 4b45416a2e7857bcd00f5c4ab4e9895d0d04e04a Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Thu, 9 Dec 2021 16:00:49 +0000 Subject: [PATCH] finish package part, need modify t5 --- .../generation/autoregressive_sampling.py | 3 +- SwissArmyTransformer/model/base_model.py | 4 +- .../model/cached_autoregressive_model.py | 44 ++--- SwissArmyTransformer/model/cuda2d_model.py | 2 +- SwissArmyTransformer/mpu/transformer.py | 154 ++++++++++++------ 5 files changed, 122 insertions(+), 85 deletions(-) diff --git a/SwissArmyTransformer/generation/autoregressive_sampling.py b/SwissArmyTransformer/generation/autoregressive_sampling.py index 4eada99..eb5acf4 100644 --- a/SwissArmyTransformer/generation/autoregressive_sampling.py +++ b/SwissArmyTransformer/generation/autoregressive_sampling.py @@ -100,7 +100,7 @@ def filling_sequence( else: log_attention_weights_part = None - logits, *mem_kv = model( + logits, *output_per_layers = model( tokens[:, index:], position_ids[..., index: counter+1], attention_mask[..., index: counter+1, :counter+1], # TODO memlen @@ -108,6 +108,7 @@ def filling_sequence( log_attention_weights=log_attention_weights_part, **kw_args ) + mem_kv = [o['mem_kv'] for o in output_per_layers] mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length) counter += 1 index = counter diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py index 1b0569c..61543f5 100644 --- a/SwissArmyTransformer/model/base_model.py +++ b/SwissArmyTransformer/model/base_model.py @@ -39,9 +39,9 @@ class BaseMixin(torch.nn.Module): # Eg., # # @non_conflict - # def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kwargs): + # def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kw_args): # new_q, new_k, new_v = pre_hack(q, k, v) - # attn_result = old_impl(q, k, v, mask, dropout_fn, **kwargs) + # attn_result = old_impl(q, k, v, mask, dropout_fn, **kw_args) # attn_result = post_hack(attn_result) # return attn_result diff --git a/SwissArmyTransformer/model/cached_autoregressive_model.py b/SwissArmyTransformer/model/cached_autoregressive_model.py index 296a811..ed25fe4 100755 --- a/SwissArmyTransformer/model/cached_autoregressive_model.py +++ b/SwissArmyTransformer/model/cached_autoregressive_model.py @@ -18,38 +18,24 @@ from SwissArmyTransformer.mpu.transformer import standard_attention, split_tenso class CachedAutoregressiveMixin(BaseMixin): def __init__(self): - super().__init__() - - def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, log_attention_weights=None, **kwargs): - attn_module = self.transformer.layers[layer_id].attention - mem = mems[layer_id] if mems is not None else None - - mixed_raw_layer = attn_module.query_key_value(hidden_states) - (mixed_query_layer, - mixed_key_layer, - mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) + super().__init__() + + @non_conflict + def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, old_impl=standard_attention, **kw_args): + mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size + b, nh, seq_len, hidden_size = k.shape + + cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2) + kw_args['output_this_layer']['mem_kv'] = cache_kv if mem is not None: # the first time, mem is None - b = mixed_key_layer.shape[0] # might change batch_size - memk, memv = split_tensor_along_last_dim(mem.expand(b, -1, -1), 2) - mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1) - mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1) + # might change batch_size + mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4) + memk, memv = mem[0], mem[1] + k = torch.cat((memk, k), dim=2) + v = torch.cat((memv, v), dim=2) + return old_impl(q, k, v, mask, dropout_fn, **kw_args) - # same as training - query_layer = attn_module._transpose_for_scores(mixed_query_layer) - key_layer = attn_module._transpose_for_scores(mixed_key_layer) - value_layer = attn_module._transpose_for_scores(mixed_value_layer) - context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=log_attention_weights) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - output = attn_module.dense(context_layer) - - # new mem this layer - new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous() - - return output, new_mem class CachedAutoregressiveModel(BaseModel): def __init__(self, args, transformer=None): diff --git a/SwissArmyTransformer/model/cuda2d_model.py b/SwissArmyTransformer/model/cuda2d_model.py index cda027c..4fdc87d 100644 --- a/SwissArmyTransformer/model/cuda2d_model.py +++ b/SwissArmyTransformer/model/cuda2d_model.py @@ -88,7 +88,7 @@ class Cuda2dModel(BaseModel): output_1 = dense_plus(context_layer1) output = torch.cat((output_0, output_1), dim=1) - return output, None + return output def disable_untrainable_params(self): self.transformer.requires_grad_(False) diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index 229e21a..c553cc2 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -122,7 +122,7 @@ class SelfAttention(torch.nn.Module): def forward(self, hidden_states, mask, *args, **kw_args): if 'attention_forward' in self.hooks: - return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id) + return self.hooks['attention_forward'](hidden_states, mask, **kw_args) else: attention_fn = standard_attention if 'attention_fn' in self.hooks: @@ -139,7 +139,7 @@ class SelfAttention(torch.nn.Module): key_layer = self._transpose_for_scores(mixed_key_layer) value_layer = self._transpose_for_scores(mixed_value_layer) - context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, layer_id=self.layer_id, **kw_args) + context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) @@ -149,7 +149,7 @@ class SelfAttention(torch.nn.Module): if self.training: output = self.output_dropout(output) - return output, None + return output class CrossAttention(torch.nn.Module): @@ -266,7 +266,7 @@ class MLP(torch.nn.Module): def forward(self, hidden_states, **kw_args): if 'mlp_forward' in self.hooks: - output = self.hooks['mlp_forward'](hidden_states, **kw_args, layer_id=self.layer_id) + output = self.hooks['mlp_forward'](hidden_states, **kw_args) else: intermediate_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel = self.activation_func(intermediate_parallel) @@ -367,7 +367,7 @@ class BaseTransformerLayer(torch.nn.Module): # Layer norm at the begining of the transformer layer. layernorm_output1 = self.input_layernorm(hidden_states) # Self attention. - attention_output, output_this_layer = self.attention(layernorm_output1, mask, **kw_args) + attention_output = self.attention(layernorm_output1, mask, **kw_args) # Third LayerNorm if self.sandwich_ln: @@ -381,7 +381,7 @@ class BaseTransformerLayer(torch.nn.Module): if self.is_decoder: encoder_outputs = kw_args['encoder_outputs'] if encoder_outputs is not None: - cross_attention_mask=kw_args['cross_attention_mask'] + cross_attention_mask = kw_args['cross_attention_mask'] # Cross attention attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args) # Residual connection. @@ -399,7 +399,7 @@ class BaseTransformerLayer(torch.nn.Module): # Second residual connection. output = layernorm_input + mlp_output - return output, output_this_layer # temporally, output_this_layer is only from attention + return output, kw_args['output_this_layer'], kw_args['output_cross_layer'] class BaseTransformer(torch.nn.Module): @@ -475,7 +475,7 @@ class BaseTransformer(torch.nn.Module): # Final layer norm before output. self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) - def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, + def forward(self, input_ids, position_ids, attention_mask, *, output_hidden_states=False, **kw_args): # sanity check assert len(input_ids.shape) == 2 @@ -486,11 +486,7 @@ class BaseTransformer(torch.nn.Module): ) # None means full attention assert len(attention_mask.shape) == 2 or \ len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1 - assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor) - # branch_input is a new part of input need layer-by-layer update, - # but with different hidden_dim and computational routine. - # In most cases, you can just ignore it. - + # embedding part if 'word_embedding_forward' in self.hooks: hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args) @@ -507,64 +503,126 @@ class BaseTransformer(torch.nn.Module): hidden_states = hidden_states + position_embeddings hidden_states = self.embedding_dropout(hidden_states) - hidden_states_outputs = [hidden_states] if output_hidden_states else [] - # branch related embedding - if branch_input is None and 'branch_embedding_forward' in self.hooks: - branch_input = self.hooks['branch_embedding_forward'](branch_input, **kw_args) - - # define custom_forward for checkpointing + # initial output_cross_layer + if 'cross_layer_embedding_forward' in self.hooks: + output_cross_layer = self.hooks['cross_layer_embedding_forward'](hidden_states, **kw_args) + else: + output_cross_layer = {} + output_per_layers = [] if self.checkpoint_activations: - def custom(start, end): + # define custom_forward for checkpointing + def custom(start, end, kw_args_index, cross_layer_index): def custom_forward(*inputs): layers_ = self.layers[start:end] x_, mask = inputs[0], inputs[1] - if len(inputs) > 2: # have branch_input - branch_ = inputs[2] + + # recover kw_args and output_cross_layer + flat_inputs = inputs[2:] + kw_args, output_cross_layer = {}, {} + for k, idx in kw_args_index.items(): + kw_args[k] = flat_inputs[idx] + for k, idx in cross_layer_index.items(): + output_cross_layer[k] = flat_inputs[idx] + # ----------------- + output_per_layers_part = [] for i, layer in enumerate(layers_): - if len(inputs) > 2: - x_, branch_, output_this_layer = self.hooks['layer_forward']( - x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args - ) - elif 'layer_forward' in self.hooks: - x_, output_this_layer = self.hooks['layer_forward']( - x_, mask, layer_id=layer.layer_id, **kw_args + if 'layer_forward' in self.hooks: + layer_ret = self.hooks['layer_forward']( + x_, mask, layer_id=layer.layer_id, + **kw_args, **output_cross_layer, + output_this_layer={}, output_cross_layer={} ) else: - x_, output_this_layer = layer(x_, mask, **kw_args) + layer_ret = layer( + x_, mask, layer_id=layer.layer_id, + **kw_args, **output_cross_layer, + output_this_layer={}, output_cross_layer={} + ) + if torch.is_tensor(layer_ret): # only hidden_states + x_, output_this_layer, output_cross_layer = layer_ret, {}, {} + elif len(layer_ret) == 2: # hidden_states & output_this_layer + x_, output_this_layer = layer_ret + output_cross_layer = {} + elif len(layer_ret) == 3: + x_, output_this_layer, output_cross_layer = layer_ret + assert isinstance(output_this_layer, dict) + assert isinstance(output_cross_layer, dict) + if output_hidden_states: + output_this_layer['hidden_states'] = x_ output_per_layers_part.append(output_this_layer) - return x_, output_per_layers_part + + # flatten for re-aggregate keywords outputs + flat_outputs = [] + for output_this_layer in output_per_layers_part: + for k in output_this_layer: + # TODO add warning for depth>=2 grad tensors + flat_outputs.append(output_this_layer[k]) + output_this_layer[k] = len(flat_outputs) - 1 + for k in output_cross_layer: + flat_outputs.append(output_cross_layer[k]) + output_cross_layer[k] = len(flat_outputs) - 1 + # -------------------- + + return x_, output_per_layers_part, output_cross_layer, flat_outputs return custom_forward # prevent to lose requires_grad in checkpointing. # To save memory when only finetuning the final layers, don't use checkpointing. if self.training: hidden_states.requires_grad_(True) - + l, num_layers = 0, len(self.layers) chunk_length = self.checkpoint_num_layers while l < num_layers: args = [hidden_states, attention_mask] - if branch_input is not None: - hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, branch_input) - else: - hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args) - if output_hidden_states: - hidden_states_outputs.append(hidden_states) + # flatten kw_args and output_cross_layer + flat_inputs, kw_args_index, cross_layer_index = [], {}, {} + for k, v in kw_args.items(): + flat_inputs.append(v) + kw_args_index[k] = len(flat_inputs) - 1 + for k, v in output_cross_layer.items(): + flat_inputs.append(v) + cross_layer_index[k] = len(flat_inputs) - 1 + # -------------------- + + hidden_states, output_per_layers_part, output_cross_layer, flat_outputs = \ + checkpoint(custom(l, l + chunk_length, kw_args_index, cross_layer_index), *args, *flat_inputs) + + # recover output_per_layers_part, output_cross_layer + for output_this_layer in output_per_layers_part: + for k in output_this_layer: + output_this_layer[k] = flat_outputs[output_this_layer[k]] + for k in output_cross_layer: + output_cross_layer[k] = flat_outputs[output_cross_layer[k]] + # -------------------- + output_per_layers.extend(output_per_layers_part) l += chunk_length else: for i, layer in enumerate(self.layers): args = [hidden_states, attention_mask] - if branch_input is not None: # customized layer_forward with branch_input - hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_args) - elif 'layer_forward' in self.hooks: # customized layer_forward - hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args) + + if 'layer_forward' in self.hooks: # customized layer_forward + layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), + **kw_args, + **output_cross_layer, + output_this_layer={}, output_cross_layer={} + ) else: - hidden_states, output_this_layer = layer(*args, **kw_args) + layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer, + output_this_layer={}, output_cross_layer={}) + if torch.is_tensor(layer_ret): # only hidden_states + hidden_states, output_this_layer, output_cross_layer = layer_ret, {}, {} + elif len(layer_ret) == 2: # hidden_states & output_this_layer + hidden_states, output_this_layer = layer_ret + output_cross_layer = {} + elif len(layer_ret) == 3: + hidden_states, output_this_layer, output_cross_layer = layer_ret + if output_hidden_states: - hidden_states_outputs.append(hidden_states) + output_this_layer['hidden_states'] = hidden_states output_per_layers.append(output_this_layer) # Final layer norm. @@ -576,18 +634,10 @@ class BaseTransformer(torch.nn.Module): logits_parallel = copy_to_model_parallel_region(logits) logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight) - # branch related embedding - if branch_input is None and 'branch_final_forward' in self.hooks: - branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args) - if not self.parallel_output: logits_parallel = gather_from_model_parallel_region(logits_parallel) outputs = [logits_parallel] - if branch_input is not None: - outputs.append(branch_input) - if output_hidden_states: - outputs.append(hidden_states_outputs) outputs.extend(output_per_layers) return outputs -- GitLab