Skip to content
Snippets Groups Projects
Commit 4b45416a authored by Ming Ding's avatar Ming Ding
Browse files

finish package part, need modify t5

parent 8db7ecd7
No related branches found
No related tags found
No related merge requests found
...@@ -100,7 +100,7 @@ def filling_sequence( ...@@ -100,7 +100,7 @@ def filling_sequence(
else: else:
log_attention_weights_part = None log_attention_weights_part = None
logits, *mem_kv = model( logits, *output_per_layers = model(
tokens[:, index:], tokens[:, index:],
position_ids[..., index: counter+1], position_ids[..., index: counter+1],
attention_mask[..., index: counter+1, :counter+1], # TODO memlen attention_mask[..., index: counter+1, :counter+1], # TODO memlen
...@@ -108,6 +108,7 @@ def filling_sequence( ...@@ -108,6 +108,7 @@ def filling_sequence(
log_attention_weights=log_attention_weights_part, log_attention_weights=log_attention_weights_part,
**kw_args **kw_args
) )
mem_kv = [o['mem_kv'] for o in output_per_layers]
mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length) mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
counter += 1 counter += 1
index = counter index = counter
......
...@@ -39,9 +39,9 @@ class BaseMixin(torch.nn.Module): ...@@ -39,9 +39,9 @@ class BaseMixin(torch.nn.Module):
# Eg., # Eg.,
# #
# @non_conflict # @non_conflict
# def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kwargs): # def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kw_args):
# new_q, new_k, new_v = pre_hack(q, k, v) # new_q, new_k, new_v = pre_hack(q, k, v)
# attn_result = old_impl(q, k, v, mask, dropout_fn, **kwargs) # attn_result = old_impl(q, k, v, mask, dropout_fn, **kw_args)
# attn_result = post_hack(attn_result) # attn_result = post_hack(attn_result)
# return attn_result # return attn_result
......
...@@ -18,38 +18,24 @@ from SwissArmyTransformer.mpu.transformer import standard_attention, split_tenso ...@@ -18,38 +18,24 @@ from SwissArmyTransformer.mpu.transformer import standard_attention, split_tenso
class CachedAutoregressiveMixin(BaseMixin): class CachedAutoregressiveMixin(BaseMixin):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, log_attention_weights=None, **kwargs): @non_conflict
attn_module = self.transformer.layers[layer_id].attention def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, old_impl=standard_attention, **kw_args):
mem = mems[layer_id] if mems is not None else None mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size
b, nh, seq_len, hidden_size = k.shape
mixed_raw_layer = attn_module.query_key_value(hidden_states)
(mixed_query_layer, cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2)
mixed_key_layer, kw_args['output_this_layer']['mem_kv'] = cache_kv
mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
if mem is not None: # the first time, mem is None if mem is not None: # the first time, mem is None
b = mixed_key_layer.shape[0] # might change batch_size # might change batch_size
memk, memv = split_tensor_along_last_dim(mem.expand(b, -1, -1), 2) mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4)
mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1) memk, memv = mem[0], mem[1]
mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1) k = torch.cat((memk, k), dim=2)
v = torch.cat((memv, v), dim=2)
return old_impl(q, k, v, mask, dropout_fn, **kw_args)
# same as training
query_layer = attn_module._transpose_for_scores(mixed_query_layer)
key_layer = attn_module._transpose_for_scores(mixed_key_layer)
value_layer = attn_module._transpose_for_scores(mixed_value_layer)
context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=log_attention_weights)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = attn_module.dense(context_layer)
# new mem this layer
new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous()
return output, new_mem
class CachedAutoregressiveModel(BaseModel): class CachedAutoregressiveModel(BaseModel):
def __init__(self, args, transformer=None): def __init__(self, args, transformer=None):
......
...@@ -88,7 +88,7 @@ class Cuda2dModel(BaseModel): ...@@ -88,7 +88,7 @@ class Cuda2dModel(BaseModel):
output_1 = dense_plus(context_layer1) output_1 = dense_plus(context_layer1)
output = torch.cat((output_0, output_1), dim=1) output = torch.cat((output_0, output_1), dim=1)
return output, None return output
def disable_untrainable_params(self): def disable_untrainable_params(self):
self.transformer.requires_grad_(False) self.transformer.requires_grad_(False)
......
...@@ -122,7 +122,7 @@ class SelfAttention(torch.nn.Module): ...@@ -122,7 +122,7 @@ class SelfAttention(torch.nn.Module):
def forward(self, hidden_states, mask, *args, **kw_args): def forward(self, hidden_states, mask, *args, **kw_args):
if 'attention_forward' in self.hooks: if 'attention_forward' in self.hooks:
return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id) return self.hooks['attention_forward'](hidden_states, mask, **kw_args)
else: else:
attention_fn = standard_attention attention_fn = standard_attention
if 'attention_fn' in self.hooks: if 'attention_fn' in self.hooks:
...@@ -139,7 +139,7 @@ class SelfAttention(torch.nn.Module): ...@@ -139,7 +139,7 @@ class SelfAttention(torch.nn.Module):
key_layer = self._transpose_for_scores(mixed_key_layer) key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer) value_layer = self._transpose_for_scores(mixed_value_layer)
context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, layer_id=self.layer_id, **kw_args) context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
...@@ -149,7 +149,7 @@ class SelfAttention(torch.nn.Module): ...@@ -149,7 +149,7 @@ class SelfAttention(torch.nn.Module):
if self.training: if self.training:
output = self.output_dropout(output) output = self.output_dropout(output)
return output, None return output
class CrossAttention(torch.nn.Module): class CrossAttention(torch.nn.Module):
...@@ -266,7 +266,7 @@ class MLP(torch.nn.Module): ...@@ -266,7 +266,7 @@ class MLP(torch.nn.Module):
def forward(self, hidden_states, **kw_args): def forward(self, hidden_states, **kw_args):
if 'mlp_forward' in self.hooks: if 'mlp_forward' in self.hooks:
output = self.hooks['mlp_forward'](hidden_states, **kw_args, layer_id=self.layer_id) output = self.hooks['mlp_forward'](hidden_states, **kw_args)
else: else:
intermediate_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel) intermediate_parallel = self.activation_func(intermediate_parallel)
...@@ -367,7 +367,7 @@ class BaseTransformerLayer(torch.nn.Module): ...@@ -367,7 +367,7 @@ class BaseTransformerLayer(torch.nn.Module):
# Layer norm at the begining of the transformer layer. # Layer norm at the begining of the transformer layer.
layernorm_output1 = self.input_layernorm(hidden_states) layernorm_output1 = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
attention_output, output_this_layer = self.attention(layernorm_output1, mask, **kw_args) attention_output = self.attention(layernorm_output1, mask, **kw_args)
# Third LayerNorm # Third LayerNorm
if self.sandwich_ln: if self.sandwich_ln:
...@@ -381,7 +381,7 @@ class BaseTransformerLayer(torch.nn.Module): ...@@ -381,7 +381,7 @@ class BaseTransformerLayer(torch.nn.Module):
if self.is_decoder: if self.is_decoder:
encoder_outputs = kw_args['encoder_outputs'] encoder_outputs = kw_args['encoder_outputs']
if encoder_outputs is not None: if encoder_outputs is not None:
cross_attention_mask=kw_args['cross_attention_mask'] cross_attention_mask = kw_args['cross_attention_mask']
# Cross attention # Cross attention
attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args) attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args)
# Residual connection. # Residual connection.
...@@ -399,7 +399,7 @@ class BaseTransformerLayer(torch.nn.Module): ...@@ -399,7 +399,7 @@ class BaseTransformerLayer(torch.nn.Module):
# Second residual connection. # Second residual connection.
output = layernorm_input + mlp_output output = layernorm_input + mlp_output
return output, output_this_layer # temporally, output_this_layer is only from attention return output, kw_args['output_this_layer'], kw_args['output_cross_layer']
class BaseTransformer(torch.nn.Module): class BaseTransformer(torch.nn.Module):
...@@ -475,7 +475,7 @@ class BaseTransformer(torch.nn.Module): ...@@ -475,7 +475,7 @@ class BaseTransformer(torch.nn.Module):
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, def forward(self, input_ids, position_ids, attention_mask, *,
output_hidden_states=False, **kw_args): output_hidden_states=False, **kw_args):
# sanity check # sanity check
assert len(input_ids.shape) == 2 assert len(input_ids.shape) == 2
...@@ -486,11 +486,7 @@ class BaseTransformer(torch.nn.Module): ...@@ -486,11 +486,7 @@ class BaseTransformer(torch.nn.Module):
) # None means full attention ) # None means full attention
assert len(attention_mask.shape) == 2 or \ assert len(attention_mask.shape) == 2 or \
len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1 len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor)
# branch_input is a new part of input need layer-by-layer update,
# but with different hidden_dim and computational routine.
# In most cases, you can just ignore it.
# embedding part # embedding part
if 'word_embedding_forward' in self.hooks: if 'word_embedding_forward' in self.hooks:
hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args) hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args)
...@@ -507,64 +503,126 @@ class BaseTransformer(torch.nn.Module): ...@@ -507,64 +503,126 @@ class BaseTransformer(torch.nn.Module):
hidden_states = hidden_states + position_embeddings hidden_states = hidden_states + position_embeddings
hidden_states = self.embedding_dropout(hidden_states) hidden_states = self.embedding_dropout(hidden_states)
hidden_states_outputs = [hidden_states] if output_hidden_states else [] # initial output_cross_layer
# branch related embedding if 'cross_layer_embedding_forward' in self.hooks:
if branch_input is None and 'branch_embedding_forward' in self.hooks: output_cross_layer = self.hooks['cross_layer_embedding_forward'](hidden_states, **kw_args)
branch_input = self.hooks['branch_embedding_forward'](branch_input, **kw_args) else:
output_cross_layer = {}
# define custom_forward for checkpointing
output_per_layers = [] output_per_layers = []
if self.checkpoint_activations: if self.checkpoint_activations:
def custom(start, end): # define custom_forward for checkpointing
def custom(start, end, kw_args_index, cross_layer_index):
def custom_forward(*inputs): def custom_forward(*inputs):
layers_ = self.layers[start:end] layers_ = self.layers[start:end]
x_, mask = inputs[0], inputs[1] x_, mask = inputs[0], inputs[1]
if len(inputs) > 2: # have branch_input
branch_ = inputs[2] # recover kw_args and output_cross_layer
flat_inputs = inputs[2:]
kw_args, output_cross_layer = {}, {}
for k, idx in kw_args_index.items():
kw_args[k] = flat_inputs[idx]
for k, idx in cross_layer_index.items():
output_cross_layer[k] = flat_inputs[idx]
# -----------------
output_per_layers_part = [] output_per_layers_part = []
for i, layer in enumerate(layers_): for i, layer in enumerate(layers_):
if len(inputs) > 2: if 'layer_forward' in self.hooks:
x_, branch_, output_this_layer = self.hooks['layer_forward']( layer_ret = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args x_, mask, layer_id=layer.layer_id,
) **kw_args, **output_cross_layer,
elif 'layer_forward' in self.hooks: output_this_layer={}, output_cross_layer={}
x_, output_this_layer = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id, **kw_args
) )
else: else:
x_, output_this_layer = layer(x_, mask, **kw_args) layer_ret = layer(
x_, mask, layer_id=layer.layer_id,
**kw_args, **output_cross_layer,
output_this_layer={}, output_cross_layer={}
)
if torch.is_tensor(layer_ret): # only hidden_states
x_, output_this_layer, output_cross_layer = layer_ret, {}, {}
elif len(layer_ret) == 2: # hidden_states & output_this_layer
x_, output_this_layer = layer_ret
output_cross_layer = {}
elif len(layer_ret) == 3:
x_, output_this_layer, output_cross_layer = layer_ret
assert isinstance(output_this_layer, dict)
assert isinstance(output_cross_layer, dict)
if output_hidden_states:
output_this_layer['hidden_states'] = x_
output_per_layers_part.append(output_this_layer) output_per_layers_part.append(output_this_layer)
return x_, output_per_layers_part
# flatten for re-aggregate keywords outputs
flat_outputs = []
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
# TODO add warning for depth>=2 grad tensors
flat_outputs.append(output_this_layer[k])
output_this_layer[k] = len(flat_outputs) - 1
for k in output_cross_layer:
flat_outputs.append(output_cross_layer[k])
output_cross_layer[k] = len(flat_outputs) - 1
# --------------------
return x_, output_per_layers_part, output_cross_layer, flat_outputs
return custom_forward return custom_forward
# prevent to lose requires_grad in checkpointing. # prevent to lose requires_grad in checkpointing.
# To save memory when only finetuning the final layers, don't use checkpointing. # To save memory when only finetuning the final layers, don't use checkpointing.
if self.training: if self.training:
hidden_states.requires_grad_(True) hidden_states.requires_grad_(True)
l, num_layers = 0, len(self.layers) l, num_layers = 0, len(self.layers)
chunk_length = self.checkpoint_num_layers chunk_length = self.checkpoint_num_layers
while l < num_layers: while l < num_layers:
args = [hidden_states, attention_mask] args = [hidden_states, attention_mask]
if branch_input is not None: # flatten kw_args and output_cross_layer
hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, branch_input) flat_inputs, kw_args_index, cross_layer_index = [], {}, {}
else: for k, v in kw_args.items():
hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args) flat_inputs.append(v)
if output_hidden_states: kw_args_index[k] = len(flat_inputs) - 1
hidden_states_outputs.append(hidden_states) for k, v in output_cross_layer.items():
flat_inputs.append(v)
cross_layer_index[k] = len(flat_inputs) - 1
# --------------------
hidden_states, output_per_layers_part, output_cross_layer, flat_outputs = \
checkpoint(custom(l, l + chunk_length, kw_args_index, cross_layer_index), *args, *flat_inputs)
# recover output_per_layers_part, output_cross_layer
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
output_this_layer[k] = flat_outputs[output_this_layer[k]]
for k in output_cross_layer:
output_cross_layer[k] = flat_outputs[output_cross_layer[k]]
# --------------------
output_per_layers.extend(output_per_layers_part) output_per_layers.extend(output_per_layers_part)
l += chunk_length l += chunk_length
else: else:
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
args = [hidden_states, attention_mask] args = [hidden_states, attention_mask]
if branch_input is not None: # customized layer_forward with branch_input
hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_args) if 'layer_forward' in self.hooks: # customized layer_forward
elif 'layer_forward' in self.hooks: # customized layer_forward layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i),
hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args) **kw_args,
**output_cross_layer,
output_this_layer={}, output_cross_layer={}
)
else: else:
hidden_states, output_this_layer = layer(*args, **kw_args) layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer,
output_this_layer={}, output_cross_layer={})
if torch.is_tensor(layer_ret): # only hidden_states
hidden_states, output_this_layer, output_cross_layer = layer_ret, {}, {}
elif len(layer_ret) == 2: # hidden_states & output_this_layer
hidden_states, output_this_layer = layer_ret
output_cross_layer = {}
elif len(layer_ret) == 3:
hidden_states, output_this_layer, output_cross_layer = layer_ret
if output_hidden_states: if output_hidden_states:
hidden_states_outputs.append(hidden_states) output_this_layer['hidden_states'] = hidden_states
output_per_layers.append(output_this_layer) output_per_layers.append(output_this_layer)
# Final layer norm. # Final layer norm.
...@@ -576,18 +634,10 @@ class BaseTransformer(torch.nn.Module): ...@@ -576,18 +634,10 @@ class BaseTransformer(torch.nn.Module):
logits_parallel = copy_to_model_parallel_region(logits) logits_parallel = copy_to_model_parallel_region(logits)
logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight) logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight)
# branch related embedding
if branch_input is None and 'branch_final_forward' in self.hooks:
branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args)
if not self.parallel_output: if not self.parallel_output:
logits_parallel = gather_from_model_parallel_region(logits_parallel) logits_parallel = gather_from_model_parallel_region(logits_parallel)
outputs = [logits_parallel] outputs = [logits_parallel]
if branch_input is not None:
outputs.append(branch_input)
if output_hidden_states:
outputs.append(hidden_states_outputs)
outputs.extend(output_per_layers) outputs.extend(output_per_layers)
return outputs return outputs
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment