diff --git a/SwissArmyTransformer/arguments.py b/SwissArmyTransformer/arguments.py index 5eafe851c4d6064d802532c2b822efe4cefc42a1..804bae08edbd705097b8fb2828443b87066e8a9c 100755 --- a/SwissArmyTransformer/arguments.py +++ b/SwissArmyTransformer/arguments.py @@ -300,25 +300,28 @@ def get_args(args_list=None): print('using world size: {} and model-parallel size: {} '.format( args.world_size, args.model_parallel_size)) - if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_config is not None: - with open(args.deepspeed_config) as file: - deepspeed_config = json.load(file) - if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]: - args.fp16 = True - else: - args.fp16 = False + if hasattr(args, "deepspeed") and args.deepspeed: if args.checkpoint_activations: args.deepspeed_activation_checkpointing = True - if "train_micro_batch_size_per_gpu" in deepspeed_config: - args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"] - if "gradient_accumulation_steps" in deepspeed_config: - args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"] else: - args.gradient_accumulation_steps = None - if "optimizer" in deepspeed_config: - optimizer_params_config = deepspeed_config["optimizer"].get("params", {}) - args.lr = optimizer_params_config.get("lr", args.lr) - args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay) + args.deepspeed_activation_checkpointing = False + if args.deepspeed_config is not None: + with open(args.deepspeed_config) as file: + deepspeed_config = json.load(file) + if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]: + args.fp16 = True + else: + args.fp16 = False + if "train_micro_batch_size_per_gpu" in deepspeed_config: + args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"] + if "gradient_accumulation_steps" in deepspeed_config: + args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"] + else: + args.gradient_accumulation_steps = None + if "optimizer" in deepspeed_config: + optimizer_params_config = deepspeed_config["optimizer"].get("params", {}) + args.lr = optimizer_params_config.get("lr", args.lr) + args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay) return args diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py index 0c6336c66eb887792e50872d9af768445c5ce030..341a800083d5a317dfc55cde1592915a8296c4ca 100644 --- a/SwissArmyTransformer/model/encoder_decoder_model.py +++ b/SwissArmyTransformer/model/encoder_decoder_model.py @@ -64,11 +64,22 @@ class EncoderDecoderModel(torch.nn.Module): return encoder_outputs def decode(self, input_ids, position_ids, attention_mask, encoder_outputs,cross_attention_mask=None, **kw_args): + if attention_mask is None: + batch_size, seq_length = input_ids.size()[:2] + seq_ids = torch.arange(seq_length, device=input_ids.device) + attention_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + attention_mask = attention_mask.to(self.decoder.transformer.word_embeddings.weight.dtype) + attention_mask = attention_mask[:, None, :, :] # If no context, please explicitly pass ``encoder_outputs=None'' return self.decoder(input_ids, position_ids, attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args) - def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids,dec_attention_mask, *, enc_attention_mask=None, cross_attention_mask=None, **kw_args): + def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids, *, enc_attention_mask=None, dec_attention_mask=None, cross_attention_mask=None, **kw_args): # Please use self.decoder for auto-regressive generation. + batch_size, seq_length = enc_input_ids.size()[:2] + if enc_attention_mask is None: + enc_attention_mask = torch.ones(1, 1, 1, seq_length, dtype=self.encoder.transformer.word_embeddings.weight.dtype, device=enc_input_ids.device) + if cross_attention_mask is None: + cross_attention_mask = enc_attention_mask encoder_outputs = self.encode(enc_input_ids, enc_position_ids, enc_attention_mask, **kw_args) decoder_outputs, *mems = self.decode(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args) return encoder_outputs, decoder_outputs, *mems diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py index 944a133371fe591e8f20c6f47c9350040df519a5..7143950772587d88a45960668215a7c604ea1f36 100644 --- a/SwissArmyTransformer/model/t5_model.py +++ b/SwissArmyTransformer/model/t5_model.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from .mixins import BaseMixin from .encoder_decoder_model import EncoderDecoderModel from SwissArmyTransformer.mpu import get_model_parallel_world_size -from SwissArmyTransformer.mpu.transformer import standard_attention +from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim @@ -28,9 +28,11 @@ class T5LayerNorm(torch.nn.Module): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - # convert into float16 if necessary + # convert into float16 or bfloat16 if necessary if self.weight.dtype == torch.float16: hidden_states = hidden_states.to(torch.float16) + elif self.weight.dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.bfloat16) return self.weight * hidden_states @@ -111,7 +113,7 @@ class T5AttentionMixin(BaseMixin): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values - def attention_forward(self, hidden_states, mask, *args, layer_id=None, mems=None, **kw_args): + def attention_forward(self, hidden_states, mask, position_bias=None, *args, layer_id=None, mems=None, **kw_args): attn_module = self.transformer.layers[layer_id].attention seq_length = hidden_states.size(1) memory_length = mems[layer_id].size(1) if mems else 0 @@ -126,7 +128,8 @@ class T5AttentionMixin(BaseMixin): key_layer = attn_module._transpose_for_scores(mixed_key_layer) value_layer = attn_module._transpose_for_scores(mixed_value_layer) - position_bias = self.compute_bias(seq_length, memory_length + seq_length) + if position_bias is None: + position_bias = self.compute_bias(seq_length, memory_length + seq_length) context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn, log_attention_weights=position_bias, scaling_attention_score=False) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() @@ -137,9 +140,10 @@ class T5AttentionMixin(BaseMixin): if attn_module.training: output = attn_module.output_dropout(output) - return output, None + return output, position_bias - def cross_attention_forward(self, hidden_states, cross_mask, encoder_outputs, layer_id=None, *args, **kw_args): + def cross_attention_forward(self, hidden_states, cross_attention_mask, encoder_outputs, layer_id=None, *args, + **kw_args): attn_module = self.transformer.layers[layer_id].cross_attention mixed_query_layer = attn_module.query(hidden_states) mixed_x_layer = attn_module.key_value(encoder_outputs) @@ -151,7 +155,7 @@ class T5AttentionMixin(BaseMixin): key_layer = attn_module._transpose_for_scores(mixed_key_layer) value_layer = attn_module._transpose_for_scores(mixed_value_layer) - context_layer = standard_attention(query_layer, key_layer, value_layer, cross_mask, dropout_fn, + context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn, scaling_attention_score=False) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,) @@ -180,8 +184,10 @@ class T5DecoderFinalMixin(BaseMixin): class T5Model(EncoderDecoderModel): def __init__(self, args, **kwargs): - super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False, - layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu) + self.init_method_std = args.init_method_std + super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False, + layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu, + init_method=self._init_weights) self.encoder.add_mixin( "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads) ) @@ -201,7 +207,61 @@ class T5Model(EncoderDecoderModel): ) del self.decoder.transformer.position_embeddings + def _init_weights(self, weight, module, name): + init_method_std = self.init_method_std + if isinstance(module, MLP): + if name == "dense_h_to_4h": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5)) + elif name == "dense_4h_to_h": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5)) + else: + raise NotImplementedError(name) + elif isinstance(module, SelfAttention): + if name == "query_key_value": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5)) + torch.nn.init.normal_(weight[:module.inner_hidden_size], mean=0, std=init_method_std * ( + (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5)) + elif name == "dense": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5)) + else: + raise NotImplementedError(name) + elif isinstance(module, CrossAttention): + if name == "query": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * ( + (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5)) + elif name == "key_value": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5)) + elif name == "dense": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5)) + else: + raise NotImplementedError(name) + else: + raise NotImplementedError(module) + @classmethod def add_model_specific_args(cls, parser): super().add_model_specific_args(parser) parser.add_argument("--relative-attention-num-buckets", type=int, default=None) + parser.add_argument("--init-method-std", type=float, default=0.02) + + def encode(self, input_ids, attention_mask=None, **kw_args): + return super().encode(input_ids, None, attention_mask, **kw_args) + + def decode(self, input_ids, attention_mask=None, encoder_outputs=None, cross_attention_mask=None, **kw_args): + return super().decode(input_ids, None, attention_mask, encoder_outputs=encoder_outputs, + cross_attention_mask=cross_attention_mask, **kw_args) + + def forward(self, enc_input_ids, dec_input_ids, *, enc_attention_mask=None, + dec_attention_mask=None, cross_attention_mask=None, **kw_args): + batch_size, seq_length = enc_input_ids.size()[:2] + if enc_attention_mask is None: + enc_attention_mask = torch.ones(1, 1, 1, seq_length, + dtype=self.encoder.transformer.word_embeddings.weight.dtype, + device=enc_input_ids.device) + if cross_attention_mask is None: + cross_attention_mask = enc_attention_mask + encoder_outputs = self.encode(enc_input_ids, enc_attention_mask, **kw_args) + decoder_outputs, *mems = self.decode(dec_input_ids, dec_attention_mask, + encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, + **kw_args) + return encoder_outputs, decoder_outputs, *mems diff --git a/SwissArmyTransformer/mpu/layers.py b/SwissArmyTransformer/mpu/layers.py index 4af3f33180aae46bba28efa213a9a941df5cd3ce..8d7cf370e75f2629ae96b1b2f442f4dfe6bae70f 100755 --- a/SwissArmyTransformer/mpu/layers.py +++ b/SwissArmyTransformer/mpu/layers.py @@ -37,7 +37,7 @@ from .utils import VocabUtility def _initialize_affine_weight(weight, output_size, input_size, per_partition_size, partition_dim, init_method, - stride=1, return_master_weight=False): + stride=1, return_master_weight=False, module=None, name=None): """Initialize affine weight for model parallel. Build the master weight on all processes and scatter @@ -45,7 +45,7 @@ def _initialize_affine_weight(weight, output_size, input_size, # If we only use 1 process for model parallelism, bypass scatter. world_size = get_model_parallel_world_size() if world_size == 1: - init_method(weight) + init_method(weight, module=module, name=name) if return_master_weight: return weight return None @@ -54,7 +54,7 @@ def _initialize_affine_weight(weight, output_size, input_size, master_weight = torch.empty(output_size, input_size, dtype=weight.dtype, requires_grad=False) - init_method(master_weight) + init_method(master_weight, module=module, name=name) # Split and copy per_partition_per_stride_size = divide(per_partition_size, stride) @@ -200,7 +200,7 @@ class ColumnParallelLinear(torch.nn.Module): """ def __init__(self, input_size, output_size, bias=True, gather_output=True, init_method=init.xavier_normal_, stride=1, - keep_master_weight_for_test=False): + keep_master_weight_for_test=False, module=None, name=None): super(ColumnParallelLinear, self).__init__() # Keep input parameters @@ -230,7 +230,7 @@ class ColumnParallelLinear(torch.nn.Module): self.master_weight = _initialize_affine_weight( self.weight, self.output_size, self.input_size, self.output_size_per_partition, 0, init_method, - stride=stride, return_master_weight=keep_master_weight_for_test) + stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name) def forward(self, input_): # Set up backprop all-reduce. @@ -274,7 +274,7 @@ class RowParallelLinear(torch.nn.Module): def __init__(self, input_size, output_size, bias=True, input_is_parallel=False, init_method=init.xavier_normal_, stride=1, - keep_master_weight_for_test=False): + keep_master_weight_for_test=False, module=None, name=None): super(RowParallelLinear, self).__init__() # Keep input parameters @@ -303,7 +303,7 @@ class RowParallelLinear(torch.nn.Module): self.master_weight = _initialize_affine_weight( self.weight, self.output_size, self.input_size, self.input_size_per_partition, 1, init_method, - stride=stride, return_master_weight=keep_master_weight_for_test) + stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name) def forward(self, input_): # Set up backprop all-reduce. diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index c553cc21daf4e96c7ce79be41dd08e1cb5cb1c47..8afb1f9dc66151584fbb77efcb89b328a61d1b98 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -22,11 +22,12 @@ import torch import torch.nn.functional as F from apex.normalization.fused_layer_norm import FusedLayerNorm +from SwissArmyTransformer import mpu from .initialize import get_model_parallel_world_size from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from .mappings import gather_from_model_parallel_region, copy_to_model_parallel_region -from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint, get_cuda_rng_tracker +from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint from .utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu from .utils import split_tensor_along_last_dim @@ -62,7 +63,10 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, attention_probs = F.softmax(attention_scores, dim=-1) if attention_dropout is not None: - with get_cuda_rng_tracker().fork(): + if mpu.get_cuda_rng_tracker is not None: + with mpu.get_cuda_rng_tracker().fork(): + attention_probs = attention_dropout(attention_probs) + else: attention_probs = attention_dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) @@ -82,31 +86,36 @@ class SelfAttention(torch.nn.Module): self.layer_id = layer_id # Per attention head and per partition values. world_size = get_model_parallel_world_size() + self.hidden_size = hidden_size if hidden_size_per_attention_head is None: self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) else: self.hidden_size_per_attention_head = hidden_size_per_attention_head self.num_attention_heads_per_partition = divide(num_attention_heads, world_size) - inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition # Strided linear layer. self.query_key_value = ColumnParallelLinear( hidden_size, - 3 * inner_hidden_size, + 3 * self.inner_hidden_size, stride=3, gather_output=False, init_method=init_method, - bias=bias + bias=bias, + module=self, + name="query_key_value" ) self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) self.dense = RowParallelLinear( - inner_hidden_size, + self.inner_hidden_size, hidden_size, input_is_parallel=True, init_method=output_layer_init_method, - bias=bias + bias=bias, + module=self, + name="dense" ) self.output_dropout = torch.nn.Dropout(output_dropout_prob) @@ -165,21 +174,22 @@ class CrossAttention(torch.nn.Module): self.layer_id = layer_id # Per attention head and per partition values. world_size = get_model_parallel_world_size() + self.hidden_size = hidden_size if hidden_size_per_attention_head is None: self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) else: self.hidden_size_per_attention_head = hidden_size_per_attention_head self.num_attention_heads_per_partition = divide(num_attention_heads, world_size) - inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition # Strided linear layer. - self.query = ColumnParallelLinear(hidden_size, inner_hidden_size, + self.query = ColumnParallelLinear(hidden_size, self.inner_hidden_size, gather_output=False, - init_method=init_method, bias=bias) - self.key_value = ColumnParallelLinear(hidden_size, 2 * inner_hidden_size, + init_method=init_method, bias=bias, module=self, name="query") + self.key_value = ColumnParallelLinear(hidden_size, 2 * self.inner_hidden_size, stride=2, gather_output=False, - init_method=init_method, bias=bias) + init_method=init_method, bias=bias, module=self, name="key_value") # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. @@ -187,10 +197,10 @@ class CrossAttention(torch.nn.Module): # Output. self.dense = RowParallelLinear( - inner_hidden_size, + self.inner_hidden_size, hidden_size, input_is_parallel=True, - init_method=output_layer_init_method, bias=bias) + init_method=output_layer_init_method, bias=bias, module=self, name="dense") self.output_dropout = torch.nn.Dropout(output_dropout_prob) def _transpose_for_scores(self, tensor): @@ -245,22 +255,28 @@ class MLP(torch.nn.Module): output_layer_init_method = init_method self.hooks = hooks # Project to 4h. + self.hidden_size = hidden_size if inner_hidden_size is None: inner_hidden_size = 4 * hidden_size + self.inner_hidden_size = inner_hidden_size self.dense_h_to_4h = ColumnParallelLinear( - hidden_size, - inner_hidden_size, + self.hidden_size, + self.inner_hidden_size, gather_output=False, init_method=init_method, - bias=bias + bias=bias, + module=self, + name="dense_h_to_4h" ) # Project back to h. self.dense_4h_to_h = RowParallelLinear( - inner_hidden_size, - hidden_size, + self.inner_hidden_size, + self.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, - bias=bias + bias=bias, + module=self, + name="dense_4h_to_h" ) self.dropout = torch.nn.Dropout(output_dropout_prob) @@ -358,7 +374,7 @@ class BaseTransformerLayer(torch.nn.Module): hooks=hooks ) - def forward(self, hidden_states, mask, **kw_args): + def forward(self, hidden_states, mask, encoder_outputs=None, *args, **kw_args): ''' hidden_states: [batch, seq_len, hidden_size] mask: [(1, 1), seq_len, seq_len] @@ -424,6 +440,7 @@ class BaseTransformer(torch.nn.Module): use_bias=True, activation_func=gelu, layernorm=LayerNorm, + init_method=None, hooks={} ): super(BaseTransformer, self).__init__() @@ -446,8 +463,12 @@ class BaseTransformer(torch.nn.Module): torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) # create all layers - self.output_layer_init_method = scaled_init_method(init_method_std, num_layers) - self.init_method = unscaled_init_method(init_method_std) + if init_method is None: + self.output_layer_init_method = scaled_init_method(init_method_std, num_layers) + self.init_method = unscaled_init_method(init_method_std) + else: + self.output_layer_init_method = init_method + self.init_method = init_method def get_layer(layer_id): return BaseTransformerLayer( @@ -483,7 +504,7 @@ class BaseTransformer(torch.nn.Module): if attention_mask is None: attention_mask = torch.ones(1, 1, device=input_ids.device).type_as( next(self.parameters()) - ) # None means full attention + ) # None means full attention assert len(attention_mask.shape) == 2 or \ len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1 @@ -575,6 +596,7 @@ class BaseTransformer(torch.nn.Module): l, num_layers = 0, len(self.layers) chunk_length = self.checkpoint_num_layers + output_this_layer = [] while l < num_layers: args = [hidden_states, attention_mask] # flatten kw_args and output_cross_layer @@ -601,6 +623,7 @@ class BaseTransformer(torch.nn.Module): output_per_layers.extend(output_per_layers_part) l += chunk_length else: + output_this_layer = [] for i, layer in enumerate(self.layers): args = [hidden_states, attention_mask] diff --git a/SwissArmyTransformer/mpu/utils.py b/SwissArmyTransformer/mpu/utils.py index c83f501889ccabb33cce33260f7d4fee9eadcab7..2f4f29f08227269634af831b91d2fe9b301b4e9e 100755 --- a/SwissArmyTransformer/mpu/utils.py +++ b/SwissArmyTransformer/mpu/utils.py @@ -75,7 +75,7 @@ def sqrt(x): def unscaled_init_method(sigma): """Init method based on N(0, sigma).""" - def init_(tensor): + def init_(tensor, **kwargs): return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) return init_ @@ -83,7 +83,7 @@ def unscaled_init_method(sigma): def scaled_init_method(sigma, num_layers): """Init method based on N(0, sigma/sqrt(2*num_layers).""" std = sigma / math.sqrt(2.0 * num_layers) - def init_(tensor): + def init_(tensor, **kwargs): return torch.nn.init.normal_(tensor, mean=0.0, std=std) return init_ diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py index 7e0485274deeae3f095456f73a0d9c842a026e46..2279ee051c622f17991f0337206933d2a05c3653 100644 --- a/SwissArmyTransformer/tokenization/__init__.py +++ b/SwissArmyTransformer/tokenization/__init__.py @@ -29,7 +29,8 @@ def _export_vocab_size_to_args(args, original_num_tokens): print_rank_0('> padded vocab (size: {}) with {} dummy ' 'tokens (new size: {})'.format( before, after - before, after)) - args.vocab_size = after + if not args.vocab_size: + args.vocab_size = after print_rank_0("prepare tokenizer done") return tokenizer @@ -63,6 +64,10 @@ def get_tokenizer(args=None, outer_tokenizer=None): elif args.tokenizer_type == "glm_ChineseSPTokenizer": from .glm import ChineseSPTokenizer get_tokenizer.tokenizer = ChineseSPTokenizer(args.tokenizer_model_type, **kwargs) + elif args.tokenizer_type.startswith('hf'): + from .hf_tokenizer import HFT5Tokenizer + if args.tokenizer_type == "hf_T5Tokenizer": + get_tokenizer.tokenizer = HFT5Tokenizer(args.tokenizer_model_type, cache_dir=args.cache_dir) else: assert args.vocab_size > 0 get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size) diff --git a/SwissArmyTransformer/tokenization/glm/tokenization.py b/SwissArmyTransformer/tokenization/glm/tokenization.py index 67be818a0f1d42b0d35696c24577b5e5391ff3b4..9b9a8abc0202b46a95efc8a8338e90e0da9fd60f 100644 --- a/SwissArmyTransformer/tokenization/glm/tokenization.py +++ b/SwissArmyTransformer/tokenization/glm/tokenization.py @@ -312,11 +312,11 @@ class Tokenizer(object): tokenization.tokenization = [self.IdToToken(idx) for idx in tokenization.tokenization] return tokenization - def IdToToken(self, Id): + def IdToToken(self, idx): """convert Id to token accounting for command tokens""" - if isinstance(Id, CommandToken): - return Id.token - return self.tokens[Id] + if isinstance(idx, CommandToken): + return idx.token + return self.tokens[idx] def TokenToId(self, token): """convert token to Id accounting for command tokens""" @@ -324,16 +324,16 @@ class Tokenizer(object): return token.Id return self.vocab[token] - def DecodeIds(self, Ids): + def DecodeIds(self, ids): """ convert Ids to tokens accounting for command tokens, tokens are joined and returned as a string. """ rtn_strs = [] current_str = [] - if isinstance(Ids, Tokenization): - Ids = Ids.tokenization - for Id in Ids: + if isinstance(ids, Tokenization): + ids = ids.tokenization + for Id in ids: if isinstance(Id, CommandToken): rtn_strs.append(self._decode(current_str)) current_str = [] @@ -353,11 +353,11 @@ class Tokenizer(object): output = self.clean_up_tokenization(output) return output - def DecodeTokens(self, Tokens): + def DecodeTokens(self, tokens): """ convert tokens to a string accounting for command and type tokens. """ - Ids = [self.TokenToId(token) for token in Tokens] + Ids = [self.TokenToId(token) for token in tokens] return self.DecodeIds(Ids) diff --git a/SwissArmyTransformer/tokenization/hf_tokenizer.py b/SwissArmyTransformer/tokenization/hf_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e40adb47a4ac8cb876132dc5cefc27f824442baf --- /dev/null +++ b/SwissArmyTransformer/tokenization/hf_tokenizer.py @@ -0,0 +1,80 @@ +from transformers import T5Tokenizer +from .glm.tokenization import Tokenization, CommandToken + + +PRETRAINED_VOCAB_FILES_MAP = { + "t5-small": "/dataset/fd5061f6/yanan/huggingface_models/t5-small", + "t5-base": "/dataset/fd5061f6/yanan/huggingface_models/t5-base", + "t5-large": "/mnt/t5", + "t5-3b": "/dataset/fd5061f6/yanan/huggingface_models/t5-3b", + "t5-11b": "/dataset/fd5061f6/yanan/huggingface_models/t5-11b" +} + +class HFTokenizer: + def __init__(self, model_cls, model_type_or_path=None, cache_dir=None, command_tokens=None): + if model_type_or_path in PRETRAINED_VOCAB_FILES_MAP: + model_type_or_path = PRETRAINED_VOCAB_FILES_MAP[model_type_or_path] + self.text_tokenizer = model_cls.from_pretrained(model_type_or_path, cache_dir=cache_dir) + self.num_tokens = len(self.text_tokenizer) + self._command_tokens = [] + self.command_name_map = {} + self.command_token_map = {} + self.command_id_map = {} + + def __len__(self): + return len(self.text_tokenizer) + + @property + def command_tokens(self): + return self._command_tokens + + @command_tokens.setter + def command_tokens(self, command_tokens): + self._command_tokens = command_tokens + self.command_name_map = {tok.name: tok for tok in self.command_tokens} + self.command_token_map = {tok.token: tok for tok in self.command_tokens} + self.command_id_map = {tok.Id: tok for tok in self.command_tokens} + + def get_command(self, name): + """get command token corresponding to `name`""" + return self.command_name_map[name] + + def EncodeAsIds(self, text, process_fn=None): + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + ids = self.text_tokenizer.encode(processed_text, add_special_tokens=False) + tokenization = Tokenization(ids, processed_text, text) + return tokenization + + def DecodeIds(self, ids): + if isinstance(ids, Tokenization): + ids = ids.tokenization + return self.text_tokenizer.decode(ids) + + def DecodeTokens(self, tokens): + return self.text_tokenizer.convert_tokens_to_string(tokens) + + def IdToToken(self, Id): + if isinstance(Id, CommandToken): + return Id.token + return self.text_tokenizer.convert_ids_to_tokens(Id) + + def TokenToId(self, token): + if isinstance(token, CommandToken): + return token.Id + return self.text_tokenizer.convert_tokens_to_ids(token) + + +class HFT5Tokenizer(HFTokenizer): + def __init__(self, model_type_or_path=None, cache_dir=None): + super().__init__(T5Tokenizer, model_type_or_path=model_type_or_path, cache_dir=cache_dir) + command_tokens = [ + CommandToken('eos', '</s>', self.TokenToId("</s>")), + CommandToken('pad', '<pad>', self.TokenToId("<pad>")), + CommandToken('sop', '<pad>', self.TokenToId("<pad>")) + ] + for i in range(100): + command_tokens.append(CommandToken(f'MASK{i}', f'<extra_id_{i}>', self.TokenToId(f'<extra_id_{i}>'))) + self.command_tokens = command_tokens + diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py index d9059ac4d1f282e3792b3a877af76d81c8531b7c..beacffc39d8dfce4609c25fcf40c37c1d78e6c6e 100644 --- a/SwissArmyTransformer/training/deepspeed_training.py +++ b/SwissArmyTransformer/training/deepspeed_training.py @@ -144,6 +144,8 @@ def get_model(args, model_cls): if args.fp16: model.half() + elif args.bf16: + model.bfloat16() model.cuda(torch.cuda.current_device()) return model @@ -546,7 +548,8 @@ def initialize_distributed(args): # Optional DeepSpeed Activation Checkpointing Features if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing: set_deepspeed_activation_checkpointing(args) # TODO manual model-parallel seed - + else: + mpu.get_cuda_rng_tracker = None def set_random_seed(seed): """Set random seed for reproducability.""" diff --git a/examples/t5/config/config_t5_large.json b/examples/t5/config/config_t5_large.json new file mode 100644 index 0000000000000000000000000000000000000000..25d7bf7ac71485cda4fba8500944712a06b775ab --- /dev/null +++ b/examples/t5/config/config_t5_large.json @@ -0,0 +1,34 @@ +{ + "train_micro_batch_size_per_gpu": 16, + "gradient_accumulation_steps": 1, + "steps_per_print": 100, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 2, + "contiguous_gradients": false, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 50000000, + "allgather_bucket_size": 500000000 + }, + "bfloat16": { + "enabled": true + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0002, + "weight_decay": 0.1, + "betas": [ + 0.9, + 0.98 + ], + "eps": 1e-6 + } + }, + "activation_checkpointing": { + "partition_activations": false, + "contiguous_memory_optimization": false + }, + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/examples/t5/config/model_t5_large.sh b/examples/t5/config/model_t5_large.sh new file mode 100644 index 0000000000000000000000000000000000000000..0c20d559447a87030f5d07a701d7a125b4f79c97 --- /dev/null +++ b/examples/t5/config/model_t5_large.sh @@ -0,0 +1,15 @@ +MODEL_TYPE="t5-large" +MODEL_ARGS="--block-lm \ + --cloze-eval \ + --vocab-size 32128 \ + --num-layers 24 \ + --hidden-size 1024 \ + --inner-hidden-size 4096 \ + --num-attention-heads 16 \ + --hidden-size-per-attention-head 64 \ + --max-sequence-length 513 \ + --relative-attention-num-buckets 32 \ + --layernorm-epsilon 1e-6 \ + --tokenizer-type hf_T5Tokenizer \ + --tokenizer-model-type t5-large \ + --load ${CHECKPOINT_PATH}/glm-large-en-blank" \ No newline at end of file diff --git a/examples/t5/inference_t5.py b/examples/t5/inference_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..b902f248e0c6ebe6dfccba3805d4112923eff6fa --- /dev/null +++ b/examples/t5/inference_t5.py @@ -0,0 +1,218 @@ +# -*- encoding: utf-8 -*- +''' +@File : inference_glm.py +@Time : 2021/10/22 19:41:58 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +from functools import partial +import os +import sys +import random +import time +from datetime import datetime +import torch +import torch.nn.functional as F +import argparse +import stat +from functools import partial + +from SwissArmyTransformer import mpu, get_args, get_tokenizer, load_checkpoint, initialize_distributed, set_random_seed + +from SwissArmyTransformer.model import T5Model +from SwissArmyTransformer.model.mixins import CachedAutoregressiveMixin +from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence, evaluate_perplexity +from SwissArmyTransformer.generation.sampling_strategies import BeamSearchStrategy, BaseStrategy +from SwissArmyTransformer.generation.utils import timed_name, generate_continually +from SwissArmyTransformer.training.deepspeed_training import setup_model_and_optimizer + + +def decoder_shift_right(input_ids, args): + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = args.decoder_start_token_id + return shifted_input_ids + + +def get_batch(data, args): + keys = ['text', 'loss_mask', 'target', 'attention_mask'] + datatype = torch.int64 + + # Broadcast data. + data_b = mpu.broadcast_data(keys, data, datatype) + # Unpack. + tokens = data_b['text'].long() + labels = data_b['target'].long() + decoder_tokens = decoder_shift_right(labels, args) + attention_mask = data_b['attention_mask'].long() + loss_mask = data_b['loss_mask'].float() + + # Convert + if args.fp16: + attention_mask = attention_mask.half() + elif args.bf16: + attention_mask = attention_mask.bfloat16() + return tokens, decoder_tokens, labels, loss_mask, attention_mask + + +def get_masks_and_position_ids_glm(seq, mask_position, context_length): + tokens = seq.unsqueeze(0) + + attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device) + attention_mask.tril_() + attention_mask[..., :context_length] = 1 + attention_mask.unsqueeze_(1) + + position_ids = torch.zeros(2, len(seq), device=tokens.device, dtype=torch.long) + torch.arange(0, context_length, out=position_ids[0, :context_length]) + position_ids[0, context_length:] = mask_position + torch.arange(1, len(seq) - context_length + 1, out=position_ids[1, context_length:]) + + position_ids = position_ids.unsqueeze(0) + return tokens, attention_mask, position_ids + + +def main(args): + args.do_train = False + initialize_distributed(args) + tokenizer = get_tokenizer(args) + # load_checkpoint(model, args) + set_random_seed(args.seed) + + # Model, optimizer, and learning rate. + model_cls = T5Model + model, optimizer = setup_model_and_optimizer(args, model_cls=model_cls) + + missing_keys, unexpected_keys = model.module.load_state_dict( + torch.load("/dataset/fd5061f6/yanan/huggingface_models/t5-large/model_states.pt")["module"]) + optimizer.refresh_fp32_params() + model.eval() + input_ids = tokenizer.EncodeAsIds("The <extra_id_0> walks in <extra_id_1> park").tokenization + input_ids = input_ids + [tokenizer.get_command("eos").Id] + input_ids = torch.LongTensor([input_ids]) + decoder_input_ids = tokenizer.EncodeAsIds('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>').tokenization + decoder_input_ids = decoder_input_ids + [tokenizer.get_command("eos").Id] + decoder_input_ids = torch.LongTensor([decoder_input_ids]) + data = {'text': input_ids, 'loss_mask': input_ids.new_ones(input_ids.shape), 'target': decoder_input_ids, + 'attention_mask': input_ids.new_ones(input_ids.shape)} + tokens, decoder_tokens, labels, loss_mask, attention_mask = get_batch(data, args) + encoder_outputs, logits, *_ = model(enc_input_ids=tokens, dec_input_ids=decoder_tokens, + enc_attention_mask=attention_mask) + losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels) + loss_mask = loss_mask.view(-1) + loss = torch.sum(losses.view(-1) * loss_mask) + if loss_mask.sum().item() > 0: + loss = loss / loss_mask.sum() + loss.backward() + + breakpoint() + + end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id] + # define function for each query + if args.sampling_strategy == 'BaseStrategy': + strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens) + elif args.sampling_strategy == 'BeamSearchStrategy': + strategy = BeamSearchStrategy(args.batch_size, length_penalty=args.length_penalty, consider_end=True, + end_tokens=end_tokens, no_repeat_ngram_size=args.no_repeat_ngram_size, + min_tgt_length=args.min_tgt_length) + else: + raise ValueError(f'unknown strategy {args.sampling_strategy}') + + def process(raw_text): + if args.with_id: + query_id, raw_text = raw_text.split('\t') + # add MASK + generation_mask = '[gMASK]' if args.task_mask else '[MASK]' + if 'MASK]' not in raw_text: + raw_text += ' ' + generation_mask + seq = tokenizer.EncodeAsIds(raw_text).tokenization + seq = [tokenizer.get_command('ENC').Id] + seq + if not raw_text.endswith('MASK]'): + seq = seq + [tokenizer.get_command('eos').Id] + print('raw text: {}\n'.format(raw_text)) + if len(seq) > args.max_sequence_length: + raise ValueError('text too long.') + + # generation + mbz = args.max_inference_batch_size + assert args.batch_size < mbz or args.batch_size % mbz == 0 + output_list = [seq] + # continually detect the first mark position + while True: + seq = output_list[0] # TODO find the best one + # detect + mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK'] + mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens] + mask_position = len(seq) + for token in mask_tokens: + try: + mask_position = min(mask_position, seq.index(token)) + except ValueError: + pass + if mask_position == len(seq): + break + + get_func = partial(get_masks_and_position_ids_glm, mask_position=mask_position, context_length=len(seq)) + output_list = [] + for tim in range(max(args.batch_size // mbz, 1)): + input_seq = torch.cuda.LongTensor( + seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length - len(seq) - 1), + device=args.device) + output = filling_sequence(model, input_seq, + batch_size=min(args.batch_size, mbz), + strategy=strategy, + log_attention_weights=None, + get_masks_and_position_ids=get_func + )[0] # we don't use mems, fill back + if isinstance(output, torch.Tensor): # different strategies + output = list(output) + + output_list.extend(output) + + # clip -1s and fill back generated things into seq + for i in range(len(output_list)): + output = output_list[i].tolist() + try: + unfinished = output.index(-1) + except ValueError: + unfinished = len(output) + if output[unfinished - 1] in end_tokens: + unfinished -= 1 + bog = output.index(tokenizer.get_command('sop').Id) + output_list[i] = output[:mask_position] + output[bog + 1:unfinished] + output[mask_position + 1:bog] + + # decoding + txts = [] + for seq in output_list: + decode_tokens = tokenizer.DecodeIds(seq) + txts.append(decode_tokens) + + # save + if args.with_id: + full_path = os.path.join(args.output_path, query_id + '.txt') + else: + prefix = raw_text.replace('/', '')[:20] + full_path = timed_name(prefix, '.txt', args.output_path) + print(txts[0]) # print the first. + with open(full_path, 'w') as fout: + for txt in txts: + fout.write(txt + '\n') + os.chmod(full_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU) + + os.makedirs(args.output_path, exist_ok=True) + generate_continually(process, args.input_source) + + +if __name__ == "__main__": + py_parser = argparse.ArgumentParser(add_help=False) + py_parser.add_argument('--sampling-strategy', type=str, default='BaseStrategy', + help='type name of sampling strategy') + T5Model.add_model_specific_args(py_parser) + known, args_list = py_parser.parse_known_args() + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + with torch.no_grad(): + main(args) diff --git a/examples/t5/scripts/generate_t5.sh b/examples/t5/scripts/generate_t5.sh new file mode 100644 index 0000000000000000000000000000000000000000..c4bc602d8babfca6a51b29c99e9ccfdd1d209558 --- /dev/null +++ b/examples/t5/scripts/generate_t5.sh @@ -0,0 +1,38 @@ +#!/bin/bash +CHECKPOINT_PATH=/dataset/fd5061f6/sat_pretrained/glm + +source $1 +MPSIZE=1 +MAXSEQLEN=512 +MASTER_PORT=$(shuf -n 1 -i 10000-65535) + +#SAMPLING ARGS +TEMP=0.9 +#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p +TOPK=40 +TOPP=0 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +config_json="$script_dir/config_t5_large.json" + +python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTER_PORT inference_t5.py \ + --deepspeed \ + --deepspeed-config ${config_json} \ + --mode inference \ + --model-parallel-size $MPSIZE \ + $MODEL_ARGS \ + --num-beams 4 \ + --no-repeat-ngram-size 3 \ + --length-penalty 0.7 \ + --out-seq-length $MAXSEQLEN \ + --temperature $TEMP \ + --top_k $TOPK \ + --output-path samples_glm \ + --batch-size 2 \ + --out-seq-length 200 \ + --mode inference \ + --input-source ./input.txt \ + --checkpoint-activations \ + --sampling-strategy BeamSearchStrategy diff --git a/examples/t5/test_t5.py b/examples/t5/test_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..692ed918f39ddebbcbc23cbf194e0999e2a44aa6 --- /dev/null +++ b/examples/t5/test_t5.py @@ -0,0 +1,10 @@ +from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer +tokenizer = T5Tokenizer.from_pretrained("t5-large") +model = T5Model.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large") +model = model.to('cuda') +model.eval() +input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids.to('cuda') +decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids.to('cuda') +output = model(input_ids=input_ids, labels=decoder_input_ids) +output.loss.backward() +breakpoint() \ No newline at end of file