diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py index 531ec54b0062fc616d46a42652a8d486a4659e33..7143950772587d88a45960668215a7c604ea1f36 100644 --- a/SwissArmyTransformer/model/t5_model.py +++ b/SwissArmyTransformer/model/t5_model.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from .mixins import BaseMixin from .encoder_decoder_model import EncoderDecoderModel from SwissArmyTransformer.mpu import get_model_parallel_world_size -from SwissArmyTransformer.mpu.transformer import standard_attention +from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim @@ -28,9 +28,11 @@ class T5LayerNorm(torch.nn.Module): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - # convert into float16 if necessary + # convert into float16 or bfloat16 if necessary if self.weight.dtype == torch.float16: hidden_states = hidden_states.to(torch.float16) + elif self.weight.dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.bfloat16) return self.weight * hidden_states @@ -111,7 +113,7 @@ class T5AttentionMixin(BaseMixin): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values - def attention_forward(self, hidden_states, mask, *args, layer_id=None, mems=None, **kw_args): + def attention_forward(self, hidden_states, mask, position_bias=None, *args, layer_id=None, mems=None, **kw_args): attn_module = self.transformer.layers[layer_id].attention seq_length = hidden_states.size(1) memory_length = mems[layer_id].size(1) if mems else 0 @@ -126,7 +128,8 @@ class T5AttentionMixin(BaseMixin): key_layer = attn_module._transpose_for_scores(mixed_key_layer) value_layer = attn_module._transpose_for_scores(mixed_value_layer) - position_bias = self.compute_bias(seq_length, memory_length + seq_length) + if position_bias is None: + position_bias = self.compute_bias(seq_length, memory_length + seq_length) context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn, log_attention_weights=position_bias, scaling_attention_score=False) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() @@ -137,9 +140,10 @@ class T5AttentionMixin(BaseMixin): if attn_module.training: output = attn_module.output_dropout(output) - return output, None + return output, position_bias - def cross_attention_forward(self, hidden_states, cross_mask, encoder_outputs, layer_id=None, *args, **kw_args): + def cross_attention_forward(self, hidden_states, cross_attention_mask, encoder_outputs, layer_id=None, *args, + **kw_args): attn_module = self.transformer.layers[layer_id].cross_attention mixed_query_layer = attn_module.query(hidden_states) mixed_x_layer = attn_module.key_value(encoder_outputs) @@ -151,7 +155,7 @@ class T5AttentionMixin(BaseMixin): key_layer = attn_module._transpose_for_scores(mixed_key_layer) value_layer = attn_module._transpose_for_scores(mixed_value_layer) - context_layer = standard_attention(query_layer, key_layer, value_layer, cross_mask, dropout_fn, + context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn, scaling_attention_score=False) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,) @@ -180,8 +184,10 @@ class T5DecoderFinalMixin(BaseMixin): class T5Model(EncoderDecoderModel): def __init__(self, args, **kwargs): + self.init_method_std = args.init_method_std super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False, - layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu) + layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu, + init_method=self._init_weights) self.encoder.add_mixin( "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads) ) @@ -201,10 +207,42 @@ class T5Model(EncoderDecoderModel): ) del self.decoder.transformer.position_embeddings + def _init_weights(self, weight, module, name): + init_method_std = self.init_method_std + if isinstance(module, MLP): + if name == "dense_h_to_4h": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5)) + elif name == "dense_4h_to_h": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5)) + else: + raise NotImplementedError(name) + elif isinstance(module, SelfAttention): + if name == "query_key_value": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5)) + torch.nn.init.normal_(weight[:module.inner_hidden_size], mean=0, std=init_method_std * ( + (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5)) + elif name == "dense": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5)) + else: + raise NotImplementedError(name) + elif isinstance(module, CrossAttention): + if name == "query": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * ( + (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5)) + elif name == "key_value": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5)) + elif name == "dense": + torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5)) + else: + raise NotImplementedError(name) + else: + raise NotImplementedError(module) + @classmethod def add_model_specific_args(cls, parser): super().add_model_specific_args(parser) parser.add_argument("--relative-attention-num-buckets", type=int, default=None) + parser.add_argument("--init-method-std", type=float, default=0.02) def encode(self, input_ids, attention_mask=None, **kw_args): return super().encode(input_ids, None, attention_mask, **kw_args) diff --git a/SwissArmyTransformer/mpu/layers.py b/SwissArmyTransformer/mpu/layers.py index 4af3f33180aae46bba28efa213a9a941df5cd3ce..8d7cf370e75f2629ae96b1b2f442f4dfe6bae70f 100755 --- a/SwissArmyTransformer/mpu/layers.py +++ b/SwissArmyTransformer/mpu/layers.py @@ -37,7 +37,7 @@ from .utils import VocabUtility def _initialize_affine_weight(weight, output_size, input_size, per_partition_size, partition_dim, init_method, - stride=1, return_master_weight=False): + stride=1, return_master_weight=False, module=None, name=None): """Initialize affine weight for model parallel. Build the master weight on all processes and scatter @@ -45,7 +45,7 @@ def _initialize_affine_weight(weight, output_size, input_size, # If we only use 1 process for model parallelism, bypass scatter. world_size = get_model_parallel_world_size() if world_size == 1: - init_method(weight) + init_method(weight, module=module, name=name) if return_master_weight: return weight return None @@ -54,7 +54,7 @@ def _initialize_affine_weight(weight, output_size, input_size, master_weight = torch.empty(output_size, input_size, dtype=weight.dtype, requires_grad=False) - init_method(master_weight) + init_method(master_weight, module=module, name=name) # Split and copy per_partition_per_stride_size = divide(per_partition_size, stride) @@ -200,7 +200,7 @@ class ColumnParallelLinear(torch.nn.Module): """ def __init__(self, input_size, output_size, bias=True, gather_output=True, init_method=init.xavier_normal_, stride=1, - keep_master_weight_for_test=False): + keep_master_weight_for_test=False, module=None, name=None): super(ColumnParallelLinear, self).__init__() # Keep input parameters @@ -230,7 +230,7 @@ class ColumnParallelLinear(torch.nn.Module): self.master_weight = _initialize_affine_weight( self.weight, self.output_size, self.input_size, self.output_size_per_partition, 0, init_method, - stride=stride, return_master_weight=keep_master_weight_for_test) + stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name) def forward(self, input_): # Set up backprop all-reduce. @@ -274,7 +274,7 @@ class RowParallelLinear(torch.nn.Module): def __init__(self, input_size, output_size, bias=True, input_is_parallel=False, init_method=init.xavier_normal_, stride=1, - keep_master_weight_for_test=False): + keep_master_weight_for_test=False, module=None, name=None): super(RowParallelLinear, self).__init__() # Keep input parameters @@ -303,7 +303,7 @@ class RowParallelLinear(torch.nn.Module): self.master_weight = _initialize_affine_weight( self.weight, self.output_size, self.input_size, self.input_size_per_partition, 1, init_method, - stride=stride, return_master_weight=keep_master_weight_for_test) + stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name) def forward(self, input_): # Set up backprop all-reduce. diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index 855e093ada69e1168b565f78f15eeaee313326e5..cab65a4623060d3f362a6edc31d8d99a51e734ee 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -22,11 +22,12 @@ import torch import torch.nn.functional as F from apex.normalization.fused_layer_norm import FusedLayerNorm +from SwissArmyTransformer import mpu from .initialize import get_model_parallel_world_size from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from .mappings import gather_from_model_parallel_region, copy_to_model_parallel_region -from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint, get_cuda_rng_tracker +from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint from .utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu from .utils import split_tensor_along_last_dim @@ -62,7 +63,10 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, attention_probs = F.softmax(attention_scores, dim=-1) if attention_dropout is not None: - with get_cuda_rng_tracker().fork(): + if mpu.get_cuda_rng_tracker is not None: + with mpu.get_cuda_rng_tracker().fork(): + attention_probs = attention_dropout(attention_probs) + else: attention_probs = attention_dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) @@ -82,31 +86,36 @@ class SelfAttention(torch.nn.Module): self.layer_id = layer_id # Per attention head and per partition values. world_size = get_model_parallel_world_size() + self.hidden_size = hidden_size if hidden_size_per_attention_head is None: self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) else: self.hidden_size_per_attention_head = hidden_size_per_attention_head self.num_attention_heads_per_partition = divide(num_attention_heads, world_size) - inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition # Strided linear layer. self.query_key_value = ColumnParallelLinear( hidden_size, - 3 * inner_hidden_size, + 3 * self.inner_hidden_size, stride=3, gather_output=False, init_method=init_method, - bias=bias + bias=bias, + module=self, + name="query_key_value" ) self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) self.dense = RowParallelLinear( - inner_hidden_size, + self.inner_hidden_size, hidden_size, input_is_parallel=True, init_method=output_layer_init_method, - bias=bias + bias=bias, + module=self, + name="dense" ) self.output_dropout = torch.nn.Dropout(output_dropout_prob) @@ -122,7 +131,7 @@ class SelfAttention(torch.nn.Module): def forward(self, hidden_states, mask, *args, **kw_args): if 'attention_forward' in self.hooks: - return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id) + return self.hooks['attention_forward'](hidden_states, mask, *args, **kw_args, layer_id=self.layer_id) else: mixed_raw_layer = self.query_key_value(hidden_states) (mixed_query_layer, @@ -144,7 +153,7 @@ class SelfAttention(torch.nn.Module): if self.training: output = self.output_dropout(output) - return output, None + return output class CrossAttention(torch.nn.Module): @@ -160,21 +169,22 @@ class CrossAttention(torch.nn.Module): self.layer_id = layer_id # Per attention head and per partition values. world_size = get_model_parallel_world_size() + self.hidden_size = hidden_size if hidden_size_per_attention_head is None: self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) else: self.hidden_size_per_attention_head = hidden_size_per_attention_head self.num_attention_heads_per_partition = divide(num_attention_heads, world_size) - inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition # Strided linear layer. - self.query = ColumnParallelLinear(hidden_size, inner_hidden_size, + self.query = ColumnParallelLinear(hidden_size, self.inner_hidden_size, gather_output=False, - init_method=init_method, bias=bias) - self.key_value = ColumnParallelLinear(hidden_size, 2 * inner_hidden_size, + init_method=init_method, bias=bias, module=self, name="query") + self.key_value = ColumnParallelLinear(hidden_size, 2 * self.inner_hidden_size, stride=2, gather_output=False, - init_method=init_method, bias=bias) + init_method=init_method, bias=bias, module=self, name="key_value") # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. @@ -182,10 +192,10 @@ class CrossAttention(torch.nn.Module): # Output. self.dense = RowParallelLinear( - inner_hidden_size, + self.inner_hidden_size, hidden_size, input_is_parallel=True, - init_method=output_layer_init_method, bias=bias) + init_method=output_layer_init_method, bias=bias, module=self, name="dense") self.output_dropout = torch.nn.Dropout(output_dropout_prob) def _transpose_for_scores(self, tensor): @@ -198,7 +208,7 @@ class CrossAttention(torch.nn.Module): tensor = tensor.view(*new_tensor_shape) return tensor.permute(0, 2, 1, 3) - def forward(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args): + def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *args, **kw_args): # hidden_states: [b, s, h] if 'cross_attention_forward' in self.hooks: return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, @@ -240,22 +250,28 @@ class MLP(torch.nn.Module): output_layer_init_method = init_method self.hooks = hooks # Project to 4h. + self.hidden_size = hidden_size if inner_hidden_size is None: inner_hidden_size = 4 * hidden_size + self.inner_hidden_size = inner_hidden_size self.dense_h_to_4h = ColumnParallelLinear( - hidden_size, - inner_hidden_size, + self.hidden_size, + self.inner_hidden_size, gather_output=False, init_method=init_method, - bias=bias + bias=bias, + module=self, + name="dense_h_to_4h" ) # Project back to h. self.dense_4h_to_h = RowParallelLinear( - inner_hidden_size, - hidden_size, + self.inner_hidden_size, + self.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, - bias=bias + bias=bias, + module=self, + name="dense_4h_to_h" ) self.dropout = torch.nn.Dropout(output_dropout_prob) @@ -353,7 +369,7 @@ class BaseTransformerLayer(torch.nn.Module): hooks=hooks ) - def forward(self, hidden_states, mask, **kw_args): + def forward(self, hidden_states, mask, encoder_outputs=None, *args, **kw_args): ''' hidden_states: [batch, seq_len, hidden_size] mask: [(1, 1), seq_len, seq_len] @@ -362,7 +378,7 @@ class BaseTransformerLayer(torch.nn.Module): # Layer norm at the begining of the transformer layer. layernorm_output1 = self.input_layernorm(hidden_states) # Self attention. - attention_output, output_this_layer = self.attention(layernorm_output1, mask, **kw_args) + attention_output, *output_this_layer = self.attention(layernorm_output1, mask, *args, **kw_args) # Third LayerNorm if self.sandwich_ln: @@ -374,15 +390,14 @@ class BaseTransformerLayer(torch.nn.Module): layernorm_output = self.post_attention_layernorm(layernorm_input) # only for Encoder-Decoder, omit this for BERT-like or GPT-like models - if self.is_decoder and \ - 'encoder_outputs' in kw_args and kw_args['encoder_outputs'] is not None: - assert 'cross_attention_mask' in kw_args - # Cross attention - attention_output = self.cross_attention(layernorm_output, **kw_args) - # Residual connection. - layernorm_input = layernorm_input + attention_output - # Layer norm post the cross attention - layernorm_output = self.post_cross_attention_layernorm(layernorm_input) + if self.is_decoder and encoder_outputs is not None: + assert 'cross_attention_mask' in kw_args + # Cross attention + attention_output = self.cross_attention(layernorm_output, encoder_outputs=encoder_outputs, **kw_args) + # Residual connection. + layernorm_input = layernorm_input + attention_output + # Layer norm post the cross attention + layernorm_output = self.post_cross_attention_layernorm(layernorm_input) # MLP. mlp_output = self.mlp(layernorm_output, **kw_args) @@ -394,7 +409,7 @@ class BaseTransformerLayer(torch.nn.Module): # Second residual connection. output = layernorm_input + mlp_output - return output, output_this_layer # temporally, output_this_layer is only from attention + return output, *output_this_layer # temporally, output_this_layer is only from attention class BaseTransformer(torch.nn.Module): @@ -419,6 +434,7 @@ class BaseTransformer(torch.nn.Module): use_bias=True, activation_func=gelu, layernorm=LayerNorm, + init_method=None, hooks={} ): super(BaseTransformer, self).__init__() @@ -441,8 +457,12 @@ class BaseTransformer(torch.nn.Module): torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) # create all layers - self.output_layer_init_method = scaled_init_method(init_method_std, num_layers) - self.init_method = unscaled_init_method(init_method_std) + if init_method is None: + self.output_layer_init_method = scaled_init_method(init_method_std, num_layers) + self.init_method = unscaled_init_method(init_method_std) + else: + self.output_layer_init_method = init_method + self.init_method = init_method def get_layer(layer_id): return BaseTransformerLayer( @@ -470,7 +490,7 @@ class BaseTransformer(torch.nn.Module): # Final layer norm before output. self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) - def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, + def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, encoder_outputs=None, output_hidden_states=False, **kw_args): # sanity check assert len(input_ids.shape) == 2 @@ -478,7 +498,7 @@ class BaseTransformer(torch.nn.Module): if attention_mask is None: attention_mask = torch.ones(1, 1, device=input_ids.device).type_as( next(self.parameters()) - ) # None means full attention + ) # None means full attention assert len(attention_mask.shape) == 2 or \ len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1 assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor) @@ -513,23 +533,19 @@ class BaseTransformer(torch.nn.Module): def custom(start, end): def custom_forward(*inputs): layers_ = self.layers[start:end] - x_, mask = inputs[0], inputs[1] - if len(inputs) > 2: # have branch_input - branch_ = inputs[2] + x_, mask, encoder_outputs_ = inputs[0], inputs[1], inputs[2] + output_this_layer = inputs[3:] output_per_layers_part = [] for i, layer in enumerate(layers_): - if len(inputs) > 2: - x_, branch_, output_this_layer = self.hooks['layer_forward']( - x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args - ) - elif 'layer_forward' in self.hooks: - x_, output_this_layer = self.hooks['layer_forward']( - x_, mask, layer_id=layer.layer_id, **kw_args + if 'layer_forward' in self.hooks: + x_, *output_this_layer = self.hooks['layer_forward']( + x_, mask, encoder_outputs_, *output_this_layer, layer_id=layer.layer_id, **kw_args ) else: - x_, output_this_layer = layer(x_, mask, **kw_args) + x_, *output_this_layer = layer(x_, mask, encoder_outputs_, *output_this_layer, **kw_args) output_per_layers_part.append(output_this_layer) - return x_, output_per_layers_part + return x_, output_per_layers_part, *output_this_layer + return custom_forward # prevent to lose requires_grad in checkpointing. @@ -539,25 +555,24 @@ class BaseTransformer(torch.nn.Module): l, num_layers = 0, len(self.layers) chunk_length = self.checkpoint_num_layers + output_this_layer = [] while l < num_layers: - args = [hidden_states, attention_mask] - if branch_input is not None: - hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, branch_input) - else: - hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args) + args = [hidden_states, attention_mask, encoder_outputs] + hidden_states, output_per_layers_part, *output_this_layer = checkpoint(custom(l, l + chunk_length), + *args, *output_this_layer) if output_hidden_states: hidden_states_outputs.append(hidden_states) output_per_layers.extend(output_per_layers_part) l += chunk_length else: + output_this_layer = [] for i, layer in enumerate(self.layers): - args = [hidden_states, attention_mask] - if branch_input is not None: # customized layer_forward with branch_input - hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_args) - elif 'layer_forward' in self.hooks: # customized layer_forward - hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args) + args = [hidden_states, attention_mask, encoder_outputs] + if 'layer_forward' in self.hooks: # customized layer_forward + hidden_states, *output_this_layer = self.hooks['layer_forward'](*args, *output_this_layer, + layer_id=torch.tensor(i), **kw_args) else: - hidden_states, output_this_layer = layer(*args, **kw_args) + hidden_states, *output_this_layer = layer(*args, *output_this_layer, **kw_args) if output_hidden_states: hidden_states_outputs.append(hidden_states) output_per_layers.append(output_this_layer) diff --git a/SwissArmyTransformer/mpu/utils.py b/SwissArmyTransformer/mpu/utils.py index c83f501889ccabb33cce33260f7d4fee9eadcab7..2f4f29f08227269634af831b91d2fe9b301b4e9e 100755 --- a/SwissArmyTransformer/mpu/utils.py +++ b/SwissArmyTransformer/mpu/utils.py @@ -75,7 +75,7 @@ def sqrt(x): def unscaled_init_method(sigma): """Init method based on N(0, sigma).""" - def init_(tensor): + def init_(tensor, **kwargs): return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) return init_ @@ -83,7 +83,7 @@ def unscaled_init_method(sigma): def scaled_init_method(sigma, num_layers): """Init method based on N(0, sigma/sqrt(2*num_layers).""" std = sigma / math.sqrt(2.0 * num_layers) - def init_(tensor): + def init_(tensor, **kwargs): return torch.nn.init.normal_(tensor, mean=0.0, std=std) return init_ diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py index ffdf5c354cfe6985d1aff9415079ab1c02dcacd7..2279ee051c622f17991f0337206933d2a05c3653 100644 --- a/SwissArmyTransformer/tokenization/__init__.py +++ b/SwissArmyTransformer/tokenization/__init__.py @@ -67,7 +67,7 @@ def get_tokenizer(args=None, outer_tokenizer=None): elif args.tokenizer_type.startswith('hf'): from .hf_tokenizer import HFT5Tokenizer if args.tokenizer_type == "hf_T5Tokenizer": - get_tokenizer.tokenizer = HFT5Tokenizer(args.tokenizer_model_type) + get_tokenizer.tokenizer = HFT5Tokenizer(args.tokenizer_model_type, cache_dir=args.cache_dir) else: assert args.vocab_size > 0 get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size) diff --git a/SwissArmyTransformer/tokenization/hf_tokenizer.py b/SwissArmyTransformer/tokenization/hf_tokenizer.py index d67197662cf200acbda1299bce1e9199c1e34862..e40adb47a4ac8cb876132dc5cefc27f824442baf 100644 --- a/SwissArmyTransformer/tokenization/hf_tokenizer.py +++ b/SwissArmyTransformer/tokenization/hf_tokenizer.py @@ -2,8 +2,18 @@ from transformers import T5Tokenizer from .glm.tokenization import Tokenization, CommandToken +PRETRAINED_VOCAB_FILES_MAP = { + "t5-small": "/dataset/fd5061f6/yanan/huggingface_models/t5-small", + "t5-base": "/dataset/fd5061f6/yanan/huggingface_models/t5-base", + "t5-large": "/mnt/t5", + "t5-3b": "/dataset/fd5061f6/yanan/huggingface_models/t5-3b", + "t5-11b": "/dataset/fd5061f6/yanan/huggingface_models/t5-11b" +} + class HFTokenizer: def __init__(self, model_cls, model_type_or_path=None, cache_dir=None, command_tokens=None): + if model_type_or_path in PRETRAINED_VOCAB_FILES_MAP: + model_type_or_path = PRETRAINED_VOCAB_FILES_MAP[model_type_or_path] self.text_tokenizer = model_cls.from_pretrained(model_type_or_path, cache_dir=cache_dir) self.num_tokens = len(self.text_tokenizer) self._command_tokens = [] @@ -11,6 +21,9 @@ class HFTokenizer: self.command_token_map = {} self.command_id_map = {} + def __len__(self): + return len(self.text_tokenizer) + @property def command_tokens(self): return self._command_tokens @@ -59,7 +72,9 @@ class HFT5Tokenizer(HFTokenizer): command_tokens = [ CommandToken('eos', '</s>', self.TokenToId("</s>")), CommandToken('pad', '<pad>', self.TokenToId("<pad>")), + CommandToken('sop', '<pad>', self.TokenToId("<pad>")) ] for i in range(100): command_tokens.append(CommandToken(f'MASK{i}', f'<extra_id_{i}>', self.TokenToId(f'<extra_id_{i}>'))) self.command_tokens = command_tokens + diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py index d9059ac4d1f282e3792b3a877af76d81c8531b7c..beacffc39d8dfce4609c25fcf40c37c1d78e6c6e 100644 --- a/SwissArmyTransformer/training/deepspeed_training.py +++ b/SwissArmyTransformer/training/deepspeed_training.py @@ -144,6 +144,8 @@ def get_model(args, model_cls): if args.fp16: model.half() + elif args.bf16: + model.bfloat16() model.cuda(torch.cuda.current_device()) return model @@ -546,7 +548,8 @@ def initialize_distributed(args): # Optional DeepSpeed Activation Checkpointing Features if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing: set_deepspeed_activation_checkpointing(args) # TODO manual model-parallel seed - + else: + mpu.get_cuda_rng_tracker = None def set_random_seed(seed): """Set random seed for reproducability."""