Skip to content
Snippets Groups Projects
Commit 4b45416a authored by Ming Ding's avatar Ming Ding
Browse files

finish package part, need modify t5

parent 8db7ecd7
No related branches found
No related tags found
No related merge requests found
......@@ -100,7 +100,7 @@ def filling_sequence(
else:
log_attention_weights_part = None
logits, *mem_kv = model(
logits, *output_per_layers = model(
tokens[:, index:],
position_ids[..., index: counter+1],
attention_mask[..., index: counter+1, :counter+1], # TODO memlen
......@@ -108,6 +108,7 @@ def filling_sequence(
log_attention_weights=log_attention_weights_part,
**kw_args
)
mem_kv = [o['mem_kv'] for o in output_per_layers]
mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
counter += 1
index = counter
......
......@@ -39,9 +39,9 @@ class BaseMixin(torch.nn.Module):
# Eg.,
#
# @non_conflict
# def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kwargs):
# def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kw_args):
# new_q, new_k, new_v = pre_hack(q, k, v)
# attn_result = old_impl(q, k, v, mask, dropout_fn, **kwargs)
# attn_result = old_impl(q, k, v, mask, dropout_fn, **kw_args)
# attn_result = post_hack(attn_result)
# return attn_result
......
......@@ -18,38 +18,24 @@ from SwissArmyTransformer.mpu.transformer import standard_attention, split_tenso
class CachedAutoregressiveMixin(BaseMixin):
def __init__(self):
super().__init__()
def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, log_attention_weights=None, **kwargs):
attn_module = self.transformer.layers[layer_id].attention
mem = mems[layer_id] if mems is not None else None
mixed_raw_layer = attn_module.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
super().__init__()
@non_conflict
def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, old_impl=standard_attention, **kw_args):
mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size
b, nh, seq_len, hidden_size = k.shape
cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2)
kw_args['output_this_layer']['mem_kv'] = cache_kv
if mem is not None: # the first time, mem is None
b = mixed_key_layer.shape[0] # might change batch_size
memk, memv = split_tensor_along_last_dim(mem.expand(b, -1, -1), 2)
mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1)
mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1)
# might change batch_size
mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4)
memk, memv = mem[0], mem[1]
k = torch.cat((memk, k), dim=2)
v = torch.cat((memv, v), dim=2)
return old_impl(q, k, v, mask, dropout_fn, **kw_args)
# same as training
query_layer = attn_module._transpose_for_scores(mixed_query_layer)
key_layer = attn_module._transpose_for_scores(mixed_key_layer)
value_layer = attn_module._transpose_for_scores(mixed_value_layer)
context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=log_attention_weights)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = attn_module.dense(context_layer)
# new mem this layer
new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous()
return output, new_mem
class CachedAutoregressiveModel(BaseModel):
def __init__(self, args, transformer=None):
......
......@@ -88,7 +88,7 @@ class Cuda2dModel(BaseModel):
output_1 = dense_plus(context_layer1)
output = torch.cat((output_0, output_1), dim=1)
return output, None
return output
def disable_untrainable_params(self):
self.transformer.requires_grad_(False)
......
......@@ -122,7 +122,7 @@ class SelfAttention(torch.nn.Module):
def forward(self, hidden_states, mask, *args, **kw_args):
if 'attention_forward' in self.hooks:
return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id)
return self.hooks['attention_forward'](hidden_states, mask, **kw_args)
else:
attention_fn = standard_attention
if 'attention_fn' in self.hooks:
......@@ -139,7 +139,7 @@ class SelfAttention(torch.nn.Module):
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, layer_id=self.layer_id, **kw_args)
context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
......@@ -149,7 +149,7 @@ class SelfAttention(torch.nn.Module):
if self.training:
output = self.output_dropout(output)
return output, None
return output
class CrossAttention(torch.nn.Module):
......@@ -266,7 +266,7 @@ class MLP(torch.nn.Module):
def forward(self, hidden_states, **kw_args):
if 'mlp_forward' in self.hooks:
output = self.hooks['mlp_forward'](hidden_states, **kw_args, layer_id=self.layer_id)
output = self.hooks['mlp_forward'](hidden_states, **kw_args)
else:
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
......@@ -367,7 +367,7 @@ class BaseTransformerLayer(torch.nn.Module):
# Layer norm at the begining of the transformer layer.
layernorm_output1 = self.input_layernorm(hidden_states)
# Self attention.
attention_output, output_this_layer = self.attention(layernorm_output1, mask, **kw_args)
attention_output = self.attention(layernorm_output1, mask, **kw_args)
# Third LayerNorm
if self.sandwich_ln:
......@@ -381,7 +381,7 @@ class BaseTransformerLayer(torch.nn.Module):
if self.is_decoder:
encoder_outputs = kw_args['encoder_outputs']
if encoder_outputs is not None:
cross_attention_mask=kw_args['cross_attention_mask']
cross_attention_mask = kw_args['cross_attention_mask']
# Cross attention
attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args)
# Residual connection.
......@@ -399,7 +399,7 @@ class BaseTransformerLayer(torch.nn.Module):
# Second residual connection.
output = layernorm_input + mlp_output
return output, output_this_layer # temporally, output_this_layer is only from attention
return output, kw_args['output_this_layer'], kw_args['output_cross_layer']
class BaseTransformer(torch.nn.Module):
......@@ -475,7 +475,7 @@ class BaseTransformer(torch.nn.Module):
# Final layer norm before output.
self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None,
def forward(self, input_ids, position_ids, attention_mask, *,
output_hidden_states=False, **kw_args):
# sanity check
assert len(input_ids.shape) == 2
......@@ -486,11 +486,7 @@ class BaseTransformer(torch.nn.Module):
) # None means full attention
assert len(attention_mask.shape) == 2 or \
len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor)
# branch_input is a new part of input need layer-by-layer update,
# but with different hidden_dim and computational routine.
# In most cases, you can just ignore it.
# embedding part
if 'word_embedding_forward' in self.hooks:
hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args)
......@@ -507,64 +503,126 @@ class BaseTransformer(torch.nn.Module):
hidden_states = hidden_states + position_embeddings
hidden_states = self.embedding_dropout(hidden_states)
hidden_states_outputs = [hidden_states] if output_hidden_states else []
# branch related embedding
if branch_input is None and 'branch_embedding_forward' in self.hooks:
branch_input = self.hooks['branch_embedding_forward'](branch_input, **kw_args)
# define custom_forward for checkpointing
# initial output_cross_layer
if 'cross_layer_embedding_forward' in self.hooks:
output_cross_layer = self.hooks['cross_layer_embedding_forward'](hidden_states, **kw_args)
else:
output_cross_layer = {}
output_per_layers = []
if self.checkpoint_activations:
def custom(start, end):
# define custom_forward for checkpointing
def custom(start, end, kw_args_index, cross_layer_index):
def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_, mask = inputs[0], inputs[1]
if len(inputs) > 2: # have branch_input
branch_ = inputs[2]
# recover kw_args and output_cross_layer
flat_inputs = inputs[2:]
kw_args, output_cross_layer = {}, {}
for k, idx in kw_args_index.items():
kw_args[k] = flat_inputs[idx]
for k, idx in cross_layer_index.items():
output_cross_layer[k] = flat_inputs[idx]
# -----------------
output_per_layers_part = []
for i, layer in enumerate(layers_):
if len(inputs) > 2:
x_, branch_, output_this_layer = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args
)
elif 'layer_forward' in self.hooks:
x_, output_this_layer = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id, **kw_args
if 'layer_forward' in self.hooks:
layer_ret = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id,
**kw_args, **output_cross_layer,
output_this_layer={}, output_cross_layer={}
)
else:
x_, output_this_layer = layer(x_, mask, **kw_args)
layer_ret = layer(
x_, mask, layer_id=layer.layer_id,
**kw_args, **output_cross_layer,
output_this_layer={}, output_cross_layer={}
)
if torch.is_tensor(layer_ret): # only hidden_states
x_, output_this_layer, output_cross_layer = layer_ret, {}, {}
elif len(layer_ret) == 2: # hidden_states & output_this_layer
x_, output_this_layer = layer_ret
output_cross_layer = {}
elif len(layer_ret) == 3:
x_, output_this_layer, output_cross_layer = layer_ret
assert isinstance(output_this_layer, dict)
assert isinstance(output_cross_layer, dict)
if output_hidden_states:
output_this_layer['hidden_states'] = x_
output_per_layers_part.append(output_this_layer)
return x_, output_per_layers_part
# flatten for re-aggregate keywords outputs
flat_outputs = []
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
# TODO add warning for depth>=2 grad tensors
flat_outputs.append(output_this_layer[k])
output_this_layer[k] = len(flat_outputs) - 1
for k in output_cross_layer:
flat_outputs.append(output_cross_layer[k])
output_cross_layer[k] = len(flat_outputs) - 1
# --------------------
return x_, output_per_layers_part, output_cross_layer, flat_outputs
return custom_forward
# prevent to lose requires_grad in checkpointing.
# To save memory when only finetuning the final layers, don't use checkpointing.
if self.training:
hidden_states.requires_grad_(True)
l, num_layers = 0, len(self.layers)
chunk_length = self.checkpoint_num_layers
while l < num_layers:
args = [hidden_states, attention_mask]
if branch_input is not None:
hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, branch_input)
else:
hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args)
if output_hidden_states:
hidden_states_outputs.append(hidden_states)
# flatten kw_args and output_cross_layer
flat_inputs, kw_args_index, cross_layer_index = [], {}, {}
for k, v in kw_args.items():
flat_inputs.append(v)
kw_args_index[k] = len(flat_inputs) - 1
for k, v in output_cross_layer.items():
flat_inputs.append(v)
cross_layer_index[k] = len(flat_inputs) - 1
# --------------------
hidden_states, output_per_layers_part, output_cross_layer, flat_outputs = \
checkpoint(custom(l, l + chunk_length, kw_args_index, cross_layer_index), *args, *flat_inputs)
# recover output_per_layers_part, output_cross_layer
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
output_this_layer[k] = flat_outputs[output_this_layer[k]]
for k in output_cross_layer:
output_cross_layer[k] = flat_outputs[output_cross_layer[k]]
# --------------------
output_per_layers.extend(output_per_layers_part)
l += chunk_length
else:
for i, layer in enumerate(self.layers):
args = [hidden_states, attention_mask]
if branch_input is not None: # customized layer_forward with branch_input
hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_args)
elif 'layer_forward' in self.hooks: # customized layer_forward
hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args)
if 'layer_forward' in self.hooks: # customized layer_forward
layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i),
**kw_args,
**output_cross_layer,
output_this_layer={}, output_cross_layer={}
)
else:
hidden_states, output_this_layer = layer(*args, **kw_args)
layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer,
output_this_layer={}, output_cross_layer={})
if torch.is_tensor(layer_ret): # only hidden_states
hidden_states, output_this_layer, output_cross_layer = layer_ret, {}, {}
elif len(layer_ret) == 2: # hidden_states & output_this_layer
hidden_states, output_this_layer = layer_ret
output_cross_layer = {}
elif len(layer_ret) == 3:
hidden_states, output_this_layer, output_cross_layer = layer_ret
if output_hidden_states:
hidden_states_outputs.append(hidden_states)
output_this_layer['hidden_states'] = hidden_states
output_per_layers.append(output_this_layer)
# Final layer norm.
......@@ -576,18 +634,10 @@ class BaseTransformer(torch.nn.Module):
logits_parallel = copy_to_model_parallel_region(logits)
logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight)
# branch related embedding
if branch_input is None and 'branch_final_forward' in self.hooks:
branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args)
if not self.parallel_output:
logits_parallel = gather_from_model_parallel_region(logits_parallel)
outputs = [logits_parallel]
if branch_input is not None:
outputs.append(branch_input)
if output_hidden_states:
outputs.append(hidden_states_outputs)
outputs.extend(output_per_layers)
return outputs
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment