diff --git a/SwissArmyTransformer/arguments.py b/SwissArmyTransformer/arguments.py index dfa75c910489d592262a7483c6b0cfa2d901660b..83e3322219a3d8f17dd4f1ced5e78223744ca91d 100755 --- a/SwissArmyTransformer/arguments.py +++ b/SwissArmyTransformer/arguments.py @@ -33,6 +33,8 @@ def add_model_config_args(parser): help='num of transformer attention heads') group.add_argument('--hidden-size', type=int, default=1024, help='tansformer hidden size') + group.add_argument('--inner-hidden-size', type=int, default=None) + group.add_argument('--hidden-size-per-attention-head', type=int, default=None) group.add_argument('--num-layers', type=int, default=24, help='num decoder layers') group.add_argument('--layernorm-epsilon', type=float, default=1e-5, @@ -129,6 +131,8 @@ def add_training_args(parser): group.add_argument('--fp16', action='store_true', help='Run model in fp16 mode') + group.add_argument('--bf16', action='store_true', + help='Run model in fp16 mode') return parser @@ -146,7 +150,8 @@ def add_evaluation_args(parser): 'validation/test for') group.add_argument('--eval-interval', type=int, default=1000, help='interval between running evaluation on validation set') - + group.add_argument('--strict-eval', action='store_true', + help='won\'t enlarge or randomly map eval-data, and eval full eval-data.') return parser @@ -297,25 +302,28 @@ def get_args(args_list=None): print('using world size: {} and model-parallel size: {} '.format( args.world_size, args.model_parallel_size)) - if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_config is not None: - with open(args.deepspeed_config) as file: - deepspeed_config = json.load(file) - if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]: - args.fp16 = True - else: - args.fp16 = False + if hasattr(args, "deepspeed") and args.deepspeed: if args.checkpoint_activations: args.deepspeed_activation_checkpointing = True - if "train_micro_batch_size_per_gpu" in deepspeed_config: - args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"] - if "gradient_accumulation_steps" in deepspeed_config: - args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"] else: - args.gradient_accumulation_steps = None - if "optimizer" in deepspeed_config: - optimizer_params_config = deepspeed_config["optimizer"].get("params", {}) - args.lr = optimizer_params_config.get("lr", args.lr) - args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay) + args.deepspeed_activation_checkpointing = False + if args.deepspeed_config is not None: + with open(args.deepspeed_config) as file: + deepspeed_config = json.load(file) + if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]: + args.fp16 = True + else: + args.fp16 = False + if "train_micro_batch_size_per_gpu" in deepspeed_config: + args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"] + if "gradient_accumulation_steps" in deepspeed_config: + args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"] + else: + args.gradient_accumulation_steps = None + if "optimizer" in deepspeed_config: + optimizer_params_config = deepspeed_config["optimizer"].get("params", {}) + args.lr = optimizer_params_config.get("lr", args.lr) + args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay) return args diff --git a/SwissArmyTransformer/data_utils/configure_data.py b/SwissArmyTransformer/data_utils/configure_data.py index b7d6299c69ef5d14030876fb6462b89b2daec654..5d53cded92c2264a82021ffb16912d8be8a45ee9 100755 --- a/SwissArmyTransformer/data_utils/configure_data.py +++ b/SwissArmyTransformer/data_utils/configure_data.py @@ -24,14 +24,17 @@ from .samplers import DistributedBatchSampler from SwissArmyTransformer import mpu -def make_data_loader(dataset, batch_size, num_iters, args): +def make_data_loader(dataset, batch_size, args): world_size = torch.distributed.get_world_size( group=mpu.get_data_parallel_group()) rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group()) distributed = world_size > 1 sampler = torch.utils.data.SequentialSampler(dataset) - drop_last = distributed + # drop_last = distributed + drop_last = True # TODO will always drop last to keep the consistency. + # or, how to avg in eval last batch? + # the GPUs in the same model parallel group receive the same data if distributed: # TODO reformat this, but it is not urgent gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1) @@ -52,7 +55,7 @@ def make_data_loader(dataset, batch_size, num_iters, args): return data_loader -def make_dataset_full(path, split, args, create_dataset_function, **kwargs): +def make_dataset_full(path, split, args, create_dataset_function, random_mapping=True, **kwargs): """function to create datasets+tokenizers for common options""" print('make dataset ...', path) if split is None: @@ -67,12 +70,8 @@ def make_dataset_full(path, split, args, create_dataset_function, **kwargs): ds = ConcatDataset(ds) if should_split(split): ds = split_ds(ds, split, block_size=args.block_size) - else: + elif random_mapping: ds = RandomMappingDataset(ds) - - # if should_split(split): - # ds = split_ds(ds, split) # Large dataset, cannot shuffle, randomly mapping - # FIXME this will merge valid set and train set. return ds def make_loaders(args, create_dataset_function): @@ -115,25 +114,25 @@ def make_loaders(args, create_dataset_function): # make training and val dataset if necessary if valid is None and args.valid_data is not None: eval_set_args['path'] = args.valid_data - valid = make_dataset(**eval_set_args, args=args) + valid = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval) if test is None and args.test_data is not None: eval_set_args['path'] = args.test_data - test = make_dataset(**eval_set_args, args=args) + test = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval) # wrap datasets with data loader if train is not None and args.batch_size > 0: - train = make_data_loader(train, batch_size, args.train_iters, args) + train = make_data_loader(train, batch_size, args) args.do_train = True else: args.do_train = False eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size if valid is not None: - valid = make_data_loader(valid, eval_batch_size, args.train_iters, args) + valid = make_data_loader(valid, eval_batch_size, args) args.do_valid = True else: args.do_valid = False if test is not None: - test = make_data_loader(test, eval_batch_size, len(test) // eval_batch_size + 1, args) + test = make_data_loader(test, eval_batch_size, args) args.do_test = True else: args.do_test = False diff --git a/SwissArmyTransformer/data_utils/datasets.py b/SwissArmyTransformer/data_utils/datasets.py index bec1207112194e70221c8f52852f2c9855ed951e..ef46e2da7bc185465320792a89b6c2804dce1ab6 100755 --- a/SwissArmyTransformer/data_utils/datasets.py +++ b/SwissArmyTransformer/data_utils/datasets.py @@ -65,3 +65,18 @@ class BinaryDataset(Dataset): def __getitem__(self, index): return self.process_fn(self.bin[index]) +class TSVDataset(Dataset): + def __init__(self, path, process_fn, with_heads=True, **kwargs): + self.process_fn = process_fn + with open(path, 'r') as fin: + if with_heads: + self.heads = fin.readline().split('\t') + else: + self.heads = None + self.items = [line.split('\t') for line in fin] + + def __len__(self): + return len(self.items) + + def __getitem__(self, index): + return self.process_fn(self.items[index]) diff --git a/SwissArmyTransformer/generation/autoregressive_sampling.py b/SwissArmyTransformer/generation/autoregressive_sampling.py index 37a24a88a004bfff63a87ebdcf3c909b29256a20..eb5acf42c303db3340e64249108f7f8e697293b4 100644 --- a/SwissArmyTransformer/generation/autoregressive_sampling.py +++ b/SwissArmyTransformer/generation/autoregressive_sampling.py @@ -56,7 +56,8 @@ def filling_sequence( max_memory_length=100000, log_attention_weights=None, get_masks_and_position_ids=get_masks_and_position_ids_default, - mems=None + mems=None, + **kw_args ): ''' seq: [2, 3, 5, ..., -1(to be generated), -1, ...] @@ -99,13 +100,15 @@ def filling_sequence( else: log_attention_weights_part = None - logits, *mem_kv = model( + logits, *output_per_layers = model( tokens[:, index:], position_ids[..., index: counter+1], attention_mask[..., index: counter+1, :counter+1], # TODO memlen mems=mems, - log_attention_weights=log_attention_weights_part + log_attention_weights=log_attention_weights_part, + **kw_args ) + mem_kv = [o['mem_kv'] for o in output_per_layers] mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length) counter += 1 index = counter diff --git a/SwissArmyTransformer/model/__init__.py b/SwissArmyTransformer/model/__init__.py index 32f46e4bce6d09b3f4780f9adbef6120960158be..4fbcd53865ddcd51433a352cec079a8562558b5b 100755 --- a/SwissArmyTransformer/model/__init__.py +++ b/SwissArmyTransformer/model/__init__.py @@ -2,4 +2,5 @@ from .base_model import BaseModel from .cached_autoregressive_model import CachedAutoregressiveModel from .cuda2d_model import Cuda2dModel from .glm_model import GLMModel -from .encoder_decoder_model import EncoderDecoderModel \ No newline at end of file +from .encoder_decoder_model import EncoderDecoderModel +from .t5_model import T5Model diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py index c9e1c9017782b073546fa12423690b71888a51b0..76e8c6a9e78527f6ba431c8c428add2245795917 100644 --- a/SwissArmyTransformer/model/base_model.py +++ b/SwissArmyTransformer/model/base_model.py @@ -7,6 +7,7 @@ ''' # here put the import lib +from functools import partial import os import sys import math @@ -14,19 +15,39 @@ import random import torch from SwissArmyTransformer.mpu import BaseTransformer +from SwissArmyTransformer.mpu.transformer import standard_attention + +def non_conflict(func): + func.non_conflict = True + return func class BaseMixin(torch.nn.Module): def __init__(self): super(BaseMixin, self).__init__() # define new params + def reinit(self, *pre_mixins): # reload the initial params from previous trained modules pass - # can also define hook-functions here + + # can define hook-functions here # ... + # If the hook is just a pre- or post- transformation, + # You can use @non_conflict to mark it, + # and run `old_impl` to make it compatible with other mixins. + # Eg., + # + # @non_conflict + # def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kw_args): + # new_q, new_k, new_v = pre_hack(q, k, v) + # attn_result = old_impl(q, k, v, mask, dropout_fn, **kw_args) + # attn_result = post_hack(attn_result) + # return attn_result + + class BaseModel(torch.nn.Module): - def __init__(self, args, transformer=None, parallel_output=True): + def __init__(self, args, transformer=None, **kwargs): super(BaseModel, self).__init__() self.mixins = torch.nn.ModuleDict() self.collect_hooks_() @@ -42,14 +63,16 @@ class BaseModel(torch.nn.Module): embedding_dropout_prob=args.hidden_dropout, attention_dropout_prob=args.attention_dropout, output_dropout_prob=args.hidden_dropout, + inner_hidden_size=args.inner_hidden_size, + hidden_size_per_attention_head=args.hidden_size_per_attention_head, checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, sandwich_ln=args.sandwich_ln, - parallel_output=parallel_output, - hooks=self.hooks + hooks=self.hooks, + **kwargs ) - def reinit(self): # will be called when loading model + def reinit(self): # will be called when loading model # if some mixins are loaded, overrides this function for m in self.mixins.values(): m.reinit(self.transformer) @@ -58,11 +81,11 @@ class BaseModel(torch.nn.Module): assert name not in self.mixins assert isinstance(new_mixin, BaseMixin) - self.mixins[name] = new_mixin # will auto-register parameters - object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr + self.mixins[name] = new_mixin # will auto-register parameters + object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr if reinit: - new_mixin.reinit(self.transformer, **self.mixins) # also pass current mixins + new_mixin.reinit(self.transformer, **self.mixins) # also pass current mixins self.collect_hooks_() def del_mixin(self, name): @@ -82,18 +105,25 @@ class BaseModel(torch.nn.Module): def collect_hooks_(self): names = ['word_embedding_forward', 'position_embedding_forward', - 'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward', - 'branch_embedding_forward', 'branch_final_forward' - ] + 'attention_forward', 'cross_attention_forward', 'mlp_forward', 'final_forward', 'layer_forward', + 'cross_layer_embedding_forward', + 'attention_fn' + ] hooks = {} hook_origins = {} for name in names: for mixin_name, m in self.mixins.items(): if hasattr(m, name): - if name in hooks: # conflict - raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.') - hooks[name] = getattr(m, name) - hook_origins[name] = mixin_name + if name in hooks: # if this hook name is already registered + if hasattr(getattr(m, name), 'non_conflict'): + hooks[name] = partial(getattr(m, name), old_impl=hooks[name]) + hook_origins[name] = mixin_name + ' -> ' + hook_origins[name] + else: # conflict + raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.') + else: # new hook + hooks[name] = getattr(m, name) + hook_origins[name] = mixin_name + if hasattr(self, name): # if name in hooks: # defined in mixins, can override # print(f'Override {name} in {hook_origins[name]}...') @@ -104,4 +134,4 @@ class BaseModel(torch.nn.Module): return hooks def disable_untrainable_params(self): - pass \ No newline at end of file + pass diff --git a/SwissArmyTransformer/model/cached_autoregressive_model.py b/SwissArmyTransformer/model/cached_autoregressive_model.py index a0e1699e923e9002e7a46ac69692ae7934032a57..8caed663b27e4b6c1b4382090ddd754a1c875d3e 100755 --- a/SwissArmyTransformer/model/cached_autoregressive_model.py +++ b/SwissArmyTransformer/model/cached_autoregressive_model.py @@ -13,43 +13,31 @@ import math import random import torch -from .base_model import BaseModel, BaseMixin +from .base_model import BaseModel, BaseMixin, non_conflict from SwissArmyTransformer.mpu.transformer import standard_attention, split_tensor_along_last_dim class CachedAutoregressiveMixin(BaseMixin): def __init__(self): - super().__init__() - - def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, log_attention_weights=None, **kwargs): - attn_module = self.transformer.layers[layer_id].attention - mem = mems[layer_id] if mems is not None else None - - mixed_raw_layer = attn_module.query_key_value(hidden_states) - (mixed_query_layer, - mixed_key_layer, - mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) - - if mem is not None: # the first time, mem is None - b = mixed_key_layer.shape[0] # might change batch_size - memk, memv = split_tensor_along_last_dim(mem.expand(b, -1, -1), 2) - mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1) - mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1) + super().__init__() + + @non_conflict + def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, cross_attention=False, old_impl=standard_attention, + **kw_args): + if not cross_attention: + mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size + b, nh, seq_len, hidden_size = k.shape + + cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2) + kw_args['output_this_layer']['mem_kv'] = cache_kv + + if mem is not None: # the first time, mem is None + # might change batch_size + mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4) + memk, memv = mem[0], mem[1] + k = torch.cat((memk, k), dim=2) + v = torch.cat((memv, v), dim=2) + return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, mems=mems, **kw_args) - # same as training - query_layer = attn_module._transpose_for_scores(mixed_query_layer) - key_layer = attn_module._transpose_for_scores(mixed_key_layer) - value_layer = attn_module._transpose_for_scores(mixed_value_layer) - context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=log_attention_weights) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - output = attn_module.dense(context_layer) - - # new mem this layer - new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous() - - return output, new_mem class CachedAutoregressiveModel(BaseModel): def __init__(self, args, transformer=None): diff --git a/SwissArmyTransformer/model/common_layers.py b/SwissArmyTransformer/model/common_layers.py deleted file mode 100644 index 90d24334511192b727ed0a0b4eeff588b8fdb960..0000000000000000000000000000000000000000 --- a/SwissArmyTransformer/model/common_layers.py +++ /dev/null @@ -1,91 +0,0 @@ -# -*- encoding: utf-8 -*- -''' -@File : components.py -@Time : 2021/11/23 18:20:22 -@Author : Ming Ding -@Contact : dm18@mails.tsinghua.edu.cn -''' - -# here put the import lib -import os -import sys -import math -import random -import torch -from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim -from SwissArmyTransformer.mpu.transformer import standard_attention, LayerNorm - -class CrossAttention(torch.nn.Module): - def __init__(self, hidden_size, num_attention_heads, - attention_dropout_prob, output_dropout_prob, - init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None): - super(CrossAttention, self).__init__() - # Set output layer initialization if not provided. - if output_layer_init_method is None: - output_layer_init_method = init_method - if inner_hidden_size is None: - inner_hidden_size = hidden_size - self.inner_hidden_size = inner_hidden_size - if enc_hidden_size is None: - enc_hidden_size = hidden_size - self.enc_hidden_size = enc_hidden_size - - # To make user understand better, temporally not support model parallel - world_size = 1 - self.hidden_size_per_partition = divide(hidden_size, world_size) - self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) - self.num_attention_heads_per_partition = divide(num_attention_heads, world_size) - - # To map encoder outputs - self.kv_linear = torch.nn.Linear( - enc_hidden_size, inner_hidden_size * 2 - ) - init_method(self.kv_linear.weight) - - # To map self - self.q_linear = torch.nn.Linear( - hidden_size, inner_hidden_size - ) - init_method(self.q_linear.weight) - - self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) - - self.dense = torch.nn.Linear( - inner_hidden_size, - hidden_size, - ) - output_layer_init_method(self.dense.weight) - self.output_dropout = torch.nn.Dropout(output_dropout_prob) - - - def _transpose_for_scores(self, tensor): - """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with - size [b, np, s, hn]. - """ - new_tensor_shape = tensor.size()[:-1] + \ - (self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) - tensor = tensor.view(*new_tensor_shape) - return tensor.permute(0, 2, 1, 3) - - def forward(self, hidden_states, mask, encoder_outputs, **kw_args): - - query_layer = self.q_linear(hidden_states) - key_layer, value_layer = split_tensor_along_last_dim(self.kv_linear(encoder_outputs), 2) - - dropout_fn = self.attention_dropout if self.training else None - - query_layer = self._transpose_for_scores(query_layer) - key_layer = self._transpose_for_scores(key_layer) - value_layer = self._transpose_for_scores(value_layer) - - context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - output = self.dense(context_layer) - - if self.training: - output = self.output_dropout(output) - - return output diff --git a/SwissArmyTransformer/model/cuda2d_model.py b/SwissArmyTransformer/model/cuda2d_model.py index cda027c5d8bcdb46cb0d343f4c0a0338dfccb5e0..4fdc87d3bee12f1b132f25f979d33c2582d36fa8 100644 --- a/SwissArmyTransformer/model/cuda2d_model.py +++ b/SwissArmyTransformer/model/cuda2d_model.py @@ -88,7 +88,7 @@ class Cuda2dModel(BaseModel): output_1 = dense_plus(context_layer1) output = torch.cat((output_0, output_1), dim=1) - return output, None + return output def disable_untrainable_params(self): self.transformer.requires_grad_(False) diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py index e70245368e68e6d43ee06d97a959868194b712d6..341a800083d5a317dfc55cde1592915a8296c4ca 100644 --- a/SwissArmyTransformer/model/encoder_decoder_model.py +++ b/SwissArmyTransformer/model/encoder_decoder_model.py @@ -14,131 +14,83 @@ import random import torch import argparse from .base_model import BaseModel, BaseMixin -from .common_layers import CrossAttention, LayerNorm +from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region -class CrossAttentionMixin(BaseMixin): - def __init__(self, num_layers, hidden_size, num_attention_heads, - attention_dropout_prob, output_dropout_prob, - init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None): - super().__init__() - - self.cross_attentions = torch.nn.ModuleList( - [CrossAttention( - hidden_size, num_attention_heads, - attention_dropout_prob, output_dropout_prob, - init_method, enc_hidden_size=enc_hidden_size, inner_hidden_size=inner_hidden_size, - output_layer_init_method=output_layer_init_method - ) for layer_id in range(num_layers)] - ) # Just copy args - self.cross_lns = torch.nn.ModuleList( - [LayerNorm(hidden_size, 1e-5) - for layer_id in range(num_layers)] - ) - - - def layer_forward(self, hidden_states, mask, layer_id, **kw_args): - layer = self.transformer.layers[layer_id] - encoder_outputs = kw_args['encoder_outputs'] - ''' - hidden_states: [batch, seq_len, hidden_size] - mask: [(1, 1), seq_len, seq_len] - encoder_outputs: [batch, enc_seq_len, enc_hidden_size] - ''' - # Layer norm at the begining of the transformer layer. - layernorm_output = layer.input_layernorm(hidden_states) - attention_output, output_this_layer = layer.attention(layernorm_output, mask, **kw_args) - # Third LayerNorm - if layer.sandwich_ln: - attention_output = layer.third_layernorm(attention_output) - # Residual connection. - hidden_states = hidden_states + attention_output - - # Cross attention. - layernorm_output = self.cross_lns[layer_id](hidden_states) - cross_attn_output = self.cross_attentions[layer_id]( - layernorm_output, - torch.ones(1, 1, device=hidden_states.device, dtype=hidden_states.dtype), - encoder_outputs - ) - hidden_states = hidden_states + cross_attn_output - - # Layer norm post the layer attention. - layernorm_output = layer.post_attention_layernorm(hidden_states) - # MLP. - mlp_output = layer.mlp(layernorm_output, **kw_args) +class EncoderFinalMixin(BaseMixin): + def final_forward(self, logits, **kwargs): + logits = copy_to_model_parallel_region(logits) + return logits - # Fourth LayerNorm - if layer.sandwich_ln: - mlp_output = layer.fourth_layernorm(mlp_output) - output = hidden_states + mlp_output - - return output, output_this_layer - - -class DecoderModel(BaseModel): - def __init__(self, args, transformer=None): - dec_args = argparse.Namespace(**vars(args)) - dec_args.enc_hidden_size = dec_args.hidden_size # used for cross attn - override_attrs = ['num_layers', 'vocab_size', - 'hidden_size', 'num_attention_heads', - 'max_sequence_length', 'sandwich_ln' # TODO - ] - for name in override_attrs: - dec_attr = getattr(dec_args, 'dec_' + name, None) - if dec_attr is not None: # else use encoder-config - setattr(dec_args, name, dec_attr) - - super().__init__(dec_args, transformer=transformer) - self.add_mixin('cross_attention', - CrossAttentionMixin( - dec_args.num_layers, - dec_args.hidden_size, dec_args.num_attention_heads, - dec_args.attention_dropout, dec_args.hidden_dropout, - self.transformer.init_method, - enc_hidden_size=dec_args.enc_hidden_size, - inner_hidden_size=getattr(dec_args, 'dec_inner_hidden_size', None), - output_layer_init_method=self.transformer.output_layer_init_method - ) - ) class EncoderDecoderModel(torch.nn.Module): - def __init__(self, args, encoder=None, decoder=None): + def __init__(self, args, encoder=None, decoder=None, tie_word_embeddings=True, parallel_output=False, **kwargs): super(EncoderDecoderModel, self).__init__() if encoder is not None: assert isinstance(encoder, BaseModel) self.encoder = encoder else: - self.encoder = BaseModel(args) + self.encoder = BaseModel(args, **kwargs) + self.encoder.add_mixin("final", EncoderFinalMixin()) if decoder is not None: assert isinstance(decoder, BaseModel) self.decoder = decoder else: - self.decoder = DecoderModel(args) + dec_args = argparse.Namespace(**vars(args)) + dec_args.enc_hidden_size = dec_args.hidden_size # used for cross attn + override_attrs = ['num_layers', 'hidden_size', 'num_attention_heads', + 'max_sequence_length', 'inner_hidden_size', 'hidden_size_per_attention_head'] + for name in override_attrs: + dec_attr = getattr(dec_args, 'dec_' + name, None) + if dec_attr is not None: # else use encoder-config + setattr(dec_args, name, dec_attr) + self.decoder = BaseModel(args, is_decoder=True, parallel_output=parallel_output, **kwargs) + + self.tie_word_embeddings = tie_word_embeddings + if tie_word_embeddings: + self.decoder.transformer.word_embeddings = self.encoder.transformer.word_embeddings def reinit(self): self.encoder.reinit() self.decoder.reinit() - + def disable_untrainable_params(self): self.encoder.disable_untrainable_params() self.decoder.disable_untrainable_params() + + def encode(self, input_ids, position_ids, attention_mask=None, **kw_args): + encoder_outputs, *_dumps = self.encoder(input_ids, position_ids, attention_mask, **kw_args) + return encoder_outputs + + def decode(self, input_ids, position_ids, attention_mask, encoder_outputs,cross_attention_mask=None, **kw_args): + if attention_mask is None: + batch_size, seq_length = input_ids.size()[:2] + seq_ids = torch.arange(seq_length, device=input_ids.device) + attention_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + attention_mask = attention_mask.to(self.decoder.transformer.word_embeddings.weight.dtype) + attention_mask = attention_mask[:, None, :, :] + # If no context, please explicitly pass ``encoder_outputs=None'' + return self.decoder(input_ids, position_ids, attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args) - def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids, dec_attention_mask, *, branch_input=None, **kw_args): - mask_one = torch.ones(1, 1, device=enc_input_ids.device, dtype=dec_attention_mask.dtype) - enc_outputs, *_dumps = self.encoder(enc_input_ids, enc_position_ids, mask_one, branch_input=branch_input, **kw_args) - dec_outputs, *dec_mems = self.decoder(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=enc_outputs, branch_input=branch_input, **kw_args) - return enc_outputs, dec_outputs, *dec_mems + def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids, *, enc_attention_mask=None, dec_attention_mask=None, cross_attention_mask=None, **kw_args): + # Please use self.decoder for auto-regressive generation. + batch_size, seq_length = enc_input_ids.size()[:2] + if enc_attention_mask is None: + enc_attention_mask = torch.ones(1, 1, 1, seq_length, dtype=self.encoder.transformer.word_embeddings.weight.dtype, device=enc_input_ids.device) + if cross_attention_mask is None: + cross_attention_mask = enc_attention_mask + encoder_outputs = self.encode(enc_input_ids, enc_position_ids, enc_attention_mask, **kw_args) + decoder_outputs, *mems = self.decode(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args) + return encoder_outputs, decoder_outputs, *mems @classmethod def add_model_specific_args(cls, parser): group = parser.add_argument_group('EncoderDecoderModel', 'T5 or Bart') - group.add_argument("--dec_num_layers", type=int, default=None) - group.add_argument("--dec_vocab_size", type=int, default=None) - group.add_argument("--dec_hidden_size", type=int, default=None) - group.add_argument("--dec_num_attention_heads", type=int, default=None) - group.add_argument("--dec_max_sequence_length", type=int, default=None) - group.add_argument("--dec_sandwich_ln", action='store_true') - group.add_argument("--dec_inner_hidden_size", type=int, default=None) - return parser \ No newline at end of file + group.add_argument("--dec-num-layers", type=int, default=None) + group.add_argument("--dec-hidden-size", type=int, default=None) + group.add_argument("--dec-num-attention-heads", type=int, default=None) + group.add_argument("--dec-max-sequence-length", type=int, default=None) + group.add_argument("--dec-inner-hidden-size", type=int, default=None) + group.add_argument("--dec-hidden-size-per-attention-head", type=int, default=None) + return parser diff --git a/SwissArmyTransformer/model/finetune/__init__.py b/SwissArmyTransformer/model/finetune/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2cc6a0915d3df1762edd4f202b2f96b81a5609 --- /dev/null +++ b/SwissArmyTransformer/model/finetune/__init__.py @@ -0,0 +1,2 @@ +from .mlp_head import MLPHeadMixin +from .prompt_tuning import PrefixTuningMixin, PTuningV2Mixin \ No newline at end of file diff --git a/SwissArmyTransformer/model/finetune/mlp_head.py b/SwissArmyTransformer/model/finetune/mlp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e25d661d0d36a24887614dbcbb07faaa445956e2 --- /dev/null +++ b/SwissArmyTransformer/model/finetune/mlp_head.py @@ -0,0 +1,36 @@ + +# -*- encoding: utf-8 -*- +''' +@File : mlp_head.py +@Time : 2021/12/12 20:44:09 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random + +import torch +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin, non_conflict + +class MLPHeadMixin(BaseMixin): + def __init__(self, hidden_size, *output_sizes, bias=True, activation_func=torch.nn.functional.relu, init_mean=0, init_std=0.005): + super().__init__() + self.activation_func = activation_func + last_size = hidden_size + self.layers = torch.nn.ModuleList() + for sz in output_sizes: + this_layer = torch.nn.Linear(last_size, sz, bias=bias) + last_size = sz + torch.nn.init.normal_(this_layer.weight, mean=init_mean, std=init_std) + self.layers.append(this_layer) + + def final_forward(self, logits, **kw_args): + for i, layer in enumerate(self.layers): + if i > 0: + logits = self.activation_func(logits) + logits = layer(logits) + return logits \ No newline at end of file diff --git a/SwissArmyTransformer/model/finetune/prompt_tuning.py b/SwissArmyTransformer/model/finetune/prompt_tuning.py new file mode 100644 index 0000000000000000000000000000000000000000..51cd1326b2a42cbcb89a327ab3b15892e3c0a2c1 --- /dev/null +++ b/SwissArmyTransformer/model/finetune/prompt_tuning.py @@ -0,0 +1,45 @@ +# -*- encoding: utf-8 -*- +''' +@File : prompt_tuning.py +@Time : 2021/12/12 20:45:18 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch + +from SwissArmyTransformer.mpu.transformer import standard_attention +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin, non_conflict + + +class PrefixTuningMixin(BaseMixin): + def __init__(self, num_layers, hidden_size_per_attention_head, num_attention_heads, prefix_len): + super().__init__() + self.prefix = torch.nn.ParameterList([ + torch.nn.Parameter(torch.randn(2, num_attention_heads, prefix_len, hidden_size_per_attention_head)*0.01) + for layer_id in range(num_layers) + ]) + self.prefix_len = prefix_len + + @non_conflict + def attention_fn(self, q, k, v, mask, dropout_fn, old_impl=standard_attention, **kw_args): + prefix_k, prefix_v = self.prefix[kw_args['layer_id']] + + b, nh, seq_len, hidden_size = k.shape + prefix_k = prefix_k.unsqueeze(0).expand(b, nh, -1, hidden_size) + prefix_v = prefix_v.unsqueeze(0).expand(b, nh, -1, hidden_size) + + k = torch.cat((k, prefix_k), dim=2) + v = torch.cat((v, prefix_v), dim=2) + if mask.numel() > 1: + mask_prefixed = torch.ones(self.prefix_len, device=mask.device, dtype=mask.dtype) + mask_prefixed = mask_prefixed.expand(*(mask.size()[:-1]), -1) + mask = torch.cat((mask, mask_prefixed), dim=-1) + return old_impl(q, k, v, mask, dropout_fn, **kw_args) + +PTuningV2Mixin = PrefixTuningMixin \ No newline at end of file diff --git a/SwissArmyTransformer/model/mixins.py b/SwissArmyTransformer/model/mixins.py index 2a76b8038265ff00b7dee078df70b169ebb34105..d1fcf89f7311f2f52835e54285dadfd016714fde 100644 --- a/SwissArmyTransformer/model/mixins.py +++ b/SwissArmyTransformer/model/mixins.py @@ -17,41 +17,45 @@ from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear from SwissArmyTransformer.mpu.transformer import unscaled_init_method from .base_model import BaseMixin from .cached_autoregressive_model import CachedAutoregressiveMixin +from .finetune import * class PositionEmbeddingMixin(BaseMixin): - def __init__(self, additional_sequence_length, hidden_size, - init_method_std=0.02, reinit_slice=slice(-1024, None) - ): + def __init__(self, additional_sequence_length, hidden_size, + init_method_std=0.02, reinit_slice=slice(-1024, None) + ): super(PositionEmbeddingMixin, self).__init__() self.reinit_slice = reinit_slice self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size) torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) + def reinit(self, *pre_mixins): old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice] old_len, hidden_size = old_weights.shape assert hidden_size == self.position_embeddings.weight.shape[-1] self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights) + class AttentionMixin(BaseMixin): def __init__(self, num_layers, - hidden_size, - init_method=unscaled_init_method(0.02), - output_layer_init_method=unscaled_init_method(0.02) - ): + hidden_size, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02) + ): super(AttentionMixin, self).__init__() - self.num_layers = num_layers # replace attention in the LAST n layers + self.num_layers = num_layers # replace attention in the LAST n layers self.query_key_value = torch.nn.ModuleList( - [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3, - gather_output=False,init_method=init_method) - for layer_id in range(num_layers) - ]) + [ColumnParallelLinear(hidden_size, 3 * hidden_size, stride=3, + gather_output=False, init_method=init_method) + for layer_id in range(num_layers) + ]) self.dense = torch.nn.ModuleList( [RowParallelLinear(hidden_size, - hidden_size, - input_is_parallel=True, - init_method=output_layer_init_method) - for layer_id in range(num_layers) - ]) + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method) + for layer_id in range(num_layers) + ]) + def reinit(self, *pre_mixins): start_layer = len(self.transformer.layers) - self.num_layers assert start_layer >= 0 diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d1a50951ea352babcd763a0178ed2de2013695dc --- /dev/null +++ b/SwissArmyTransformer/model/t5_model.py @@ -0,0 +1,287 @@ +import math +import torch +import torch.nn.functional as F +from .mixins import BaseMixin +from .encoder_decoder_model import EncoderDecoderModel +from .base_model import non_conflict +from SwissArmyTransformer.mpu import get_model_parallel_world_size +from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP +from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region +from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim, unscaled_init_method +from SwissArmyTransformer.mpu.layers import ColumnParallelLinear, VocabParallelEmbedding + + +class T5PositionEmbeddingMixin(BaseMixin): + def position_embedding_forward(self, position_ids, **kw_args): + return None + + +class T5LayerNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style No bias and no subtraction of mean. + """ + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # layer norm should always be calculated in float32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into float16 or bfloat16 if necessary + if self.weight.dtype == torch.float16: + hidden_states = hidden_states.to(torch.float16) + elif self.weight.dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.bfloat16) + return self.weight * hidden_states + + +class T5AttentionMixin(BaseMixin): + def __init__(self, relative_attention_num_buckets, num_attention_heads, is_decoder=False): + super().__init__() + self.relative_attention_num_buckets = relative_attention_num_buckets + world_size = get_model_parallel_world_size() + self.num_attention_heads_per_partition = divide(num_attention_heads, world_size) + self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, + self.num_attention_heads_per_partition) + self.is_decoder = is_decoder + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + # shape (query_length, key_length, num_heads) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + @non_conflict + def attention_fn(self, q, k, v, mask, dropout_fn, position_bias=None, old_impl=standard_attention, + cross_attention=False, **kw_args): + log_attention_weights = None + if not cross_attention: + if position_bias is None: + seq_length = q.size(2) + key_length = k.size(2) + position_bias = self.compute_bias(key_length, key_length) + position_bias = position_bias[:, :, -seq_length:, :] + kw_args['output_cross_layer']['position_bias'] = position_bias + log_attention_weights = position_bias + return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, position_bias=position_bias, + log_attention_weights=log_attention_weights, scaling_attention_score=False, **kw_args) + + +class T5DecoderFinalMixin(BaseMixin): + def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True): + super().__init__() + self.hidden_size = hidden_size + self.tie_word_embeddings = tie_word_embeddings + if not tie_word_embeddings: + self.lm_head = VocabParallelEmbedding( + vocab_size, hidden_size, init_method=unscaled_init_method(0.02)) + + def final_forward(self, logits, **kwargs): + logits_parallel = copy_to_model_parallel_region(logits) + if self.tie_word_embeddings: + logits_parallel = logits_parallel * (self.hidden_size ** -0.5) + logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight) + else: + logits_parallel = F.linear(logits_parallel, self.lm_head.weight) + return logits_parallel + + +def t5_gelu(x): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5GatedGeluMLPMixin(BaseMixin): + def __init__(self, num_layers, hidden_size, inner_hidden_size=None, bias=True, init_method_std=0.02): + super().__init__() + self.hidden_size = hidden_size + if inner_hidden_size is None: + inner_hidden_size = 4 * hidden_size + self.inner_hidden_size = inner_hidden_size + self.init_method_std = init_method_std + self.gated_h_to_4h_list = torch.nn.ModuleList([ + ColumnParallelLinear( + self.hidden_size, + self.inner_hidden_size, + gather_output=False, + init_method=self._init_weights, + bias=bias, + module=self, + name="gated_h_to_4h" + ) + for layer_id in range(num_layers)]) + + def _init_weights(self, weight, **kwargs): + torch.nn.init.normal_(weight, mean=0, std=self.init_method_std * (self.hidden_size ** -0.5)) + + def mlp_forward(self, hidden_states, layer_id=None, **kw_args): + mlp_module = self.transformer.layers[layer_id].mlp + hidden_gelu = t5_gelu(mlp_module.dense_h_to_4h(hidden_states)) + hidden_linear = self.gated_h_to_4h_list[layer_id](hidden_states) + hidden_states = hidden_gelu * hidden_linear + output = mlp_module.dense_4h_to_h(hidden_states) + + if self.training: + output = mlp_module.dropout(output) + return output + + +class T5Model(EncoderDecoderModel): + def __init__(self, args, **kwargs): + self.init_method_std = args.init_method_std + super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False, + layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu, + init_method=self._init_weights) + self.encoder.add_mixin( + "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads) + ) + self.encoder.add_mixin( + "t5-position", T5PositionEmbeddingMixin() + ) + del self.encoder.transformer.position_embeddings + num_attention_heads = args.dec_num_attention_heads if args.dec_num_attention_heads is not None else args.num_attention_heads + self.decoder.add_mixin( + "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, num_attention_heads, is_decoder=True) + ) + self.decoder.add_mixin( + "t5-position", T5PositionEmbeddingMixin() + ) + self.decoder.add_mixin( + "t5-final", + T5DecoderFinalMixin(args.vocab_size, args.hidden_size, tie_word_embeddings=not args.no_share_embeddings) + ) + del self.decoder.transformer.position_embeddings + if args.gated_gelu_mlp: + self.encoder.add_mixin( + "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std, + inner_hidden_size=args.inner_hidden_size, bias=False) + ) + self.decoder.add_mixin( + "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std, + inner_hidden_size=args.inner_hidden_size, bias=False) + ) + + def _init_weights(self, weight, module, name): + init_method_std = self.init_method_std + if isinstance(module, MLP): + if name == "dense_h_to_4h": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5)) + elif name == "dense_4h_to_h": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5)) + else: + raise NotImplementedError(name) + elif isinstance(module, SelfAttention): + if name == "query_key_value": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5)) + torch.nn.init.normal_(weight[:module.inner_hidden_size], mean=0, std=init_method_std * ( + (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5)) + elif name == "dense": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5)) + else: + raise NotImplementedError(name) + elif isinstance(module, CrossAttention): + if name == "query": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * ( + (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5)) + elif name == "key_value": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5)) + elif name == "dense": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5)) + else: + raise NotImplementedError(name) + else: + raise NotImplementedError(module) + + @classmethod + def add_model_specific_args(cls, parser): + super().add_model_specific_args(parser) + parser.add_argument("--relative-attention-num-buckets", type=int, default=None) + parser.add_argument("--init-method-std", type=float, default=0.02) + parser.add_argument("--gated-gelu-mlp", action='store_true') + parser.add_argument("--no-share-embeddings", action='store_true') + + def encode(self, input_ids, attention_mask=None, **kw_args): + return super().encode(input_ids, None, attention_mask, **kw_args) + + def decode(self, input_ids, attention_mask=None, encoder_outputs=None, cross_attention_mask=None, **kw_args): + return super().decode(input_ids, None, attention_mask, encoder_outputs=encoder_outputs, + cross_attention_mask=cross_attention_mask, **kw_args) + + def forward(self, enc_input_ids, dec_input_ids, *, enc_attention_mask=None, dec_attention_mask=None, + cross_attention_mask=None, **kw_args): + batch_size, seq_length = enc_input_ids.size()[:2] + if enc_attention_mask is None: + enc_attention_mask = torch.ones(1, 1, 1, seq_length, + dtype=self.encoder.transformer.word_embeddings.weight.dtype, + device=enc_input_ids.device) + if cross_attention_mask is None: + cross_attention_mask = enc_attention_mask + encoder_outputs = self.encode(enc_input_ids, enc_attention_mask, **kw_args) + decoder_outputs, *mems = self.decode(dec_input_ids, dec_attention_mask, + encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, + **kw_args) + return encoder_outputs, decoder_outputs, *mems diff --git a/SwissArmyTransformer/mpu/layers.py b/SwissArmyTransformer/mpu/layers.py index 4af3f33180aae46bba28efa213a9a941df5cd3ce..8d7cf370e75f2629ae96b1b2f442f4dfe6bae70f 100755 --- a/SwissArmyTransformer/mpu/layers.py +++ b/SwissArmyTransformer/mpu/layers.py @@ -37,7 +37,7 @@ from .utils import VocabUtility def _initialize_affine_weight(weight, output_size, input_size, per_partition_size, partition_dim, init_method, - stride=1, return_master_weight=False): + stride=1, return_master_weight=False, module=None, name=None): """Initialize affine weight for model parallel. Build the master weight on all processes and scatter @@ -45,7 +45,7 @@ def _initialize_affine_weight(weight, output_size, input_size, # If we only use 1 process for model parallelism, bypass scatter. world_size = get_model_parallel_world_size() if world_size == 1: - init_method(weight) + init_method(weight, module=module, name=name) if return_master_weight: return weight return None @@ -54,7 +54,7 @@ def _initialize_affine_weight(weight, output_size, input_size, master_weight = torch.empty(output_size, input_size, dtype=weight.dtype, requires_grad=False) - init_method(master_weight) + init_method(master_weight, module=module, name=name) # Split and copy per_partition_per_stride_size = divide(per_partition_size, stride) @@ -200,7 +200,7 @@ class ColumnParallelLinear(torch.nn.Module): """ def __init__(self, input_size, output_size, bias=True, gather_output=True, init_method=init.xavier_normal_, stride=1, - keep_master_weight_for_test=False): + keep_master_weight_for_test=False, module=None, name=None): super(ColumnParallelLinear, self).__init__() # Keep input parameters @@ -230,7 +230,7 @@ class ColumnParallelLinear(torch.nn.Module): self.master_weight = _initialize_affine_weight( self.weight, self.output_size, self.input_size, self.output_size_per_partition, 0, init_method, - stride=stride, return_master_weight=keep_master_weight_for_test) + stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name) def forward(self, input_): # Set up backprop all-reduce. @@ -274,7 +274,7 @@ class RowParallelLinear(torch.nn.Module): def __init__(self, input_size, output_size, bias=True, input_is_parallel=False, init_method=init.xavier_normal_, stride=1, - keep_master_weight_for_test=False): + keep_master_weight_for_test=False, module=None, name=None): super(RowParallelLinear, self).__init__() # Keep input parameters @@ -303,7 +303,7 @@ class RowParallelLinear(torch.nn.Module): self.master_weight = _initialize_affine_weight( self.weight, self.output_size, self.input_size, self.input_size_per_partition, 1, init_method, - stride=stride, return_master_weight=keep_master_weight_for_test) + stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name) def forward(self, input_): # Set up backprop all-reduce. diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index 3764f0552b1c4163536b061700b8d5d7b6e3ff79..289056948feff75964127fd5a2e4b0f4e8ed3944 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -22,55 +22,62 @@ import torch import torch.nn.functional as F from apex.normalization.fused_layer_norm import FusedLayerNorm +from SwissArmyTransformer import mpu from .initialize import get_model_parallel_world_size from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from .mappings import gather_from_model_parallel_region, copy_to_model_parallel_region -from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint, get_cuda_rng_tracker +from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint from .utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu from .utils import split_tensor_along_last_dim + class LayerNorm(FusedLayerNorm): def __init__(self, *args, pb_relax=False, **kwargs): super().__init__(*args, **kwargs) self.pb_relax = pb_relax + def forward(self, x): if not self.pb_relax: return super().forward(x) - return super().forward(x / (x.abs().max().detach()/8)) + return super().forward(x / (x.abs().max().detach() / 8)) + def standard_attention(query_layer, key_layer, value_layer, attention_mask, - attention_dropout=None, log_attention_weights=None): + attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs): # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. - attention_scores = torch.matmul( - query_layer / math.sqrt(query_layer.shape[-1]), - key_layer.transpose(-1, -2) - ) + if scaling_attention_score: + query_layer = query_layer / math.sqrt(query_layer.shape[-1]) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if log_attention_weights is not None: attention_scores += log_attention_weights - - if not(attention_mask.shape[-2] == 1 and (attention_mask > 0).all()): + + if not (attention_mask.shape[-2] == 1 and (attention_mask > 0).all()): # if auto-regressive, skip attention_scores = torch.mul(attention_scores, attention_mask) - \ - 10000.0 * (1.0 - attention_mask) + 10000.0 * (1.0 - attention_mask) attention_probs = F.softmax(attention_scores, dim=-1) if attention_dropout is not None: - with get_cuda_rng_tracker().fork(): + if mpu.get_cuda_rng_tracker is not None: + with mpu.get_cuda_rng_tracker().fork(): + attention_probs = attention_dropout(attention_probs) + else: attention_probs = attention_dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) return context_layer + class SelfAttention(torch.nn.Module): def __init__(self, hidden_size, num_attention_heads, - attention_dropout_prob, output_dropout_prob, - init_method, layer_id, output_layer_init_method=None, - hooks={}): + attention_dropout_prob, output_dropout_prob, + init_method, layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True, + hooks={}): super(SelfAttention, self).__init__() # Set output layer initialization if not provided. if output_layer_init_method is None: @@ -79,47 +86,61 @@ class SelfAttention(torch.nn.Module): self.layer_id = layer_id # Per attention head and per partition values. world_size = get_model_parallel_world_size() - self.hidden_size_per_partition = divide(hidden_size, world_size) - self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) + self.hidden_size = hidden_size + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head self.num_attention_heads_per_partition = divide(num_attention_heads, world_size) + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition # Strided linear layer. self.query_key_value = ColumnParallelLinear( hidden_size, - 3*hidden_size, + 3 * self.inner_hidden_size, stride=3, gather_output=False, - init_method=init_method + init_method=init_method, + bias=bias, + module=self, + name="query_key_value" ) self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) self.dense = RowParallelLinear( - hidden_size, + self.inner_hidden_size, hidden_size, input_is_parallel=True, - init_method=output_layer_init_method + init_method=output_layer_init_method, + bias=bias, + module=self, + name="dense" ) self.output_dropout = torch.nn.Dropout(output_dropout_prob) - def _transpose_for_scores(self, tensor): """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """ new_tensor_shape = tensor.size()[:-1] + \ - (self.num_attention_heads_per_partition, + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) tensor = tensor.view(*new_tensor_shape) return tensor.permute(0, 2, 1, 3) - def forward(self, hidden_states, mask, **kw_args): + def forward(self, hidden_states, mask, *args, **kw_args): if 'attention_forward' in self.hooks: - return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id) + return self.hooks['attention_forward'](hidden_states, mask, **kw_args) else: + attention_fn = standard_attention + if 'attention_fn' in self.hooks: + attention_fn = self.hooks['attention_fn'] + mixed_raw_layer = self.query_key_value(hidden_states) (mixed_query_layer, - mixed_key_layer, - mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) + mixed_key_layer, + mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) dropout_fn = self.attention_dropout if self.training else None @@ -127,7 +148,8 @@ class SelfAttention(torch.nn.Module): key_layer = self._transpose_for_scores(mixed_key_layer) value_layer = self._transpose_for_scores(mixed_value_layer) - context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn) + context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) @@ -136,40 +158,137 @@ class SelfAttention(torch.nn.Module): if self.training: output = self.output_dropout(output) - return output, None + return output + + +class CrossAttention(torch.nn.Module): + """Parallel cross-attention layer for Transformer""" + + def __init__(self, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, init_method, + layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True, hooks={}): + super().__init__() + # Set output layer initialization if not provided. + if output_layer_init_method is None: + output_layer_init_method = init_method + self.hooks = hooks + self.layer_id = layer_id + # Per attention head and per partition values. + world_size = get_model_parallel_world_size() + self.hidden_size = hidden_size + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head + self.num_attention_heads_per_partition = divide(num_attention_heads, world_size) + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition + # Strided linear layer. + self.query = ColumnParallelLinear(hidden_size, self.inner_hidden_size, + gather_output=False, + init_method=init_method, bias=bias, module=self, name="query") + self.key_value = ColumnParallelLinear(hidden_size, 2 * self.inner_hidden_size, + stride=2, + gather_output=False, + init_method=init_method, bias=bias, module=self, name="key_value") + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) + + # Output. + self.dense = RowParallelLinear( + self.inner_hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, bias=bias, module=self, name="dense") + self.output_dropout = torch.nn.Dropout(output_dropout_prob) + + def _transpose_for_scores(self, tensor): + """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with + size [b, np, s, hn]. + """ + new_tensor_shape = tensor.size()[:-1] + \ + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + tensor = tensor.view(*new_tensor_shape) + return tensor.permute(0, 2, 1, 3) + + def forward(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args): + # hidden_states: [b, s, h] + if 'cross_attention_forward' in self.hooks: + return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, **kw_args) + else: + attention_fn = standard_attention + if 'attention_fn' in self.hooks: + attention_fn = self.hooks['attention_fn'] + + mixed_query_layer = self.query(hidden_states) + mixed_x_layer = self.key_value(encoder_outputs) + (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2) + + dropout_fn = self.attention_dropout if self.training else None + # Reshape and transpose [b, np, s, hn] + query_layer = self._transpose_for_scores(mixed_query_layer) + key_layer = self._transpose_for_scores(mixed_key_layer) + value_layer = self._transpose_for_scores(mixed_value_layer) + + context_layer = attention_fn(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn, + cross_attention=True, **kw_args) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + # [b, s, hp] + context_layer = context_layer.view(*new_context_layer_shape) + + # Output. [b, s, h] + output = self.dense(context_layer) + if self.training: + output = self.output_dropout(output) + + return output class MLP(torch.nn.Module): - def __init__(self, hidden_size, output_dropout_prob, init_method, - output_layer_init_method=None, layer_id=None, hooks={}): + def __init__(self, hidden_size, output_dropout_prob, init_method, inner_hidden_size=None, + output_layer_init_method=None, layer_id=None, hooks={}, bias=True, activation_func=gelu): super(MLP, self).__init__() self.layer_id = layer_id + self.activation_func = activation_func # Set output layer initialization if not provided. if output_layer_init_method is None: output_layer_init_method = init_method self.hooks = hooks # Project to 4h. + self.hidden_size = hidden_size + if inner_hidden_size is None: + inner_hidden_size = 4 * hidden_size + self.inner_hidden_size = inner_hidden_size self.dense_h_to_4h = ColumnParallelLinear( - hidden_size, - 4*hidden_size, + self.hidden_size, + self.inner_hidden_size, gather_output=False, - init_method=init_method + init_method=init_method, + bias=bias, + module=self, + name="dense_h_to_4h" ) # Project back to h. self.dense_4h_to_h = RowParallelLinear( - 4*hidden_size, - hidden_size, + self.inner_hidden_size, + self.hidden_size, input_is_parallel=True, - init_method=output_layer_init_method + init_method=output_layer_init_method, + bias=bias, + module=self, + name="dense_4h_to_h" ) self.dropout = torch.nn.Dropout(output_dropout_prob) def forward(self, hidden_states, **kw_args): if 'mlp_forward' in self.hooks: - output = self.hooks['mlp_forward'](hidden_states, **kw_args, layer_id=self.layer_id) + output = self.hooks['mlp_forward'](hidden_states, **kw_args) else: intermediate_parallel = self.dense_h_to_4h(hidden_states) - intermediate_parallel = gelu(intermediate_parallel) + intermediate_parallel = self.activation_func(intermediate_parallel) output = self.dense_4h_to_h(intermediate_parallel) if self.training: @@ -179,27 +298,34 @@ class MLP(torch.nn.Module): class BaseTransformerLayer(torch.nn.Module): def __init__( - self, - hidden_size, - num_attention_heads, - attention_dropout_prob, - output_dropout_prob, - layernorm_epsilon, - init_method, - layer_id, - output_layer_init_method=None, - sandwich_ln=True, - hooks={} + self, + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + layernorm_epsilon, + init_method, + layer_id, + inner_hidden_size=None, + hidden_size_per_attention_head=None, + output_layer_init_method=None, + sandwich_ln=True, + layernorm=LayerNorm, + is_decoder=False, + use_bias=True, + activation_func=gelu, + hooks={} ): super(BaseTransformerLayer, self).__init__() # Set output layer initialization if not provided. if output_layer_init_method is None: output_layer_init_method = init_method self.layer_id = layer_id + self.is_decoder = is_decoder self.hooks = hooks # Layernorm on the input data. - self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) # Self attention. self.attention = SelfAttention( @@ -209,37 +335,57 @@ class BaseTransformerLayer(torch.nn.Module): output_dropout_prob, init_method, layer_id, + hidden_size_per_attention_head=hidden_size_per_attention_head, output_layer_init_method=output_layer_init_method, + bias=use_bias, hooks=hooks ) # Layernorm on the input data. - self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) self.sandwich_ln = sandwich_ln if sandwich_ln: - self.third_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) - self.fourth_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + self.third_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + self.fourth_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + + # Cross attention. + if self.is_decoder: + self.cross_attention = CrossAttention( + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + init_method, + layer_id, + hidden_size_per_attention_head=hidden_size_per_attention_head, + output_layer_init_method=output_layer_init_method, + bias=use_bias, + hooks=hooks + ) + self.post_cross_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) # MLP self.mlp = MLP( hidden_size, output_dropout_prob, init_method, + inner_hidden_size=inner_hidden_size, output_layer_init_method=output_layer_init_method, + bias=use_bias, layer_id=layer_id, + activation_func=activation_func, hooks=hooks ) - def forward(self, hidden_states, mask, **kw_args): + def forward(self, hidden_states, mask, *args, **kw_args): ''' hidden_states: [batch, seq_len, hidden_size] mask: [(1, 1), seq_len, seq_len] ''' - # Layer norm at the begining of the transformer layer. layernorm_output1 = self.input_layernorm(hidden_states) # Self attention. - attention_output, output_this_layer = self.attention(layernorm_output1, mask, **kw_args) + attention_output = self.attention(layernorm_output1, mask, **kw_args) # Third LayerNorm if self.sandwich_ln: @@ -249,6 +395,18 @@ class BaseTransformerLayer(torch.nn.Module): layernorm_input = hidden_states + attention_output # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) + + if self.is_decoder: + encoder_outputs = kw_args['encoder_outputs'] + if encoder_outputs is not None: + assert 'cross_attention_mask' in kw_args + # Cross attention + attention_output = self.cross_attention(layernorm_output, **kw_args) + # Residual connection. + layernorm_input = layernorm_input + attention_output + # Layer norm post the cross attention + layernorm_output = self.post_cross_attention_layernorm(layernorm_input) + # MLP. mlp_output = self.mlp(layernorm_output, **kw_args) @@ -259,7 +417,8 @@ class BaseTransformerLayer(torch.nn.Module): # Second residual connection. output = layernorm_input + mlp_output - return output, output_this_layer # temporally, output_this_layer is only from attention + return output, kw_args['output_this_layer'], kw_args['output_cross_layer'] + class BaseTransformer(torch.nn.Module): def __init__(self, @@ -275,18 +434,26 @@ class BaseTransformer(torch.nn.Module): checkpoint_num_layers=1, layernorm_epsilon=1.0e-5, init_method_std=0.02, + inner_hidden_size=None, + hidden_size_per_attention_head=None, sandwich_ln=True, parallel_output=True, + is_decoder=False, + use_bias=True, + activation_func=gelu, + layernorm=LayerNorm, + init_method=None, hooks={} ): super(BaseTransformer, self).__init__() # recording parameters + self.is_decoder = is_decoder self.parallel_output = parallel_output self.checkpoint_activations = checkpoint_activations self.checkpoint_num_layers = checkpoint_num_layers self.max_sequence_length = max_sequence_length - self.hooks = copy.copy(hooks) # hooks will be updated each forward + self.hooks = copy.copy(hooks) # hooks will be updated each forward # create embedding parameters self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) @@ -298,8 +465,13 @@ class BaseTransformer(torch.nn.Module): torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) # create all layers - self.output_layer_init_method = scaled_init_method(init_method_std, num_layers) - self.init_method = unscaled_init_method(init_method_std) + if init_method is None: + self.output_layer_init_method = scaled_init_method(init_method_std, num_layers) + self.init_method = unscaled_init_method(init_method_std) + else: + self.output_layer_init_method = init_method + self.init_method = init_method + def get_layer(layer_id): return BaseTransformerLayer( hidden_size, @@ -309,32 +481,39 @@ class BaseTransformer(torch.nn.Module): layernorm_epsilon, self.init_method, layer_id, + inner_hidden_size=inner_hidden_size, + hidden_size_per_attention_head=hidden_size_per_attention_head, output_layer_init_method=self.output_layer_init_method, + is_decoder=self.is_decoder, sandwich_ln=sandwich_ln, + layernorm=layernorm, + use_bias=use_bias, + activation_func=activation_func, hooks=self.hooks - ) + ) + self.layers = torch.nn.ModuleList( [get_layer(layer_id) for layer_id in range(num_layers)]) # Final layer norm before output. - self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) - def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, output_hidden_states=False, - **kw_args): - # sanity check + def forward(self, input_ids, position_ids, attention_mask, *, + output_hidden_states=False, **kw_args): + # sanity check assert len(input_ids.shape) == 2 batch_size, query_length = input_ids.shape + if attention_mask is None: + attention_mask = torch.ones(1, 1, device=input_ids.device).type_as( + next(self.parameters()) + ) # None means full attention assert len(attention_mask.shape) == 2 or \ - len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1 - assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor) - # branch_input is a new part of input need layer-by-layer update, - # but with different hidden_dim and computational routine. - # In most cases, you can just ignore it. + len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1 # embedding part if 'word_embedding_forward' in self.hooks: hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args) - else: # default + else: # default hidden_states = self.word_embeddings(input_ids) if 'position_embedding_forward' in self.hooks: @@ -343,62 +522,132 @@ class BaseTransformer(torch.nn.Module): assert len(position_ids.shape) <= 2 assert position_ids.shape[-1] == query_length position_embeddings = self.position_embeddings(position_ids) - hidden_states = hidden_states + position_embeddings + if position_embeddings is not None: + hidden_states = hidden_states + position_embeddings hidden_states = self.embedding_dropout(hidden_states) - hidden_states_outputs = [hidden_states] if output_hidden_states else [] - # branch related embedding - if branch_input is None and 'branch_embedding_forward' in self.hooks: - branch_input = self.hooks['branch_embedding_forward'](branch_input, **kw_args) + # initial output_cross_layer + if 'cross_layer_embedding_forward' in self.hooks: + output_cross_layer = self.hooks['cross_layer_embedding_forward'](hidden_states, **kw_args) + else: + output_cross_layer = {} - # define custom_forward for checkpointing output_per_layers = [] if self.checkpoint_activations: - def custom(start, end): + # define custom_forward for checkpointing + def custom(start, end, kw_args_index, cross_layer_index): def custom_forward(*inputs): layers_ = self.layers[start:end] x_, mask = inputs[0], inputs[1] - if len(inputs) > 2: # have branch_input - branch_ = inputs[2] + + # recover kw_args and output_cross_layer + flat_inputs = inputs[2:] + kw_args, output_cross_layer = {}, {} + for k, idx in kw_args_index.items(): + kw_args[k] = flat_inputs[idx] + for k, idx in cross_layer_index.items(): + output_cross_layer[k] = flat_inputs[idx] + # ----------------- + output_per_layers_part = [] for i, layer in enumerate(layers_): - if len(inputs) > 2: - x_, branch_, output_this_layer = self.hooks['layer_forward']( - x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args - ) - elif 'layer_forward' in self.hooks: - x_, output_this_layer = self.hooks['layer_forward']( - x_, mask, layer_id=layer.layer_id, **kw_args + if 'layer_forward' in self.hooks: + layer_ret = self.hooks['layer_forward']( + x_, mask, layer_id=layer.layer_id, + **kw_args, **output_cross_layer, + output_this_layer={}, output_cross_layer={} ) else: - x_, output_this_layer = layer(x_, mask, **kw_args) + layer_ret = layer( + x_, mask, layer_id=layer.layer_id, + **kw_args, **output_cross_layer, + output_this_layer={}, output_cross_layer={} + ) + if torch.is_tensor(layer_ret): # only hidden_states + x_, output_this_layer, output_cross_layer = layer_ret, {}, {} + elif len(layer_ret) == 2: # hidden_states & output_this_layer + x_, output_this_layer = layer_ret + output_cross_layer = {} + elif len(layer_ret) == 3: + x_, output_this_layer, output_cross_layer = layer_ret + assert isinstance(output_this_layer, dict) + assert isinstance(output_cross_layer, dict) + if output_hidden_states: + output_this_layer['hidden_states'] = x_ output_per_layers_part.append(output_this_layer) - return x_, output_per_layers_part + + # flatten for re-aggregate keywords outputs + flat_outputs = [] + for output_this_layer in output_per_layers_part: + for k in output_this_layer: + # TODO add warning for depth>=2 grad tensors + flat_outputs.append(output_this_layer[k]) + output_this_layer[k] = len(flat_outputs) - 1 + for k in output_cross_layer: + flat_outputs.append(output_cross_layer[k]) + output_cross_layer[k] = len(flat_outputs) - 1 + # -------------------- + + return x_, output_per_layers_part, output_cross_layer, flat_outputs return custom_forward + # prevent to lose requires_grad in checkpointing. + # To save memory when only finetuning the final layers, don't use checkpointing. + if self.training: + hidden_states.requires_grad_(True) + l, num_layers = 0, len(self.layers) chunk_length = self.checkpoint_num_layers + output_this_layer = [] while l < num_layers: args = [hidden_states, attention_mask] - if branch_input is not None: - hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, branch_input) - else: - hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args) - if output_hidden_states: - hidden_states_outputs.append(hidden_states) + # flatten kw_args and output_cross_layer + flat_inputs, kw_args_index, cross_layer_index = [], {}, {} + for k, v in kw_args.items(): + flat_inputs.append(v) + kw_args_index[k] = len(flat_inputs) - 1 + for k, v in output_cross_layer.items(): + flat_inputs.append(v) + cross_layer_index[k] = len(flat_inputs) - 1 + # -------------------- + + hidden_states, output_per_layers_part, output_cross_layer, flat_outputs = \ + checkpoint(custom(l, l + chunk_length, kw_args_index, cross_layer_index), *args, *flat_inputs) + + # recover output_per_layers_part, output_cross_layer + for output_this_layer in output_per_layers_part: + for k in output_this_layer: + output_this_layer[k] = flat_outputs[output_this_layer[k]] + for k in output_cross_layer: + output_cross_layer[k] = flat_outputs[output_cross_layer[k]] + # -------------------- + output_per_layers.extend(output_per_layers_part) l += chunk_length else: + output_this_layer = [] for i, layer in enumerate(self.layers): args = [hidden_states, attention_mask] - if branch_input is not None: # customized layer_forward with branch_input - hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_args) - elif 'layer_forward' in self.hooks: # customized layer_forward - hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args) + + if 'layer_forward' in self.hooks: # customized layer_forward + layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), + **kw_args, + **output_cross_layer, + output_this_layer={}, output_cross_layer={} + ) else: - hidden_states, output_this_layer = layer(*args, **kw_args) + layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer, + output_this_layer={}, output_cross_layer={}) + if torch.is_tensor(layer_ret): # only hidden_states + hidden_states, output_this_layer, output_cross_layer = layer_ret, {}, {} + elif len(layer_ret) == 2: # hidden_states & output_this_layer + hidden_states, output_this_layer = layer_ret + output_cross_layer = {} + elif len(layer_ret) == 3: + hidden_states, output_this_layer, output_cross_layer = layer_ret + if output_hidden_states: - hidden_states_outputs.append(hidden_states) + output_this_layer['hidden_states'] = hidden_states output_per_layers.append(output_this_layer) # Final layer norm. @@ -410,19 +659,10 @@ class BaseTransformer(torch.nn.Module): logits_parallel = copy_to_model_parallel_region(logits) logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight) - # branch related embedding - if branch_input is None and 'branch_final_forward' in self.hooks: - branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args) - if not self.parallel_output: logits_parallel = gather_from_model_parallel_region(logits_parallel) outputs = [logits_parallel] - if branch_input is not None: - outputs.append(branch_input) - if output_hidden_states: - outputs.append(hidden_states_outputs) outputs.extend(output_per_layers) return outputs - diff --git a/SwissArmyTransformer/mpu/utils.py b/SwissArmyTransformer/mpu/utils.py index c83f501889ccabb33cce33260f7d4fee9eadcab7..2f4f29f08227269634af831b91d2fe9b301b4e9e 100755 --- a/SwissArmyTransformer/mpu/utils.py +++ b/SwissArmyTransformer/mpu/utils.py @@ -75,7 +75,7 @@ def sqrt(x): def unscaled_init_method(sigma): """Init method based on N(0, sigma).""" - def init_(tensor): + def init_(tensor, **kwargs): return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) return init_ @@ -83,7 +83,7 @@ def unscaled_init_method(sigma): def scaled_init_method(sigma, num_layers): """Init method based on N(0, sigma/sqrt(2*num_layers).""" std = sigma / math.sqrt(2.0 * num_layers) - def init_(tensor): + def init_(tensor, **kwargs): return torch.nn.init.normal_(tensor, mean=0.0, std=std) return init_ diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py index fa93f3fd575bb4049c0b3bbbddd7d66e196d1e2b..2279ee051c622f17991f0337206933d2a05c3653 100644 --- a/SwissArmyTransformer/tokenization/__init__.py +++ b/SwissArmyTransformer/tokenization/__init__.py @@ -29,7 +29,8 @@ def _export_vocab_size_to_args(args, original_num_tokens): print_rank_0('> padded vocab (size: {}) with {} dummy ' 'tokens (new size: {})'.format( before, after - before, after)) - args.vocab_size = after + if not args.vocab_size: + args.vocab_size = after print_rank_0("prepare tokenizer done") return tokenizer @@ -51,7 +52,7 @@ def get_tokenizer(args=None, outer_tokenizer=None): from .cogview import UnifiedTokenizer get_tokenizer.tokenizer = UnifiedTokenizer( args.img_tokenizer_path, - # txt_tokenizer_type=args.tokenizer_type, + txt_tokenizer_type='cogview', device=torch.cuda.current_device() ) elif args.tokenizer_type.startswith('glm'): @@ -63,6 +64,10 @@ def get_tokenizer(args=None, outer_tokenizer=None): elif args.tokenizer_type == "glm_ChineseSPTokenizer": from .glm import ChineseSPTokenizer get_tokenizer.tokenizer = ChineseSPTokenizer(args.tokenizer_model_type, **kwargs) + elif args.tokenizer_type.startswith('hf'): + from .hf_tokenizer import HFT5Tokenizer + if args.tokenizer_type == "hf_T5Tokenizer": + get_tokenizer.tokenizer = HFT5Tokenizer(args.tokenizer_model_type, cache_dir=args.cache_dir) else: assert args.vocab_size > 0 get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size) diff --git a/SwissArmyTransformer/tokenization/cogview/sp_tokenizer.py b/SwissArmyTransformer/tokenization/cogview/sp_tokenizer.py index ee8c907742dea9df7a715d1f44fb3ad067ddadb0..c9e731e6fdcc7fc52c6a0cb90d1c14f3ca0306fb 100755 --- a/SwissArmyTransformer/tokenization/cogview/sp_tokenizer.py +++ b/SwissArmyTransformer/tokenization/cogview/sp_tokenizer.py @@ -22,6 +22,8 @@ python setup.py install PRETRAINED_MODEL_FILE = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'embed_assets', 'chinese_sentencepiece/cog-pretrain.model') +PRETRAINED_MODEL_FILE_ICE = os.path.join(os.path.dirname(os.path.dirname(__file__)), + 'embed_assets', 'chinese_sentencepiece/ice.model') # merge xlnet 3,2000 En tokens def get_pairs(word): @@ -148,5 +150,10 @@ def get_encoder(encoder_file, bpe_file): ) -def from_pretrained(): - return get_encoder(PRETRAINED_MODEL_FILE, "") \ No newline at end of file +def from_pretrained(tokenizer_type='cogview'): + if tokenizer_type == 'cogview_ICE': + return get_encoder(PRETRAINED_MODEL_FILE_ICE, "") + elif tokenizer_type == 'cogview': + return get_encoder(PRETRAINED_MODEL_FILE, "") + else: + raise ValueError('Unknown cogview tokenizer.') \ No newline at end of file diff --git a/SwissArmyTransformer/tokenization/cogview/unified_tokenizer.py b/SwissArmyTransformer/tokenization/cogview/unified_tokenizer.py index 8a06004d332955aa4b2f1fde712b52f188bc2a59..7b17f751aa8ca1d583f4e2641f43270c8667fe13 100755 --- a/SwissArmyTransformer/tokenization/cogview/unified_tokenizer.py +++ b/SwissArmyTransformer/tokenization/cogview/unified_tokenizer.py @@ -20,10 +20,10 @@ from .sp_tokenizer import from_pretrained from .vqvae_tokenizer import VQVAETokenizer, sqrt_int class UnifiedTokenizer(object): - def __init__(self, img_tokenizer_path, device): + def __init__(self, img_tokenizer_path, txt_tokenizer_type, device): self.device = device self.img_tokenizer = VQVAETokenizer(model_path=img_tokenizer_path, device=self.device) - self.txt_tokenizer = from_pretrained() + self.txt_tokenizer = from_pretrained(txt_tokenizer_type) self.num_tokens = self.img_tokenizer.num_tokens + self.txt_tokenizer.num_tokens self.raw_command_tokens = [ ('[PAD]', 0), diff --git a/SwissArmyTransformer/tokenization/glm/tokenization.py b/SwissArmyTransformer/tokenization/glm/tokenization.py index 67be818a0f1d42b0d35696c24577b5e5391ff3b4..9b9a8abc0202b46a95efc8a8338e90e0da9fd60f 100644 --- a/SwissArmyTransformer/tokenization/glm/tokenization.py +++ b/SwissArmyTransformer/tokenization/glm/tokenization.py @@ -312,11 +312,11 @@ class Tokenizer(object): tokenization.tokenization = [self.IdToToken(idx) for idx in tokenization.tokenization] return tokenization - def IdToToken(self, Id): + def IdToToken(self, idx): """convert Id to token accounting for command tokens""" - if isinstance(Id, CommandToken): - return Id.token - return self.tokens[Id] + if isinstance(idx, CommandToken): + return idx.token + return self.tokens[idx] def TokenToId(self, token): """convert token to Id accounting for command tokens""" @@ -324,16 +324,16 @@ class Tokenizer(object): return token.Id return self.vocab[token] - def DecodeIds(self, Ids): + def DecodeIds(self, ids): """ convert Ids to tokens accounting for command tokens, tokens are joined and returned as a string. """ rtn_strs = [] current_str = [] - if isinstance(Ids, Tokenization): - Ids = Ids.tokenization - for Id in Ids: + if isinstance(ids, Tokenization): + ids = ids.tokenization + for Id in ids: if isinstance(Id, CommandToken): rtn_strs.append(self._decode(current_str)) current_str = [] @@ -353,11 +353,11 @@ class Tokenizer(object): output = self.clean_up_tokenization(output) return output - def DecodeTokens(self, Tokens): + def DecodeTokens(self, tokens): """ convert tokens to a string accounting for command and type tokens. """ - Ids = [self.TokenToId(token) for token in Tokens] + Ids = [self.TokenToId(token) for token in tokens] return self.DecodeIds(Ids) diff --git a/SwissArmyTransformer/tokenization/hf_tokenizer.py b/SwissArmyTransformer/tokenization/hf_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..790b60d56441267b9ccd883b15cc0b11ab857018 --- /dev/null +++ b/SwissArmyTransformer/tokenization/hf_tokenizer.py @@ -0,0 +1,80 @@ +from transformers import T5Tokenizer +from .glm.tokenization import Tokenization, CommandToken + + +PRETRAINED_VOCAB_FILES_MAP = { + "t5-small": "/dataset/fd5061f6/yanan/huggingface_models/t5-small", + "t5-base": "/dataset/fd5061f6/yanan/huggingface_models/t5-base", + "t5-large": "/dataset/fd5061f6/yanan/huggingface_models/t5-large", + "t5-3b": "/dataset/fd5061f6/yanan/huggingface_models/t5-3b", + "t5-11b": "/dataset/fd5061f6/yanan/huggingface_models/t5-11b" +} + +class HFTokenizer: + def __init__(self, model_cls, model_type_or_path=None, cache_dir=None, command_tokens=None): + if model_type_or_path in PRETRAINED_VOCAB_FILES_MAP: + model_type_or_path = PRETRAINED_VOCAB_FILES_MAP[model_type_or_path] + self.text_tokenizer = model_cls.from_pretrained(model_type_or_path, cache_dir=cache_dir) + self.num_tokens = len(self.text_tokenizer) + self._command_tokens = [] + self.command_name_map = {} + self.command_token_map = {} + self.command_id_map = {} + + def __len__(self): + return len(self.text_tokenizer) + + @property + def command_tokens(self): + return self._command_tokens + + @command_tokens.setter + def command_tokens(self, command_tokens): + self._command_tokens = command_tokens + self.command_name_map = {tok.name: tok for tok in self.command_tokens} + self.command_token_map = {tok.token: tok for tok in self.command_tokens} + self.command_id_map = {tok.Id: tok for tok in self.command_tokens} + + def get_command(self, name): + """get command token corresponding to `name`""" + return self.command_name_map[name] + + def EncodeAsIds(self, text, process_fn=None): + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + ids = self.text_tokenizer.encode(processed_text, add_special_tokens=False) + tokenization = Tokenization(ids, processed_text, text) + return tokenization + + def DecodeIds(self, ids): + if isinstance(ids, Tokenization): + ids = ids.tokenization + return self.text_tokenizer.decode(ids) + + def DecodeTokens(self, tokens): + return self.text_tokenizer.convert_tokens_to_string(tokens) + + def IdToToken(self, Id): + if isinstance(Id, CommandToken): + return Id.token + return self.text_tokenizer.convert_ids_to_tokens(Id) + + def TokenToId(self, token): + if isinstance(token, CommandToken): + return token.Id + return self.text_tokenizer.convert_tokens_to_ids(token) + + +class HFT5Tokenizer(HFTokenizer): + def __init__(self, model_type_or_path=None, cache_dir=None): + super().__init__(T5Tokenizer, model_type_or_path=model_type_or_path, cache_dir=cache_dir) + command_tokens = [ + CommandToken('eos', '</s>', self.TokenToId("</s>")), + CommandToken('pad', '<pad>', self.TokenToId("<pad>")), + CommandToken('sop', '<pad>', self.TokenToId("<pad>")) + ] + for i in range(100): + command_tokens.append(CommandToken(f'MASK{i}', f'<extra_id_{i}>', self.TokenToId(f'<extra_id_{i}>'))) + self.command_tokens = command_tokens + diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py index c7860e3a9883811732a968b168894b2a668f29e4..beacffc39d8dfce4609c25fcf40c37c1d78e6c6e 100644 --- a/SwissArmyTransformer/training/deepspeed_training.py +++ b/SwissArmyTransformer/training/deepspeed_training.py @@ -100,14 +100,6 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio if val_data is not None: start_iter_val = (args.train_iters // args.save_interval) * args.eval_interval val_data.batch_sampler.start_iter = start_iter_val % len(val_data) - if train_data is not None: - train_data_iterator = iter(train_data) - else: - train_data_iterator = None - if val_data is not None: - val_data_iterator = iter(val_data) - else: - val_data_iterator = None # init hook before training if hooks['init_function'] is not None: @@ -121,16 +113,12 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio def save_on_exit(args_, model_, optimizer_, lr_scheduler_): save_checkpoint(args_.iteration, model_, optimizer_, lr_scheduler_, args_) iteration, skipped = train(model, optimizer, - lr_scheduler, - train_data_iterator, - val_data_iterator, - timers, args, summary_writer=summary_writer, - hooks=hooks - ) - if args.do_valid: - prefix = 'the end of training for val data' - val_loss = evaluate_and_print_results(prefix, val_data_iterator, - model, args, timers, False, hooks=hooks) + lr_scheduler, + train_data, + val_data, + timers, args, summary_writer=summary_writer, + hooks=hooks + ) # final save if args.save and iteration != 0: # TODO save @@ -140,7 +128,7 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio if args.do_test and test_data is not None: prefix = 'the end of training for test data' evaluate_and_print_results(prefix, iter(test_data), - model, args, timers, True, hooks=hooks) + model, len(test_data) if args.strict_eval else args.eval_iters, args, timers, True, hooks=hooks) def get_model(args, model_cls): @@ -156,6 +144,8 @@ def get_model(args, model_cls): if args.fp16: model.half() + elif args.bf16: + model.bfloat16() model.cuda(torch.cuda.current_device()) return model @@ -258,11 +248,21 @@ def get_learning_rate_scheduler(optimizer, iteration, args, def train(model, optimizer, lr_scheduler, - train_data_iterator, val_data_iterator, timers, args, - summary_writer=None, hooks={}): + train_data, val_data, timers, args, + summary_writer=None, hooks={}): """Train the model.""" + if train_data is not None: + train_data_iterator = iter(train_data) + else: + train_data_iterator = None + if val_data is not None: + val_data_iterator = iter(val_data) + else: + val_data_iterator = None + # Turn on training mode which enables dropout. model.train() + # Tracking loss. total_lm_loss = 0.0 @@ -316,10 +316,14 @@ def train(model, optimizer, lr_scheduler, # Evaluation if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid: + if args.strict_eval: + val_data_iterator = iter(val_data) + eval_iters = len(val_data) + else: + eval_iters = args.eval_iters prefix = 'iteration {}'.format(args.iteration) evaluate_and_print_results( - prefix, val_data_iterator, model, args, timers, False, step=args.iteration, - summary_writer=summary_writer, hooks=hooks) + prefix, val_data_iterator, model, eval_iters, args, timers, False, step=args.iteration, summary_writer=summary_writer, hooks=hooks) if args.exit_interval and args.iteration % args.exit_interval == 0: torch.distributed.barrier() @@ -412,21 +416,19 @@ def backward_step(optimizer, model, loss, args, timers): return - -def evaluate(data_iterator, model, args, timers, verbose=False, hooks={}): +def evaluate(data_iterator, model, eval_iters, args, timers, verbose=False, hooks={}): """Evaluation.""" forward_step = hooks['forward_step'] - # Turn on evaluation mode which disables dropout. model.eval() - total_lm_loss = 0 + total_lm_loss, metrics_total = 0, {} with torch.no_grad(): iteration = 0 - while iteration < args.eval_iters: + while iteration < eval_iters: iteration += 1 if verbose and iteration % args.log_interval == 0: - print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters)) + print_rank_0('Evaluating iter {}/{}'.format(iteration, eval_iters)) # Forward evaluation. lm_loss, metrics = forward_step(data_iterator, model, args, timers) '''when contiguous memory optimizations are enabled, the buffers @@ -436,27 +438,31 @@ def evaluate(data_iterator, model, args, timers, verbose=False, hooks={}): if args.deepspeed and args.deepspeed_activation_checkpointing: deepspeed.checkpointing.reset() total_lm_loss += lm_loss.data.detach().float().item() + for name in metrics: + if name not in metrics_total: + metrics_total[name] = 0.0 + metrics_total[name] += metrics[name] # Move model back to the train mode. model.train() - total_lm_loss /= args.eval_iters - return total_lm_loss + total_lm_loss /= eval_iters + metrics_avg = {key: value / eval_iters for key, value in metrics_total.items()} + return total_lm_loss, metrics_avg - -def evaluate_and_print_results(prefix, data_iterator, model, - args, timers, verbose=False, step=None, summary_writer=None, hooks={}): +def evaluate_and_print_results(prefix, data_iterator, model, eval_iters, + args, timers, verbose=False, step=None, summary_writer=None, hooks={}): """Helper function to evaluate and dump results on screen.""" # import line_profiler # profile = line_profiler.LineProfiler(model.module.module.transformer.layers[0].forward) # profile.enable() # torch.cuda.empty_cache() - lm_loss = evaluate(data_iterator, model, args, timers, verbose, hooks=hooks) + lm_loss, metrics = evaluate(data_iterator, model, eval_iters, args, timers, verbose, hooks=hooks) # profile.disable() # import sys # profile.print_stats(sys.stdout) lm_ppl = math.exp(min(20, lm_loss)) - report_evaluate_metrics(summary_writer, prefix, lm_loss, lm_ppl, step) + report_evaluate_metrics(summary_writer, prefix, lm_loss, lm_ppl, step, metrics) return lm_loss @@ -480,10 +486,12 @@ def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time, summary_writer.add_scalar('Train/'+key, avg_metrics[key], step) -def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step): +def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step, avg_metrics): string = ' validation loss at {} | '.format(prefix) string += 'LM loss: {:.6E} | '.format(loss) string += 'LM PPL: {:.6E}'.format(ppl) + for key in avg_metrics: + string += ' {} {:.6E} |'.format(key, avg_metrics[key]) length = len(string) + 1 print_rank_0('-' * 100) print_rank_0('-' * length) @@ -492,7 +500,9 @@ def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step): if summary_writer is not None: summary_writer.add_scalar(f'Train/valid_ppl', ppl, step) summary_writer.add_scalar(f'Train/valid_loss', loss, step) - + for key in avg_metrics: + summary_writer.add_scalar('Train/valid_'+key, avg_metrics[key], step) + ''' Optional DeepSpeed Activation Checkpointing features @@ -538,7 +548,8 @@ def initialize_distributed(args): # Optional DeepSpeed Activation Checkpointing Features if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing: set_deepspeed_activation_checkpointing(args) # TODO manual model-parallel seed - + else: + mpu.get_cuda_rng_tracker = None def set_random_seed(seed): """Set random seed for reproducability.""" diff --git a/examples/glm/finetune_glm_sst2.py b/examples/glm/finetune_glm_sst2.py new file mode 100755 index 0000000000000000000000000000000000000000..14ef44d9cfa606d78fa9ea461f609453b24d8e73 --- /dev/null +++ b/examples/glm/finetune_glm_sst2.py @@ -0,0 +1,109 @@ +# -*- encoding: utf-8 -*- +''' +@File : finetune_glm_sst2.py +@Time : 2021/12/12 20:53:28 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random + +from SwissArmyTransformer.data_utils.datasets import TSVDataset +import torch +import argparse +import numpy as np + +from SwissArmyTransformer import mpu, get_args, get_tokenizer +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin, non_conflict +from SwissArmyTransformer.training.deepspeed_training import training_main +from SwissArmyTransformer.data_utils import TSVDataset +from SwissArmyTransformer.model import GLMModel +from SwissArmyTransformer.mpu.transformer import standard_attention +from SwissArmyTransformer.model.mixins import MLPHeadMixin, PrefixTuningMixin + +class ClassificationModel(GLMModel): + def __init__(self, args, transformer=None, parallel_output=True): + super().__init__(args, transformer=transformer, parallel_output=parallel_output) + self.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1)) + self.add_mixin('prefix-tuning', PrefixTuningMixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.prefix_len)) + def disable_untrainable_params(self): + self.transformer.word_embeddings.requires_grad_(False) + # for layer_id in range(len(self.transformer.layers)): + # self.transformer.layers[layer_id].requires_grad_(False) + +def get_batch(data_iterator, args, timers): + # Items and their type. + keys = ['sentence', 'label'] + datatype = torch.int64 + + # Broadcast data. + timers('data loader').start() + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + timers('data loader').stop() + data_b = mpu.broadcast_data(keys, data, datatype) + # Unpack. + tokens = data_b['sentence'].long() + labels = data_b['label'].long() + batch_size, seq_length = tokens.size() + + position_ids = torch.zeros(2, seq_length, device=tokens.device, dtype=torch.long) + torch.arange(0, seq_length, out=position_ids[0, :seq_length]) + position_ids = position_ids.unsqueeze(0) + + attention_mask = torch.ones((batch_size, 1, seq_length, seq_length), device=tokens.device) + + attention_mask[...,:seq_length] -= (tokens==-1).view(batch_size, 1, 1, seq_length).float() + # Convert + if args.fp16: + attention_mask = attention_mask.half() + return tokens, labels, attention_mask, position_ids, (tokens!=-1) + + +def forward_step(data_iterator, model, args, timers): + """Forward step.""" + + # Get the batch. + timers('batch generator').start() + tokens, labels, attention_mask, position_ids, loss_mask = get_batch( + data_iterator, args, timers) + timers('batch generator').stop() + + logits, *mems = model(tokens, position_ids, attention_mask) + pred = ((logits.contiguous().float().squeeze(-1)) * loss_mask).sum(dim=-1) / loss_mask.sum(dim=-1) + loss = torch.nn.functional.binary_cross_entropy_with_logits( + pred, + labels.float() + ) + acc = ((pred > 0.).long() == labels).sum() / labels.numel() + return loss, {'acc': acc} + +def create_dataset_function(path, args): + tokenizer = get_tokenizer() + def process_fn(row): + sentence, label = tokenizer._encode(row[0]), int(row[1]) + sentence = [tokenizer.get_command('ENC').Id] + sentence + [tokenizer.get_command('eos').Id] + if len(sentence) >= args.sample_length: + sentence = sentence[:args.sample_length] + else: + sentence.extend([-1] * (args.sample_length-len(sentence))) + return {'sentence': np.array(sentence, dtype=np.int64), 'label': label} + return TSVDataset(path, process_fn, with_heads=True) + +if __name__ == '__main__': + py_parser = argparse.ArgumentParser(add_help=False) + py_parser.add_argument('--new_hyperparam', type=str, default=None) + py_parser.add_argument('--sample_length', type=int, default=80) + py_parser.add_argument('--prefix_len', type=int, default=16) + known, args_list = py_parser.parse_known_args() + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + # from cogdata.utils.ice_tokenizer import get_tokenizer as get_ice + # tokenizer = get_tokenizer(args=args, outer_tokenizer=get_ice()) + training_main(args, model_cls=ClassificationModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function) diff --git a/examples/glm/scripts/ds_config_ft.json b/examples/glm/scripts/ds_config_ft.json new file mode 100755 index 0000000000000000000000000000000000000000..17effe048977096633ce31ef75a1cd0d9b2d813b --- /dev/null +++ b/examples/glm/scripts/ds_config_ft.json @@ -0,0 +1,30 @@ +{ + "train_micro_batch_size_per_gpu":64, + "gradient_accumulation_steps": 1, + "steps_per_print": 10, + "gradient_clipping": 0.1, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 400, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00001, + "betas": [ + 0.9, + 0.95 + ], + "eps": 1e-8, + "weight_decay": 0 + } + }, + "activation_checkpointing": { + "partition_activations": false, + "contiguous_memory_optimization": false + }, + "wall_clock_breakdown": false +} diff --git a/examples/glm/scripts/finetune_sst2.sh b/examples/glm/scripts/finetune_sst2.sh new file mode 100755 index 0000000000000000000000000000000000000000..4e4809a07557b96afa37b00e1ea74a5830d8608e --- /dev/null +++ b/examples/glm/scripts/finetune_sst2.sh @@ -0,0 +1,62 @@ +#! /bin/bash + +# Change for multinode config +CHECKPOINT_PATH=/dataset/fd5061f6/sat_pretrained/glm + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=1 +MP_SIZE=1 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +main_dir=$(dirname $script_dir) +source $main_dir/config/model_glm_roberta_large.sh + +OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" +HOST_FILE_PATH="hostfile" +HOST_FILE_PATH="hostfile_single" + +en_data="/dataset/fd5061f6/english_data/glue_data/SST-2/train.tsv" +eval_data="/dataset/fd5061f6/english_data/glue_data/SST-2/dev.tsv" + + +config_json="$script_dir/ds_config_ft.json" +gpt_options=" \ + --experiment-name finetune-glm-sst2 \ + --model-parallel-size ${MP_SIZE} \ + --mode finetune \ + --train-iters 6000 \ + --resume-dataloader \ + $MODEL_ARGS \ + --train-data ${en_data} \ + --valid-data ${eval_data} \ + --distributed-backend nccl \ + --lr-decay-style cosine \ + --warmup .02 \ + --checkpoint-activations \ + --fp16 \ + --save-interval 6000 \ + --eval-interval 100 \ + --save /root/checkpoints \ + --split 1 \ + --strict-eval \ + --eval-batch-size 8 +" + # --load /root/checkpoints/pretrain-bert-mid-std-fulltrain12-02-06-10 + # \ --sandwich-ln + # --split 949,50,1 \ + # --load /root/checkpoints/pretrain-bert-mid11-28-15-38 \ + + + +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + + +run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} finetune_glm_sst2.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/examples/t5/config/config_t5_large.json b/examples/t5/config/config_t5_large.json new file mode 100644 index 0000000000000000000000000000000000000000..1a4bffcaff366d29421f5d1d605b819457586b78 --- /dev/null +++ b/examples/t5/config/config_t5_large.json @@ -0,0 +1,31 @@ +{ + "train_micro_batch_size_per_gpu": 4, + "gradient_accumulation_steps": 1, + "steps_per_print": 100, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 2, + "contiguous_gradients": false, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 50000000, + "allgather_bucket_size": 500000000 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0002, + "weight_decay": 0.1, + "betas": [ + 0.9, + 0.98 + ], + "eps": 1e-6 + } + }, + "activation_checkpointing": { + "partition_activations": false, + "contiguous_memory_optimization": false + }, + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/examples/t5/config/model_t5_large.sh b/examples/t5/config/model_t5_large.sh new file mode 100644 index 0000000000000000000000000000000000000000..4e97805010050a195b521b1b08fe58c6cf1ec703 --- /dev/null +++ b/examples/t5/config/model_t5_large.sh @@ -0,0 +1,14 @@ +MODEL_TYPE="t5-large" +MODEL_ARGS="--block-lm \ + --cloze-eval \ + --vocab-size 32128 \ + --num-layers 24 \ + --hidden-size 1024 \ + --inner-hidden-size 4096 \ + --num-attention-heads 16 \ + --hidden-size-per-attention-head 64 \ + --max-sequence-length 513 \ + --relative-attention-num-buckets 32 \ + --layernorm-epsilon 1e-6 \ + --tokenizer-type hf_T5Tokenizer \ + --tokenizer-model-type t5-large" \ No newline at end of file diff --git a/examples/t5/finetune_t5.py b/examples/t5/finetune_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..d64423947cdcab706524fe766f3b000c825a6d66 --- /dev/null +++ b/examples/t5/finetune_t5.py @@ -0,0 +1,83 @@ +# -*- encoding: utf-8 -*- +''' +@File : finetune_t5.py +@Time : 2021/12/11 02:39:13 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random + +import torch +import argparse +import numpy as np + +from SwissArmyTransformer import mpu, get_args, get_tokenizer +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin +from SwissArmyTransformer.training.deepspeed_training import training_main +from SwissArmyTransformer.data_utils import TSVDataset +from SwissArmyTransformer.model import T5Model + +def get_batch(data_iterator, args, timers): + # Items and their type. + keys = ['sentence', 'label'] + datatype = torch.int64 + + # Broadcast data. + timers('data loader').start() + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + timers('data loader').stop() + data = data[0].to('cuda') + attention_mask = torch.ones((1, data.shape[-1], data.shape[-1]), device=data.device) + attention_mask.tril_() + attention_mask.unsqueeze_(1) + + return data, torch.arange(data.shape[-1], device=data.device), attention_mask + + + +def forward_step(data_iterator, model, args, timers): + """Forward step.""" + + # Get the batch. + timers('batch generator').start() + input_ids, position_ids, mask = get_batch( + data_iterator, args, timers) + timers('batch generator').stop() + # Forward model. + + enc, logits, *mems = model( + enc_input_ids=input_ids, + dec_input_ids=input_ids, + enc_position_ids=position_ids, + dec_position_ids=position_ids, + dec_attention_mask=mask) + # logits, *mems = model(tokens, position_ids, attention_mask) + loss = logits.mean() + return loss, {} + +def create_dataset_function(path, args): + + return torch.utils.data.TensorDataset( + torch.ones(100000, 20, dtype=torch.long) + ) + +if __name__ == '__main__': + py_parser = argparse.ArgumentParser(add_help=False) + py_parser.add_argument('--new_hyperparam', type=str, default=None) + py_parser.add_argument('--sample_length', type=int, default=80) + py_parser.add_argument('--prefix_len', type=int, default=16) + py_parser.add_argument('--cache-dir', type=str, default='/root/some_cache', + help='hf cache') + T5Model.add_model_specific_args(py_parser) + known, args_list = py_parser.parse_known_args() + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + training_main(args, model_cls=T5Model, forward_step_function=forward_step, create_dataset_function=create_dataset_function) diff --git a/examples/t5/inference_t5.py b/examples/t5/inference_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..8286aeee411a7621fd6404980b051f9e08b84033 --- /dev/null +++ b/examples/t5/inference_t5.py @@ -0,0 +1,79 @@ +# -*- encoding: utf-8 -*- +''' +@File : inference_glm.py +@Time : 2021/10/22 19:41:58 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +from functools import partial +import os +import sys +import random +import time +from datetime import datetime +import torch +import torch.nn.functional as F +import argparse +import stat +from functools import partial + +from SwissArmyTransformer import mpu, get_args, get_tokenizer, load_checkpoint, initialize_distributed, set_random_seed + +from SwissArmyTransformer.model import T5Model +from SwissArmyTransformer.model.mixins import CachedAutoregressiveMixin +from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence, evaluate_perplexity, get_masks_and_position_ids_default +from SwissArmyTransformer.generation.sampling_strategies import BeamSearchStrategy, BaseStrategy +from SwissArmyTransformer.generation.utils import timed_name, generate_continually +from SwissArmyTransformer.training.deepspeed_training import setup_model_and_optimizer + + + +def main(args): + args.do_train = False + initialize_distributed(args) + tokenizer = get_tokenizer(args) + # build model + model = T5Model(args) + # model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) + if args.fp16: + model = model.half() + model = model.to(args.device) + load_checkpoint(model, args) + set_random_seed(args.seed) + model.eval() + + # test correctness + input_ids = tokenizer.EncodeAsIds("The <extra_id_0> walks in <extra_id_1> park").tokenization + input_ids = input_ids + [tokenizer.get_command("eos").Id] + input_ids = torch.tensor(input_ids, device='cuda', dtype=torch.long) + decoder_input_ids = tokenizer.EncodeAsIds('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>').tokenization + decoder_input_ids = decoder_input_ids + [tokenizer.get_command("eos").Id] + decoder_input_ids = torch.tensor(decoder_input_ids, device='cuda', dtype=torch.long) + + input_ids, _mask, enc_position_ids = get_masks_and_position_ids_default(input_ids) + + decoder_input_ids, dec_attention_mask, dec_position_ids = get_masks_and_position_ids_default(decoder_input_ids) + + encoder_outputs, decoder_outputs, *mems = model( + enc_input_ids=input_ids, + dec_input_ids=decoder_input_ids, + dec_attention_mask=dec_attention_mask + ) + breakpoint() + + +if __name__ == "__main__": + py_parser = argparse.ArgumentParser(add_help=False) + py_parser.add_argument('--sampling-strategy', type=str, default='BaseStrategy', + help='type name of sampling strategy') + py_parser.add_argument('--cache-dir', type=str, default='/root/some_cache', + help='hf cache') + T5Model.add_model_specific_args(py_parser) + known, args_list = py_parser.parse_known_args() + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + with torch.no_grad(): + main(args) diff --git a/examples/t5/scripts/config_t5_tmp.json b/examples/t5/scripts/config_t5_tmp.json new file mode 100644 index 0000000000000000000000000000000000000000..e87fa36ee0ddf4fdb7a0dbb3b92ca51019d912dc --- /dev/null +++ b/examples/t5/scripts/config_t5_tmp.json @@ -0,0 +1,23 @@ +{ + "train_micro_batch_size_per_gpu": 16, + "gradient_accumulation_steps": 1, + "steps_per_print": 100, + "gradient_clipping": 1.0, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0002, + "weight_decay": 0.1, + "betas": [ + 0.9, + 0.98 + ], + "eps": 1e-6 + } + }, + "activation_checkpointing": { + "partition_activations": false, + "contiguous_memory_optimization": false + }, + "wall_clock_breakdown": false + } \ No newline at end of file diff --git a/examples/t5/scripts/finetune_t5.sh b/examples/t5/scripts/finetune_t5.sh new file mode 100755 index 0000000000000000000000000000000000000000..365c6233b8fbd94add977deee25e2c4d8cb734ce --- /dev/null +++ b/examples/t5/scripts/finetune_t5.sh @@ -0,0 +1,54 @@ +#! /bin/bash + +# Change for multinode config + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=2 +MP_SIZE=1 +source $1 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +main_dir=$(dirname $script_dir) + +OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" +HOST_FILE_PATH="hostfile" +HOST_FILE_PATH="hostfile_single" + +full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin" +small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata" + +config_json="$main_dir/config/config_t5_large.json" +gpt_options=" \ + --experiment-name finetune-t5-test \ + --tokenizer-type Fake \ + --model-parallel-size ${MP_SIZE} \ + --mode finetune \ + $MODEL_ARGS \ + --train-iters 200000 \ + --resume-dataloader \ + --train-data ${small_data} \ + --split 1 \ + --distributed-backend nccl \ + --lr-decay-style cosine \ + --warmup .1 \ + --checkpoint-activations \ + --save-interval 5000 \ + --eval-interval 1000 \ + --save /root/checkpoints \ + --fp16 +" + # --load pretrained/cogview/cogview-base + + +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + + +run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} finetune_t5.py ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/examples/t5/scripts/generate_t5.sh b/examples/t5/scripts/generate_t5.sh new file mode 100755 index 0000000000000000000000000000000000000000..f2432546348e684202911259b35a6697528d1491 --- /dev/null +++ b/examples/t5/scripts/generate_t5.sh @@ -0,0 +1,37 @@ +#!/bin/bash +CHECKPOINT_PATH=/dataset/fd5061f6/sat_pretrained/t5/t5-large + +source $1 +MPSIZE=1 +MAXSEQLEN=512 +MASTER_PORT=$(shuf -n 1 -i 10000-65535) + +#SAMPLING ARGS +TEMP=0.9 +#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p +TOPK=40 +TOPP=0 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +config_json="$script_dir/config_t5_tmp.json" + +python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTER_PORT inference_t5.py \ + --mode inference \ + --model-parallel-size $MPSIZE \ + $MODEL_ARGS \ + --num-beams 4 \ + --no-repeat-ngram-size 3 \ + --length-penalty 0.7 \ + --out-seq-length $MAXSEQLEN \ + --temperature $TEMP \ + --top_k $TOPK \ + --output-path samples_glm \ + --batch-size 2 \ + --out-seq-length 200 \ + --mode inference \ + --input-source ./input.txt \ + --checkpoint-activations \ + --sampling-strategy BeamSearchStrategy \ + --load $CHECKPOINT_PATH diff --git a/examples/t5/test_t5.py b/examples/t5/test_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..41b9d2ca84320ab87a6e0eaa3c3b8912b755e014 --- /dev/null +++ b/examples/t5/test_t5.py @@ -0,0 +1,12 @@ +from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer +device = 'cuda:1' +tokenizer = T5Tokenizer.from_pretrained("t5-large") +model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-xl-lm-adapt") +model = model.to(device) +model.eval() +input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids.to(device) +decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids.to(device) +breakpoint() +output = model(input_ids=input_ids, labels=decoder_input_ids) +output.loss.backward() +a = 1 \ No newline at end of file diff --git a/setup.py b/setup.py index 636894d95dcc22a78c3c6e44697609ce9ece39e6..f106e9f7750ba16b476d2a96ad178ed3e613e4b1 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def _requirements(): setup( name="SwissArmyTransformer", - version='0.1.2', + version='0.1.3', description="A transformer-based framework with finetuning as the first class citizen.", long_description=Path("README.md").read_text(), long_description_content_type="text/markdown",