Skip to content
Snippets Groups Projects
Unverified Commit 7a5652df authored by Sleepy_chord's avatar Sleepy_chord Committed by GitHub
Browse files

Merge pull request #6 from THUDM/finer-attn-hooks

Finer attn hooks
parents 8f2800cf 981fdabd
No related branches found
No related tags found
No related merge requests found
Showing
with 946 additions and 415 deletions
......@@ -33,6 +33,8 @@ def add_model_config_args(parser):
help='num of transformer attention heads')
group.add_argument('--hidden-size', type=int, default=1024,
help='tansformer hidden size')
group.add_argument('--inner-hidden-size', type=int, default=None)
group.add_argument('--hidden-size-per-attention-head', type=int, default=None)
group.add_argument('--num-layers', type=int, default=24,
help='num decoder layers')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
......@@ -129,6 +131,8 @@ def add_training_args(parser):
group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode')
group.add_argument('--bf16', action='store_true',
help='Run model in fp16 mode')
return parser
......@@ -146,7 +150,8 @@ def add_evaluation_args(parser):
'validation/test for')
group.add_argument('--eval-interval', type=int, default=1000,
help='interval between running evaluation on validation set')
group.add_argument('--strict-eval', action='store_true',
help='won\'t enlarge or randomly map eval-data, and eval full eval-data.')
return parser
......@@ -297,25 +302,28 @@ def get_args(args_list=None):
print('using world size: {} and model-parallel size: {} '.format(
args.world_size, args.model_parallel_size))
if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_config is not None:
with open(args.deepspeed_config) as file:
deepspeed_config = json.load(file)
if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
args.fp16 = True
else:
args.fp16 = False
if hasattr(args, "deepspeed") and args.deepspeed:
if args.checkpoint_activations:
args.deepspeed_activation_checkpointing = True
if "train_micro_batch_size_per_gpu" in deepspeed_config:
args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
if "gradient_accumulation_steps" in deepspeed_config:
args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
else:
args.gradient_accumulation_steps = None
if "optimizer" in deepspeed_config:
optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
args.lr = optimizer_params_config.get("lr", args.lr)
args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
args.deepspeed_activation_checkpointing = False
if args.deepspeed_config is not None:
with open(args.deepspeed_config) as file:
deepspeed_config = json.load(file)
if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
args.fp16 = True
else:
args.fp16 = False
if "train_micro_batch_size_per_gpu" in deepspeed_config:
args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
if "gradient_accumulation_steps" in deepspeed_config:
args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
else:
args.gradient_accumulation_steps = None
if "optimizer" in deepspeed_config:
optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
args.lr = optimizer_params_config.get("lr", args.lr)
args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
return args
......@@ -24,14 +24,17 @@ from .samplers import DistributedBatchSampler
from SwissArmyTransformer import mpu
def make_data_loader(dataset, batch_size, num_iters, args):
def make_data_loader(dataset, batch_size, args):
world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
distributed = world_size > 1
sampler = torch.utils.data.SequentialSampler(dataset)
drop_last = distributed
# drop_last = distributed
drop_last = True # TODO will always drop last to keep the consistency.
# or, how to avg in eval last batch?
# the GPUs in the same model parallel group receive the same data
if distributed: # TODO reformat this, but it is not urgent
gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1)
......@@ -52,7 +55,7 @@ def make_data_loader(dataset, batch_size, num_iters, args):
return data_loader
def make_dataset_full(path, split, args, create_dataset_function, **kwargs):
def make_dataset_full(path, split, args, create_dataset_function, random_mapping=True, **kwargs):
"""function to create datasets+tokenizers for common options"""
print('make dataset ...', path)
if split is None:
......@@ -67,12 +70,8 @@ def make_dataset_full(path, split, args, create_dataset_function, **kwargs):
ds = ConcatDataset(ds)
if should_split(split):
ds = split_ds(ds, split, block_size=args.block_size)
else:
elif random_mapping:
ds = RandomMappingDataset(ds)
# if should_split(split):
# ds = split_ds(ds, split) # Large dataset, cannot shuffle, randomly mapping
# FIXME this will merge valid set and train set.
return ds
def make_loaders(args, create_dataset_function):
......@@ -115,25 +114,25 @@ def make_loaders(args, create_dataset_function):
# make training and val dataset if necessary
if valid is None and args.valid_data is not None:
eval_set_args['path'] = args.valid_data
valid = make_dataset(**eval_set_args, args=args)
valid = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
if test is None and args.test_data is not None:
eval_set_args['path'] = args.test_data
test = make_dataset(**eval_set_args, args=args)
test = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
# wrap datasets with data loader
if train is not None and args.batch_size > 0:
train = make_data_loader(train, batch_size, args.train_iters, args)
train = make_data_loader(train, batch_size, args)
args.do_train = True
else:
args.do_train = False
eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
if valid is not None:
valid = make_data_loader(valid, eval_batch_size, args.train_iters, args)
valid = make_data_loader(valid, eval_batch_size, args)
args.do_valid = True
else:
args.do_valid = False
if test is not None:
test = make_data_loader(test, eval_batch_size, len(test) // eval_batch_size + 1, args)
test = make_data_loader(test, eval_batch_size, args)
args.do_test = True
else:
args.do_test = False
......
......@@ -65,3 +65,18 @@ class BinaryDataset(Dataset):
def __getitem__(self, index):
return self.process_fn(self.bin[index])
class TSVDataset(Dataset):
def __init__(self, path, process_fn, with_heads=True, **kwargs):
self.process_fn = process_fn
with open(path, 'r') as fin:
if with_heads:
self.heads = fin.readline().split('\t')
else:
self.heads = None
self.items = [line.split('\t') for line in fin]
def __len__(self):
return len(self.items)
def __getitem__(self, index):
return self.process_fn(self.items[index])
......@@ -56,7 +56,8 @@ def filling_sequence(
max_memory_length=100000,
log_attention_weights=None,
get_masks_and_position_ids=get_masks_and_position_ids_default,
mems=None
mems=None,
**kw_args
):
'''
seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
......@@ -99,13 +100,15 @@ def filling_sequence(
else:
log_attention_weights_part = None
logits, *mem_kv = model(
logits, *output_per_layers = model(
tokens[:, index:],
position_ids[..., index: counter+1],
attention_mask[..., index: counter+1, :counter+1], # TODO memlen
mems=mems,
log_attention_weights=log_attention_weights_part
log_attention_weights=log_attention_weights_part,
**kw_args
)
mem_kv = [o['mem_kv'] for o in output_per_layers]
mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
counter += 1
index = counter
......
......@@ -2,4 +2,5 @@ from .base_model import BaseModel
from .cached_autoregressive_model import CachedAutoregressiveModel
from .cuda2d_model import Cuda2dModel
from .glm_model import GLMModel
from .encoder_decoder_model import EncoderDecoderModel
\ No newline at end of file
from .encoder_decoder_model import EncoderDecoderModel
from .t5_model import T5Model
......@@ -7,6 +7,7 @@
'''
# here put the import lib
from functools import partial
import os
import sys
import math
......@@ -14,19 +15,39 @@ import random
import torch
from SwissArmyTransformer.mpu import BaseTransformer
from SwissArmyTransformer.mpu.transformer import standard_attention
def non_conflict(func):
func.non_conflict = True
return func
class BaseMixin(torch.nn.Module):
def __init__(self):
super(BaseMixin, self).__init__()
# define new params
def reinit(self, *pre_mixins):
# reload the initial params from previous trained modules
pass
# can also define hook-functions here
# can define hook-functions here
# ...
# If the hook is just a pre- or post- transformation,
# You can use @non_conflict to mark it,
# and run `old_impl` to make it compatible with other mixins.
# Eg.,
#
# @non_conflict
# def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kw_args):
# new_q, new_k, new_v = pre_hack(q, k, v)
# attn_result = old_impl(q, k, v, mask, dropout_fn, **kw_args)
# attn_result = post_hack(attn_result)
# return attn_result
class BaseModel(torch.nn.Module):
def __init__(self, args, transformer=None, parallel_output=True):
def __init__(self, args, transformer=None, **kwargs):
super(BaseModel, self).__init__()
self.mixins = torch.nn.ModuleDict()
self.collect_hooks_()
......@@ -42,14 +63,16 @@ class BaseModel(torch.nn.Module):
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
inner_hidden_size=args.inner_hidden_size,
hidden_size_per_attention_head=args.hidden_size_per_attention_head,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
sandwich_ln=args.sandwich_ln,
parallel_output=parallel_output,
hooks=self.hooks
hooks=self.hooks,
**kwargs
)
def reinit(self): # will be called when loading model
def reinit(self): # will be called when loading model
# if some mixins are loaded, overrides this function
for m in self.mixins.values():
m.reinit(self.transformer)
......@@ -58,11 +81,11 @@ class BaseModel(torch.nn.Module):
assert name not in self.mixins
assert isinstance(new_mixin, BaseMixin)
self.mixins[name] = new_mixin # will auto-register parameters
object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr
self.mixins[name] = new_mixin # will auto-register parameters
object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr
if reinit:
new_mixin.reinit(self.transformer, **self.mixins) # also pass current mixins
new_mixin.reinit(self.transformer, **self.mixins) # also pass current mixins
self.collect_hooks_()
def del_mixin(self, name):
......@@ -82,18 +105,25 @@ class BaseModel(torch.nn.Module):
def collect_hooks_(self):
names = ['word_embedding_forward', 'position_embedding_forward',
'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
'branch_embedding_forward', 'branch_final_forward'
]
'attention_forward', 'cross_attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
'cross_layer_embedding_forward',
'attention_fn'
]
hooks = {}
hook_origins = {}
for name in names:
for mixin_name, m in self.mixins.items():
if hasattr(m, name):
if name in hooks: # conflict
raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
hooks[name] = getattr(m, name)
hook_origins[name] = mixin_name
if name in hooks: # if this hook name is already registered
if hasattr(getattr(m, name), 'non_conflict'):
hooks[name] = partial(getattr(m, name), old_impl=hooks[name])
hook_origins[name] = mixin_name + ' -> ' + hook_origins[name]
else: # conflict
raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
else: # new hook
hooks[name] = getattr(m, name)
hook_origins[name] = mixin_name
if hasattr(self, name):
# if name in hooks: # defined in mixins, can override
# print(f'Override {name} in {hook_origins[name]}...')
......@@ -104,4 +134,4 @@ class BaseModel(torch.nn.Module):
return hooks
def disable_untrainable_params(self):
pass
\ No newline at end of file
pass
......@@ -13,43 +13,31 @@ import math
import random
import torch
from .base_model import BaseModel, BaseMixin
from .base_model import BaseModel, BaseMixin, non_conflict
from SwissArmyTransformer.mpu.transformer import standard_attention, split_tensor_along_last_dim
class CachedAutoregressiveMixin(BaseMixin):
def __init__(self):
super().__init__()
def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, log_attention_weights=None, **kwargs):
attn_module = self.transformer.layers[layer_id].attention
mem = mems[layer_id] if mems is not None else None
mixed_raw_layer = attn_module.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
if mem is not None: # the first time, mem is None
b = mixed_key_layer.shape[0] # might change batch_size
memk, memv = split_tensor_along_last_dim(mem.expand(b, -1, -1), 2)
mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1)
mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1)
super().__init__()
@non_conflict
def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, cross_attention=False, old_impl=standard_attention,
**kw_args):
if not cross_attention:
mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size
b, nh, seq_len, hidden_size = k.shape
cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2)
kw_args['output_this_layer']['mem_kv'] = cache_kv
if mem is not None: # the first time, mem is None
# might change batch_size
mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4)
memk, memv = mem[0], mem[1]
k = torch.cat((memk, k), dim=2)
v = torch.cat((memv, v), dim=2)
return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, mems=mems, **kw_args)
# same as training
query_layer = attn_module._transpose_for_scores(mixed_query_layer)
key_layer = attn_module._transpose_for_scores(mixed_key_layer)
value_layer = attn_module._transpose_for_scores(mixed_value_layer)
context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=log_attention_weights)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = attn_module.dense(context_layer)
# new mem this layer
new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous()
return output, new_mem
class CachedAutoregressiveModel(BaseModel):
def __init__(self, args, transformer=None):
......
# -*- encoding: utf-8 -*-
'''
@File : components.py
@Time : 2021/11/23 18:20:22
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import torch
from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim
from SwissArmyTransformer.mpu.transformer import standard_attention, LayerNorm
class CrossAttention(torch.nn.Module):
def __init__(self, hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob,
init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None):
super(CrossAttention, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
if inner_hidden_size is None:
inner_hidden_size = hidden_size
self.inner_hidden_size = inner_hidden_size
if enc_hidden_size is None:
enc_hidden_size = hidden_size
self.enc_hidden_size = enc_hidden_size
# To make user understand better, temporally not support model parallel
world_size = 1
self.hidden_size_per_partition = divide(hidden_size, world_size)
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
# To map encoder outputs
self.kv_linear = torch.nn.Linear(
enc_hidden_size, inner_hidden_size * 2
)
init_method(self.kv_linear.weight)
# To map self
self.q_linear = torch.nn.Linear(
hidden_size, inner_hidden_size
)
init_method(self.q_linear.weight)
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
self.dense = torch.nn.Linear(
inner_hidden_size,
hidden_size,
)
output_layer_init_method(self.dense.weight)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, mask, encoder_outputs, **kw_args):
query_layer = self.q_linear(hidden_states)
key_layer, value_layer = split_tensor_along_last_dim(self.kv_linear(encoder_outputs), 2)
dropout_fn = self.attention_dropout if self.training else None
query_layer = self._transpose_for_scores(query_layer)
key_layer = self._transpose_for_scores(key_layer)
value_layer = self._transpose_for_scores(value_layer)
context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
if self.training:
output = self.output_dropout(output)
return output
......@@ -88,7 +88,7 @@ class Cuda2dModel(BaseModel):
output_1 = dense_plus(context_layer1)
output = torch.cat((output_0, output_1), dim=1)
return output, None
return output
def disable_untrainable_params(self):
self.transformer.requires_grad_(False)
......
......@@ -14,131 +14,83 @@ import random
import torch
import argparse
from .base_model import BaseModel, BaseMixin
from .common_layers import CrossAttention, LayerNorm
from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
class CrossAttentionMixin(BaseMixin):
def __init__(self, num_layers, hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob,
init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None):
super().__init__()
self.cross_attentions = torch.nn.ModuleList(
[CrossAttention(
hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob,
init_method, enc_hidden_size=enc_hidden_size, inner_hidden_size=inner_hidden_size,
output_layer_init_method=output_layer_init_method
) for layer_id in range(num_layers)]
) # Just copy args
self.cross_lns = torch.nn.ModuleList(
[LayerNorm(hidden_size, 1e-5)
for layer_id in range(num_layers)]
)
def layer_forward(self, hidden_states, mask, layer_id, **kw_args):
layer = self.transformer.layers[layer_id]
encoder_outputs = kw_args['encoder_outputs']
'''
hidden_states: [batch, seq_len, hidden_size]
mask: [(1, 1), seq_len, seq_len]
encoder_outputs: [batch, enc_seq_len, enc_hidden_size]
'''
# Layer norm at the begining of the transformer layer.
layernorm_output = layer.input_layernorm(hidden_states)
attention_output, output_this_layer = layer.attention(layernorm_output, mask, **kw_args)
# Third LayerNorm
if layer.sandwich_ln:
attention_output = layer.third_layernorm(attention_output)
# Residual connection.
hidden_states = hidden_states + attention_output
# Cross attention.
layernorm_output = self.cross_lns[layer_id](hidden_states)
cross_attn_output = self.cross_attentions[layer_id](
layernorm_output,
torch.ones(1, 1, device=hidden_states.device, dtype=hidden_states.dtype),
encoder_outputs
)
hidden_states = hidden_states + cross_attn_output
# Layer norm post the layer attention.
layernorm_output = layer.post_attention_layernorm(hidden_states)
# MLP.
mlp_output = layer.mlp(layernorm_output, **kw_args)
class EncoderFinalMixin(BaseMixin):
def final_forward(self, logits, **kwargs):
logits = copy_to_model_parallel_region(logits)
return logits
# Fourth LayerNorm
if layer.sandwich_ln:
mlp_output = layer.fourth_layernorm(mlp_output)
output = hidden_states + mlp_output
return output, output_this_layer
class DecoderModel(BaseModel):
def __init__(self, args, transformer=None):
dec_args = argparse.Namespace(**vars(args))
dec_args.enc_hidden_size = dec_args.hidden_size # used for cross attn
override_attrs = ['num_layers', 'vocab_size',
'hidden_size', 'num_attention_heads',
'max_sequence_length', 'sandwich_ln' # TODO
]
for name in override_attrs:
dec_attr = getattr(dec_args, 'dec_' + name, None)
if dec_attr is not None: # else use encoder-config
setattr(dec_args, name, dec_attr)
super().__init__(dec_args, transformer=transformer)
self.add_mixin('cross_attention',
CrossAttentionMixin(
dec_args.num_layers,
dec_args.hidden_size, dec_args.num_attention_heads,
dec_args.attention_dropout, dec_args.hidden_dropout,
self.transformer.init_method,
enc_hidden_size=dec_args.enc_hidden_size,
inner_hidden_size=getattr(dec_args, 'dec_inner_hidden_size', None),
output_layer_init_method=self.transformer.output_layer_init_method
)
)
class EncoderDecoderModel(torch.nn.Module):
def __init__(self, args, encoder=None, decoder=None):
def __init__(self, args, encoder=None, decoder=None, tie_word_embeddings=True, parallel_output=False, **kwargs):
super(EncoderDecoderModel, self).__init__()
if encoder is not None:
assert isinstance(encoder, BaseModel)
self.encoder = encoder
else:
self.encoder = BaseModel(args)
self.encoder = BaseModel(args, **kwargs)
self.encoder.add_mixin("final", EncoderFinalMixin())
if decoder is not None:
assert isinstance(decoder, BaseModel)
self.decoder = decoder
else:
self.decoder = DecoderModel(args)
dec_args = argparse.Namespace(**vars(args))
dec_args.enc_hidden_size = dec_args.hidden_size # used for cross attn
override_attrs = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_sequence_length', 'inner_hidden_size', 'hidden_size_per_attention_head']
for name in override_attrs:
dec_attr = getattr(dec_args, 'dec_' + name, None)
if dec_attr is not None: # else use encoder-config
setattr(dec_args, name, dec_attr)
self.decoder = BaseModel(args, is_decoder=True, parallel_output=parallel_output, **kwargs)
self.tie_word_embeddings = tie_word_embeddings
if tie_word_embeddings:
self.decoder.transformer.word_embeddings = self.encoder.transformer.word_embeddings
def reinit(self):
self.encoder.reinit()
self.decoder.reinit()
def disable_untrainable_params(self):
self.encoder.disable_untrainable_params()
self.decoder.disable_untrainable_params()
def encode(self, input_ids, position_ids, attention_mask=None, **kw_args):
encoder_outputs, *_dumps = self.encoder(input_ids, position_ids, attention_mask, **kw_args)
return encoder_outputs
def decode(self, input_ids, position_ids, attention_mask, encoder_outputs,cross_attention_mask=None, **kw_args):
if attention_mask is None:
batch_size, seq_length = input_ids.size()[:2]
seq_ids = torch.arange(seq_length, device=input_ids.device)
attention_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
attention_mask = attention_mask.to(self.decoder.transformer.word_embeddings.weight.dtype)
attention_mask = attention_mask[:, None, :, :]
# If no context, please explicitly pass ``encoder_outputs=None''
return self.decoder(input_ids, position_ids, attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args)
def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids, dec_attention_mask, *, branch_input=None, **kw_args):
mask_one = torch.ones(1, 1, device=enc_input_ids.device, dtype=dec_attention_mask.dtype)
enc_outputs, *_dumps = self.encoder(enc_input_ids, enc_position_ids, mask_one, branch_input=branch_input, **kw_args)
dec_outputs, *dec_mems = self.decoder(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=enc_outputs, branch_input=branch_input, **kw_args)
return enc_outputs, dec_outputs, *dec_mems
def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids, *, enc_attention_mask=None, dec_attention_mask=None, cross_attention_mask=None, **kw_args):
# Please use self.decoder for auto-regressive generation.
batch_size, seq_length = enc_input_ids.size()[:2]
if enc_attention_mask is None:
enc_attention_mask = torch.ones(1, 1, 1, seq_length, dtype=self.encoder.transformer.word_embeddings.weight.dtype, device=enc_input_ids.device)
if cross_attention_mask is None:
cross_attention_mask = enc_attention_mask
encoder_outputs = self.encode(enc_input_ids, enc_position_ids, enc_attention_mask, **kw_args)
decoder_outputs, *mems = self.decode(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args)
return encoder_outputs, decoder_outputs, *mems
@classmethod
def add_model_specific_args(cls, parser):
group = parser.add_argument_group('EncoderDecoderModel', 'T5 or Bart')
group.add_argument("--dec_num_layers", type=int, default=None)
group.add_argument("--dec_vocab_size", type=int, default=None)
group.add_argument("--dec_hidden_size", type=int, default=None)
group.add_argument("--dec_num_attention_heads", type=int, default=None)
group.add_argument("--dec_max_sequence_length", type=int, default=None)
group.add_argument("--dec_sandwich_ln", action='store_true')
group.add_argument("--dec_inner_hidden_size", type=int, default=None)
return parser
\ No newline at end of file
group.add_argument("--dec-num-layers", type=int, default=None)
group.add_argument("--dec-hidden-size", type=int, default=None)
group.add_argument("--dec-num-attention-heads", type=int, default=None)
group.add_argument("--dec-max-sequence-length", type=int, default=None)
group.add_argument("--dec-inner-hidden-size", type=int, default=None)
group.add_argument("--dec-hidden-size-per-attention-head", type=int, default=None)
return parser
from .mlp_head import MLPHeadMixin
from .prompt_tuning import PrefixTuningMixin, PTuningV2Mixin
\ No newline at end of file
# -*- encoding: utf-8 -*-
'''
@File : mlp_head.py
@Time : 2021/12/12 20:44:09
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import torch
from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin, non_conflict
class MLPHeadMixin(BaseMixin):
def __init__(self, hidden_size, *output_sizes, bias=True, activation_func=torch.nn.functional.relu, init_mean=0, init_std=0.005):
super().__init__()
self.activation_func = activation_func
last_size = hidden_size
self.layers = torch.nn.ModuleList()
for sz in output_sizes:
this_layer = torch.nn.Linear(last_size, sz, bias=bias)
last_size = sz
torch.nn.init.normal_(this_layer.weight, mean=init_mean, std=init_std)
self.layers.append(this_layer)
def final_forward(self, logits, **kw_args):
for i, layer in enumerate(self.layers):
if i > 0:
logits = self.activation_func(logits)
logits = layer(logits)
return logits
\ No newline at end of file
# -*- encoding: utf-8 -*-
'''
@File : prompt_tuning.py
@Time : 2021/12/12 20:45:18
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import torch
from SwissArmyTransformer.mpu.transformer import standard_attention
from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin, non_conflict
class PrefixTuningMixin(BaseMixin):
def __init__(self, num_layers, hidden_size_per_attention_head, num_attention_heads, prefix_len):
super().__init__()
self.prefix = torch.nn.ParameterList([
torch.nn.Parameter(torch.randn(2, num_attention_heads, prefix_len, hidden_size_per_attention_head)*0.01)
for layer_id in range(num_layers)
])
self.prefix_len = prefix_len
@non_conflict
def attention_fn(self, q, k, v, mask, dropout_fn, old_impl=standard_attention, **kw_args):
prefix_k, prefix_v = self.prefix[kw_args['layer_id']]
b, nh, seq_len, hidden_size = k.shape
prefix_k = prefix_k.unsqueeze(0).expand(b, nh, -1, hidden_size)
prefix_v = prefix_v.unsqueeze(0).expand(b, nh, -1, hidden_size)
k = torch.cat((k, prefix_k), dim=2)
v = torch.cat((v, prefix_v), dim=2)
if mask.numel() > 1:
mask_prefixed = torch.ones(self.prefix_len, device=mask.device, dtype=mask.dtype)
mask_prefixed = mask_prefixed.expand(*(mask.size()[:-1]), -1)
mask = torch.cat((mask, mask_prefixed), dim=-1)
return old_impl(q, k, v, mask, dropout_fn, **kw_args)
PTuningV2Mixin = PrefixTuningMixin
\ No newline at end of file
......@@ -17,41 +17,45 @@ from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
from SwissArmyTransformer.mpu.transformer import unscaled_init_method
from .base_model import BaseMixin
from .cached_autoregressive_model import CachedAutoregressiveMixin
from .finetune import *
class PositionEmbeddingMixin(BaseMixin):
def __init__(self, additional_sequence_length, hidden_size,
init_method_std=0.02, reinit_slice=slice(-1024, None)
):
def __init__(self, additional_sequence_length, hidden_size,
init_method_std=0.02, reinit_slice=slice(-1024, None)
):
super(PositionEmbeddingMixin, self).__init__()
self.reinit_slice = reinit_slice
self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
def reinit(self, *pre_mixins):
old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
old_len, hidden_size = old_weights.shape
assert hidden_size == self.position_embeddings.weight.shape[-1]
self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
class AttentionMixin(BaseMixin):
def __init__(self, num_layers,
hidden_size,
init_method=unscaled_init_method(0.02),
output_layer_init_method=unscaled_init_method(0.02)
):
hidden_size,
init_method=unscaled_init_method(0.02),
output_layer_init_method=unscaled_init_method(0.02)
):
super(AttentionMixin, self).__init__()
self.num_layers = num_layers # replace attention in the LAST n layers
self.num_layers = num_layers # replace attention in the LAST n layers
self.query_key_value = torch.nn.ModuleList(
[ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
gather_output=False,init_method=init_method)
for layer_id in range(num_layers)
])
[ColumnParallelLinear(hidden_size, 3 * hidden_size, stride=3,
gather_output=False, init_method=init_method)
for layer_id in range(num_layers)
])
self.dense = torch.nn.ModuleList(
[RowParallelLinear(hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method)
for layer_id in range(num_layers)
])
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method)
for layer_id in range(num_layers)
])
def reinit(self, *pre_mixins):
start_layer = len(self.transformer.layers) - self.num_layers
assert start_layer >= 0
......
import math
import torch
import torch.nn.functional as F
from .mixins import BaseMixin
from .encoder_decoder_model import EncoderDecoderModel
from .base_model import non_conflict
from SwissArmyTransformer.mpu import get_model_parallel_world_size
from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP
from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim, unscaled_init_method
from SwissArmyTransformer.mpu.layers import ColumnParallelLinear, VocabParallelEmbedding
class T5PositionEmbeddingMixin(BaseMixin):
def position_embedding_forward(self, position_ids, **kw_args):
return None
class T5LayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
"""
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
# layer norm should always be calculated in float32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into float16 or bfloat16 if necessary
if self.weight.dtype == torch.float16:
hidden_states = hidden_states.to(torch.float16)
elif self.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.bfloat16)
return self.weight * hidden_states
class T5AttentionMixin(BaseMixin):
def __init__(self, relative_attention_num_buckets, num_attention_heads, is_decoder=False):
super().__init__()
self.relative_attention_num_buckets = relative_attention_num_buckets
world_size = get_model_parallel_world_size()
self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets,
self.num_attention_heads_per_partition)
self.is_decoder = is_decoder
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_postion_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_postion_if_large = torch.min(
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length):
"""Compute binned relative position bias"""
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
)
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
# shape (query_length, key_length, num_heads)
values = self.relative_attention_bias(relative_position_bucket)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
@non_conflict
def attention_fn(self, q, k, v, mask, dropout_fn, position_bias=None, old_impl=standard_attention,
cross_attention=False, **kw_args):
log_attention_weights = None
if not cross_attention:
if position_bias is None:
seq_length = q.size(2)
key_length = k.size(2)
position_bias = self.compute_bias(key_length, key_length)
position_bias = position_bias[:, :, -seq_length:, :]
kw_args['output_cross_layer']['position_bias'] = position_bias
log_attention_weights = position_bias
return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, position_bias=position_bias,
log_attention_weights=log_attention_weights, scaling_attention_score=False, **kw_args)
class T5DecoderFinalMixin(BaseMixin):
def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True):
super().__init__()
self.hidden_size = hidden_size
self.tie_word_embeddings = tie_word_embeddings
if not tie_word_embeddings:
self.lm_head = VocabParallelEmbedding(
vocab_size, hidden_size, init_method=unscaled_init_method(0.02))
def final_forward(self, logits, **kwargs):
logits_parallel = copy_to_model_parallel_region(logits)
if self.tie_word_embeddings:
logits_parallel = logits_parallel * (self.hidden_size ** -0.5)
logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight)
else:
logits_parallel = F.linear(logits_parallel, self.lm_head.weight)
return logits_parallel
def t5_gelu(x):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class T5GatedGeluMLPMixin(BaseMixin):
def __init__(self, num_layers, hidden_size, inner_hidden_size=None, bias=True, init_method_std=0.02):
super().__init__()
self.hidden_size = hidden_size
if inner_hidden_size is None:
inner_hidden_size = 4 * hidden_size
self.inner_hidden_size = inner_hidden_size
self.init_method_std = init_method_std
self.gated_h_to_4h_list = torch.nn.ModuleList([
ColumnParallelLinear(
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=self._init_weights,
bias=bias,
module=self,
name="gated_h_to_4h"
)
for layer_id in range(num_layers)])
def _init_weights(self, weight, **kwargs):
torch.nn.init.normal_(weight, mean=0, std=self.init_method_std * (self.hidden_size ** -0.5))
def mlp_forward(self, hidden_states, layer_id=None, **kw_args):
mlp_module = self.transformer.layers[layer_id].mlp
hidden_gelu = t5_gelu(mlp_module.dense_h_to_4h(hidden_states))
hidden_linear = self.gated_h_to_4h_list[layer_id](hidden_states)
hidden_states = hidden_gelu * hidden_linear
output = mlp_module.dense_4h_to_h(hidden_states)
if self.training:
output = mlp_module.dropout(output)
return output
class T5Model(EncoderDecoderModel):
def __init__(self, args, **kwargs):
self.init_method_std = args.init_method_std
super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False,
layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu,
init_method=self._init_weights)
self.encoder.add_mixin(
"t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads)
)
self.encoder.add_mixin(
"t5-position", T5PositionEmbeddingMixin()
)
del self.encoder.transformer.position_embeddings
num_attention_heads = args.dec_num_attention_heads if args.dec_num_attention_heads is not None else args.num_attention_heads
self.decoder.add_mixin(
"t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, num_attention_heads, is_decoder=True)
)
self.decoder.add_mixin(
"t5-position", T5PositionEmbeddingMixin()
)
self.decoder.add_mixin(
"t5-final",
T5DecoderFinalMixin(args.vocab_size, args.hidden_size, tie_word_embeddings=not args.no_share_embeddings)
)
del self.decoder.transformer.position_embeddings
if args.gated_gelu_mlp:
self.encoder.add_mixin(
"gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std,
inner_hidden_size=args.inner_hidden_size, bias=False)
)
self.decoder.add_mixin(
"gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std,
inner_hidden_size=args.inner_hidden_size, bias=False)
)
def _init_weights(self, weight, module, name):
init_method_std = self.init_method_std
if isinstance(module, MLP):
if name == "dense_h_to_4h":
torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
elif name == "dense_4h_to_h":
torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
else:
raise NotImplementedError(name)
elif isinstance(module, SelfAttention):
if name == "query_key_value":
torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
torch.nn.init.normal_(weight[:module.inner_hidden_size], mean=0, std=init_method_std * (
(module.hidden_size * module.hidden_size_per_attention_head) ** -0.5))
elif name == "dense":
torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
else:
raise NotImplementedError(name)
elif isinstance(module, CrossAttention):
if name == "query":
torch.nn.init.normal_(weight, mean=0, std=init_method_std * (
(module.hidden_size * module.hidden_size_per_attention_head) ** -0.5))
elif name == "key_value":
torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
elif name == "dense":
torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
else:
raise NotImplementedError(name)
else:
raise NotImplementedError(module)
@classmethod
def add_model_specific_args(cls, parser):
super().add_model_specific_args(parser)
parser.add_argument("--relative-attention-num-buckets", type=int, default=None)
parser.add_argument("--init-method-std", type=float, default=0.02)
parser.add_argument("--gated-gelu-mlp", action='store_true')
parser.add_argument("--no-share-embeddings", action='store_true')
def encode(self, input_ids, attention_mask=None, **kw_args):
return super().encode(input_ids, None, attention_mask, **kw_args)
def decode(self, input_ids, attention_mask=None, encoder_outputs=None, cross_attention_mask=None, **kw_args):
return super().decode(input_ids, None, attention_mask, encoder_outputs=encoder_outputs,
cross_attention_mask=cross_attention_mask, **kw_args)
def forward(self, enc_input_ids, dec_input_ids, *, enc_attention_mask=None, dec_attention_mask=None,
cross_attention_mask=None, **kw_args):
batch_size, seq_length = enc_input_ids.size()[:2]
if enc_attention_mask is None:
enc_attention_mask = torch.ones(1, 1, 1, seq_length,
dtype=self.encoder.transformer.word_embeddings.weight.dtype,
device=enc_input_ids.device)
if cross_attention_mask is None:
cross_attention_mask = enc_attention_mask
encoder_outputs = self.encode(enc_input_ids, enc_attention_mask, **kw_args)
decoder_outputs, *mems = self.decode(dec_input_ids, dec_attention_mask,
encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask,
**kw_args)
return encoder_outputs, decoder_outputs, *mems
......@@ -37,7 +37,7 @@ from .utils import VocabUtility
def _initialize_affine_weight(weight, output_size, input_size,
per_partition_size, partition_dim, init_method,
stride=1, return_master_weight=False):
stride=1, return_master_weight=False, module=None, name=None):
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
......@@ -45,7 +45,7 @@ def _initialize_affine_weight(weight, output_size, input_size,
# If we only use 1 process for model parallelism, bypass scatter.
world_size = get_model_parallel_world_size()
if world_size == 1:
init_method(weight)
init_method(weight, module=module, name=name)
if return_master_weight:
return weight
return None
......@@ -54,7 +54,7 @@ def _initialize_affine_weight(weight, output_size, input_size,
master_weight = torch.empty(output_size, input_size,
dtype=weight.dtype,
requires_grad=False)
init_method(master_weight)
init_method(master_weight, module=module, name=name)
# Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride)
......@@ -200,7 +200,7 @@ class ColumnParallelLinear(torch.nn.Module):
"""
def __init__(self, input_size, output_size, bias=True, gather_output=True,
init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False):
keep_master_weight_for_test=False, module=None, name=None):
super(ColumnParallelLinear, self).__init__()
# Keep input parameters
......@@ -230,7 +230,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.master_weight = _initialize_affine_weight(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name)
def forward(self, input_):
# Set up backprop all-reduce.
......@@ -274,7 +274,7 @@ class RowParallelLinear(torch.nn.Module):
def __init__(self, input_size, output_size, bias=True,
input_is_parallel=False,
init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False):
keep_master_weight_for_test=False, module=None, name=None):
super(RowParallelLinear, self).__init__()
# Keep input parameters
......@@ -303,7 +303,7 @@ class RowParallelLinear(torch.nn.Module):
self.master_weight = _initialize_affine_weight(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name)
def forward(self, input_):
# Set up backprop all-reduce.
......
......@@ -22,55 +22,62 @@ import torch
import torch.nn.functional as F
from apex.normalization.fused_layer_norm import FusedLayerNorm
from SwissArmyTransformer import mpu
from .initialize import get_model_parallel_world_size
from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from .mappings import gather_from_model_parallel_region, copy_to_model_parallel_region
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint, get_cuda_rng_tracker
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
from .utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu
from .utils import split_tensor_along_last_dim
class LayerNorm(FusedLayerNorm):
def __init__(self, *args, pb_relax=False, **kwargs):
super().__init__(*args, **kwargs)
self.pb_relax = pb_relax
def forward(self, x):
if not self.pb_relax:
return super().forward(x)
return super().forward(x / (x.abs().max().detach()/8))
return super().forward(x / (x.abs().max().detach() / 8))
def standard_attention(query_layer, key_layer, value_layer, attention_mask,
attention_dropout=None, log_attention_weights=None):
attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs):
# We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training.
# The implementation in the paper can be done very easily, if you really need it to train very deep transformers.
attention_scores = torch.matmul(
query_layer / math.sqrt(query_layer.shape[-1]),
key_layer.transpose(-1, -2)
)
if scaling_attention_score:
query_layer = query_layer / math.sqrt(query_layer.shape[-1])
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if log_attention_weights is not None:
attention_scores += log_attention_weights
if not(attention_mask.shape[-2] == 1 and (attention_mask > 0).all()):
if not (attention_mask.shape[-2] == 1 and (attention_mask > 0).all()):
# if auto-regressive, skip
attention_scores = torch.mul(attention_scores, attention_mask) - \
10000.0 * (1.0 - attention_mask)
10000.0 * (1.0 - attention_mask)
attention_probs = F.softmax(attention_scores, dim=-1)
if attention_dropout is not None:
with get_cuda_rng_tracker().fork():
if mpu.get_cuda_rng_tracker is not None:
with mpu.get_cuda_rng_tracker().fork():
attention_probs = attention_dropout(attention_probs)
else:
attention_probs = attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
return context_layer
class SelfAttention(torch.nn.Module):
def __init__(self, hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob,
init_method, layer_id, output_layer_init_method=None,
hooks={}):
attention_dropout_prob, output_dropout_prob,
init_method, layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True,
hooks={}):
super(SelfAttention, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
......@@ -79,47 +86,61 @@ class SelfAttention(torch.nn.Module):
self.layer_id = layer_id
# Per attention head and per partition values.
world_size = get_model_parallel_world_size()
self.hidden_size_per_partition = divide(hidden_size, world_size)
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
self.hidden_size = hidden_size
if hidden_size_per_attention_head is None:
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
else:
self.hidden_size_per_attention_head = hidden_size_per_attention_head
self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
# Strided linear layer.
self.query_key_value = ColumnParallelLinear(
hidden_size,
3*hidden_size,
3 * self.inner_hidden_size,
stride=3,
gather_output=False,
init_method=init_method
init_method=init_method,
bias=bias,
module=self,
name="query_key_value"
)
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
self.dense = RowParallelLinear(
hidden_size,
self.inner_hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method
init_method=output_layer_init_method,
bias=bias,
module=self,
name="dense"
)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition,
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, mask, **kw_args):
def forward(self, hidden_states, mask, *args, **kw_args):
if 'attention_forward' in self.hooks:
return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id)
return self.hooks['attention_forward'](hidden_states, mask, **kw_args)
else:
attention_fn = standard_attention
if 'attention_fn' in self.hooks:
attention_fn = self.hooks['attention_fn']
mixed_raw_layer = self.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
mixed_key_layer,
mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
dropout_fn = self.attention_dropout if self.training else None
......@@ -127,7 +148,8 @@ class SelfAttention(torch.nn.Module):
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn)
context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
......@@ -136,40 +158,137 @@ class SelfAttention(torch.nn.Module):
if self.training:
output = self.output_dropout(output)
return output, None
return output
class CrossAttention(torch.nn.Module):
"""Parallel cross-attention layer for Transformer"""
def __init__(self, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, init_method,
layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True, hooks={}):
super().__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
self.hooks = hooks
self.layer_id = layer_id
# Per attention head and per partition values.
world_size = get_model_parallel_world_size()
self.hidden_size = hidden_size
if hidden_size_per_attention_head is None:
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
else:
self.hidden_size_per_attention_head = hidden_size_per_attention_head
self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
# Strided linear layer.
self.query = ColumnParallelLinear(hidden_size, self.inner_hidden_size,
gather_output=False,
init_method=init_method, bias=bias, module=self, name="query")
self.key_value = ColumnParallelLinear(hidden_size, 2 * self.inner_hidden_size,
stride=2,
gather_output=False,
init_method=init_method, bias=bias, module=self, name="key_value")
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
# Output.
self.dense = RowParallelLinear(
self.inner_hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method, bias=bias, module=self, name="dense")
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args):
# hidden_states: [b, s, h]
if 'cross_attention_forward' in self.hooks:
return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, **kw_args)
else:
attention_fn = standard_attention
if 'attention_fn' in self.hooks:
attention_fn = self.hooks['attention_fn']
mixed_query_layer = self.query(hidden_states)
mixed_x_layer = self.key_value(encoder_outputs)
(mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2)
dropout_fn = self.attention_dropout if self.training else None
# Reshape and transpose [b, np, s, hn]
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
context_layer = attention_fn(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn,
cross_attention=True, **kw_args)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [b, s, h]
output = self.dense(context_layer)
if self.training:
output = self.output_dropout(output)
return output
class MLP(torch.nn.Module):
def __init__(self, hidden_size, output_dropout_prob, init_method,
output_layer_init_method=None, layer_id=None, hooks={}):
def __init__(self, hidden_size, output_dropout_prob, init_method, inner_hidden_size=None,
output_layer_init_method=None, layer_id=None, hooks={}, bias=True, activation_func=gelu):
super(MLP, self).__init__()
self.layer_id = layer_id
self.activation_func = activation_func
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
self.hooks = hooks
# Project to 4h.
self.hidden_size = hidden_size
if inner_hidden_size is None:
inner_hidden_size = 4 * hidden_size
self.inner_hidden_size = inner_hidden_size
self.dense_h_to_4h = ColumnParallelLinear(
hidden_size,
4*hidden_size,
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=init_method
init_method=init_method,
bias=bias,
module=self,
name="dense_h_to_4h"
)
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
4*hidden_size,
hidden_size,
self.inner_hidden_size,
self.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method
init_method=output_layer_init_method,
bias=bias,
module=self,
name="dense_4h_to_h"
)
self.dropout = torch.nn.Dropout(output_dropout_prob)
def forward(self, hidden_states, **kw_args):
if 'mlp_forward' in self.hooks:
output = self.hooks['mlp_forward'](hidden_states, **kw_args, layer_id=self.layer_id)
output = self.hooks['mlp_forward'](hidden_states, **kw_args)
else:
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = gelu(intermediate_parallel)
intermediate_parallel = self.activation_func(intermediate_parallel)
output = self.dense_4h_to_h(intermediate_parallel)
if self.training:
......@@ -179,27 +298,34 @@ class MLP(torch.nn.Module):
class BaseTransformerLayer(torch.nn.Module):
def __init__(
self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
init_method,
layer_id,
output_layer_init_method=None,
sandwich_ln=True,
hooks={}
self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
init_method,
layer_id,
inner_hidden_size=None,
hidden_size_per_attention_head=None,
output_layer_init_method=None,
sandwich_ln=True,
layernorm=LayerNorm,
is_decoder=False,
use_bias=True,
activation_func=gelu,
hooks={}
):
super(BaseTransformerLayer, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
self.layer_id = layer_id
self.is_decoder = is_decoder
self.hooks = hooks
# Layernorm on the input data.
self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
# Self attention.
self.attention = SelfAttention(
......@@ -209,37 +335,57 @@ class BaseTransformerLayer(torch.nn.Module):
output_dropout_prob,
init_method,
layer_id,
hidden_size_per_attention_head=hidden_size_per_attention_head,
output_layer_init_method=output_layer_init_method,
bias=use_bias,
hooks=hooks
)
# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
self.sandwich_ln = sandwich_ln
if sandwich_ln:
self.third_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.fourth_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.third_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
self.fourth_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
# Cross attention.
if self.is_decoder:
self.cross_attention = CrossAttention(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
init_method,
layer_id,
hidden_size_per_attention_head=hidden_size_per_attention_head,
output_layer_init_method=output_layer_init_method,
bias=use_bias,
hooks=hooks
)
self.post_cross_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
# MLP
self.mlp = MLP(
hidden_size,
output_dropout_prob,
init_method,
inner_hidden_size=inner_hidden_size,
output_layer_init_method=output_layer_init_method,
bias=use_bias,
layer_id=layer_id,
activation_func=activation_func,
hooks=hooks
)
def forward(self, hidden_states, mask, **kw_args):
def forward(self, hidden_states, mask, *args, **kw_args):
'''
hidden_states: [batch, seq_len, hidden_size]
mask: [(1, 1), seq_len, seq_len]
'''
# Layer norm at the begining of the transformer layer.
layernorm_output1 = self.input_layernorm(hidden_states)
# Self attention.
attention_output, output_this_layer = self.attention(layernorm_output1, mask, **kw_args)
attention_output = self.attention(layernorm_output1, mask, **kw_args)
# Third LayerNorm
if self.sandwich_ln:
......@@ -249,6 +395,18 @@ class BaseTransformerLayer(torch.nn.Module):
layernorm_input = hidden_states + attention_output
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.is_decoder:
encoder_outputs = kw_args['encoder_outputs']
if encoder_outputs is not None:
assert 'cross_attention_mask' in kw_args
# Cross attention
attention_output = self.cross_attention(layernorm_output, **kw_args)
# Residual connection.
layernorm_input = layernorm_input + attention_output
# Layer norm post the cross attention
layernorm_output = self.post_cross_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output, **kw_args)
......@@ -259,7 +417,8 @@ class BaseTransformerLayer(torch.nn.Module):
# Second residual connection.
output = layernorm_input + mlp_output
return output, output_this_layer # temporally, output_this_layer is only from attention
return output, kw_args['output_this_layer'], kw_args['output_cross_layer']
class BaseTransformer(torch.nn.Module):
def __init__(self,
......@@ -275,18 +434,26 @@ class BaseTransformer(torch.nn.Module):
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
inner_hidden_size=None,
hidden_size_per_attention_head=None,
sandwich_ln=True,
parallel_output=True,
is_decoder=False,
use_bias=True,
activation_func=gelu,
layernorm=LayerNorm,
init_method=None,
hooks={}
):
super(BaseTransformer, self).__init__()
# recording parameters
self.is_decoder = is_decoder
self.parallel_output = parallel_output
self.checkpoint_activations = checkpoint_activations
self.checkpoint_num_layers = checkpoint_num_layers
self.max_sequence_length = max_sequence_length
self.hooks = copy.copy(hooks) # hooks will be updated each forward
self.hooks = copy.copy(hooks) # hooks will be updated each forward
# create embedding parameters
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
......@@ -298,8 +465,13 @@ class BaseTransformer(torch.nn.Module):
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
# create all layers
self.output_layer_init_method = scaled_init_method(init_method_std, num_layers)
self.init_method = unscaled_init_method(init_method_std)
if init_method is None:
self.output_layer_init_method = scaled_init_method(init_method_std, num_layers)
self.init_method = unscaled_init_method(init_method_std)
else:
self.output_layer_init_method = init_method
self.init_method = init_method
def get_layer(layer_id):
return BaseTransformerLayer(
hidden_size,
......@@ -309,32 +481,39 @@ class BaseTransformer(torch.nn.Module):
layernorm_epsilon,
self.init_method,
layer_id,
inner_hidden_size=inner_hidden_size,
hidden_size_per_attention_head=hidden_size_per_attention_head,
output_layer_init_method=self.output_layer_init_method,
is_decoder=self.is_decoder,
sandwich_ln=sandwich_ln,
layernorm=layernorm,
use_bias=use_bias,
activation_func=activation_func,
hooks=self.hooks
)
)
self.layers = torch.nn.ModuleList(
[get_layer(layer_id) for layer_id in range(num_layers)])
# Final layer norm before output.
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, output_hidden_states=False,
**kw_args):
# sanity check
def forward(self, input_ids, position_ids, attention_mask, *,
output_hidden_states=False, **kw_args):
# sanity check
assert len(input_ids.shape) == 2
batch_size, query_length = input_ids.shape
if attention_mask is None:
attention_mask = torch.ones(1, 1, device=input_ids.device).type_as(
next(self.parameters())
) # None means full attention
assert len(attention_mask.shape) == 2 or \
len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor)
# branch_input is a new part of input need layer-by-layer update,
# but with different hidden_dim and computational routine.
# In most cases, you can just ignore it.
len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
# embedding part
if 'word_embedding_forward' in self.hooks:
hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args)
else: # default
else: # default
hidden_states = self.word_embeddings(input_ids)
if 'position_embedding_forward' in self.hooks:
......@@ -343,62 +522,132 @@ class BaseTransformer(torch.nn.Module):
assert len(position_ids.shape) <= 2
assert position_ids.shape[-1] == query_length
position_embeddings = self.position_embeddings(position_ids)
hidden_states = hidden_states + position_embeddings
if position_embeddings is not None:
hidden_states = hidden_states + position_embeddings
hidden_states = self.embedding_dropout(hidden_states)
hidden_states_outputs = [hidden_states] if output_hidden_states else []
# branch related embedding
if branch_input is None and 'branch_embedding_forward' in self.hooks:
branch_input = self.hooks['branch_embedding_forward'](branch_input, **kw_args)
# initial output_cross_layer
if 'cross_layer_embedding_forward' in self.hooks:
output_cross_layer = self.hooks['cross_layer_embedding_forward'](hidden_states, **kw_args)
else:
output_cross_layer = {}
# define custom_forward for checkpointing
output_per_layers = []
if self.checkpoint_activations:
def custom(start, end):
# define custom_forward for checkpointing
def custom(start, end, kw_args_index, cross_layer_index):
def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_, mask = inputs[0], inputs[1]
if len(inputs) > 2: # have branch_input
branch_ = inputs[2]
# recover kw_args and output_cross_layer
flat_inputs = inputs[2:]
kw_args, output_cross_layer = {}, {}
for k, idx in kw_args_index.items():
kw_args[k] = flat_inputs[idx]
for k, idx in cross_layer_index.items():
output_cross_layer[k] = flat_inputs[idx]
# -----------------
output_per_layers_part = []
for i, layer in enumerate(layers_):
if len(inputs) > 2:
x_, branch_, output_this_layer = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args
)
elif 'layer_forward' in self.hooks:
x_, output_this_layer = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id, **kw_args
if 'layer_forward' in self.hooks:
layer_ret = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id,
**kw_args, **output_cross_layer,
output_this_layer={}, output_cross_layer={}
)
else:
x_, output_this_layer = layer(x_, mask, **kw_args)
layer_ret = layer(
x_, mask, layer_id=layer.layer_id,
**kw_args, **output_cross_layer,
output_this_layer={}, output_cross_layer={}
)
if torch.is_tensor(layer_ret): # only hidden_states
x_, output_this_layer, output_cross_layer = layer_ret, {}, {}
elif len(layer_ret) == 2: # hidden_states & output_this_layer
x_, output_this_layer = layer_ret
output_cross_layer = {}
elif len(layer_ret) == 3:
x_, output_this_layer, output_cross_layer = layer_ret
assert isinstance(output_this_layer, dict)
assert isinstance(output_cross_layer, dict)
if output_hidden_states:
output_this_layer['hidden_states'] = x_
output_per_layers_part.append(output_this_layer)
return x_, output_per_layers_part
# flatten for re-aggregate keywords outputs
flat_outputs = []
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
# TODO add warning for depth>=2 grad tensors
flat_outputs.append(output_this_layer[k])
output_this_layer[k] = len(flat_outputs) - 1
for k in output_cross_layer:
flat_outputs.append(output_cross_layer[k])
output_cross_layer[k] = len(flat_outputs) - 1
# --------------------
return x_, output_per_layers_part, output_cross_layer, flat_outputs
return custom_forward
# prevent to lose requires_grad in checkpointing.
# To save memory when only finetuning the final layers, don't use checkpointing.
if self.training:
hidden_states.requires_grad_(True)
l, num_layers = 0, len(self.layers)
chunk_length = self.checkpoint_num_layers
output_this_layer = []
while l < num_layers:
args = [hidden_states, attention_mask]
if branch_input is not None:
hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, branch_input)
else:
hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args)
if output_hidden_states:
hidden_states_outputs.append(hidden_states)
# flatten kw_args and output_cross_layer
flat_inputs, kw_args_index, cross_layer_index = [], {}, {}
for k, v in kw_args.items():
flat_inputs.append(v)
kw_args_index[k] = len(flat_inputs) - 1
for k, v in output_cross_layer.items():
flat_inputs.append(v)
cross_layer_index[k] = len(flat_inputs) - 1
# --------------------
hidden_states, output_per_layers_part, output_cross_layer, flat_outputs = \
checkpoint(custom(l, l + chunk_length, kw_args_index, cross_layer_index), *args, *flat_inputs)
# recover output_per_layers_part, output_cross_layer
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
output_this_layer[k] = flat_outputs[output_this_layer[k]]
for k in output_cross_layer:
output_cross_layer[k] = flat_outputs[output_cross_layer[k]]
# --------------------
output_per_layers.extend(output_per_layers_part)
l += chunk_length
else:
output_this_layer = []
for i, layer in enumerate(self.layers):
args = [hidden_states, attention_mask]
if branch_input is not None: # customized layer_forward with branch_input
hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_args)
elif 'layer_forward' in self.hooks: # customized layer_forward
hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args)
if 'layer_forward' in self.hooks: # customized layer_forward
layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i),
**kw_args,
**output_cross_layer,
output_this_layer={}, output_cross_layer={}
)
else:
hidden_states, output_this_layer = layer(*args, **kw_args)
layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer,
output_this_layer={}, output_cross_layer={})
if torch.is_tensor(layer_ret): # only hidden_states
hidden_states, output_this_layer, output_cross_layer = layer_ret, {}, {}
elif len(layer_ret) == 2: # hidden_states & output_this_layer
hidden_states, output_this_layer = layer_ret
output_cross_layer = {}
elif len(layer_ret) == 3:
hidden_states, output_this_layer, output_cross_layer = layer_ret
if output_hidden_states:
hidden_states_outputs.append(hidden_states)
output_this_layer['hidden_states'] = hidden_states
output_per_layers.append(output_this_layer)
# Final layer norm.
......@@ -410,19 +659,10 @@ class BaseTransformer(torch.nn.Module):
logits_parallel = copy_to_model_parallel_region(logits)
logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight)
# branch related embedding
if branch_input is None and 'branch_final_forward' in self.hooks:
branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args)
if not self.parallel_output:
logits_parallel = gather_from_model_parallel_region(logits_parallel)
outputs = [logits_parallel]
if branch_input is not None:
outputs.append(branch_input)
if output_hidden_states:
outputs.append(hidden_states_outputs)
outputs.extend(output_per_layers)
return outputs
......@@ -75,7 +75,7 @@ def sqrt(x):
def unscaled_init_method(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
def init_(tensor, **kwargs):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
......@@ -83,7 +83,7 @@ def unscaled_init_method(sigma):
def scaled_init_method(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
def init_(tensor, **kwargs):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
......
......@@ -29,7 +29,8 @@ def _export_vocab_size_to_args(args, original_num_tokens):
print_rank_0('> padded vocab (size: {}) with {} dummy '
'tokens (new size: {})'.format(
before, after - before, after))
args.vocab_size = after
if not args.vocab_size:
args.vocab_size = after
print_rank_0("prepare tokenizer done")
return tokenizer
......@@ -51,7 +52,7 @@ def get_tokenizer(args=None, outer_tokenizer=None):
from .cogview import UnifiedTokenizer
get_tokenizer.tokenizer = UnifiedTokenizer(
args.img_tokenizer_path,
# txt_tokenizer_type=args.tokenizer_type,
txt_tokenizer_type='cogview',
device=torch.cuda.current_device()
)
elif args.tokenizer_type.startswith('glm'):
......@@ -63,6 +64,10 @@ def get_tokenizer(args=None, outer_tokenizer=None):
elif args.tokenizer_type == "glm_ChineseSPTokenizer":
from .glm import ChineseSPTokenizer
get_tokenizer.tokenizer = ChineseSPTokenizer(args.tokenizer_model_type, **kwargs)
elif args.tokenizer_type.startswith('hf'):
from .hf_tokenizer import HFT5Tokenizer
if args.tokenizer_type == "hf_T5Tokenizer":
get_tokenizer.tokenizer = HFT5Tokenizer(args.tokenizer_model_type, cache_dir=args.cache_dir)
else:
assert args.vocab_size > 0
get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size)
......
......@@ -22,6 +22,8 @@ python setup.py install
PRETRAINED_MODEL_FILE = os.path.join(os.path.dirname(os.path.dirname(__file__)),
'embed_assets', 'chinese_sentencepiece/cog-pretrain.model')
PRETRAINED_MODEL_FILE_ICE = os.path.join(os.path.dirname(os.path.dirname(__file__)),
'embed_assets', 'chinese_sentencepiece/ice.model') # merge xlnet 3,2000 En tokens
def get_pairs(word):
......@@ -148,5 +150,10 @@ def get_encoder(encoder_file, bpe_file):
)
def from_pretrained():
return get_encoder(PRETRAINED_MODEL_FILE, "")
\ No newline at end of file
def from_pretrained(tokenizer_type='cogview'):
if tokenizer_type == 'cogview_ICE':
return get_encoder(PRETRAINED_MODEL_FILE_ICE, "")
elif tokenizer_type == 'cogview':
return get_encoder(PRETRAINED_MODEL_FILE, "")
else:
raise ValueError('Unknown cogview tokenizer.')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment