diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py index 4e98560be2b9e38d9bd9140c4b09552701042616..c9e1c9017782b073546fa12423690b71888a51b0 100644 --- a/SwissArmyTransformer/model/base_model.py +++ b/SwissArmyTransformer/model/base_model.py @@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module): # ... class BaseModel(torch.nn.Module): - def __init__(self, args, transformer=None): + def __init__(self, args, transformer=None, parallel_output=True): super(BaseModel, self).__init__() self.mixins = torch.nn.ModuleDict() self.collect_hooks_() @@ -45,36 +45,41 @@ class BaseModel(torch.nn.Module): checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, sandwich_ln=args.sandwich_ln, - parallel_output=True, + parallel_output=parallel_output, hooks=self.hooks ) - + def reinit(self): # will be called when loading model # if some mixins are loaded, overrides this function - for m in self.mixins.values(): + for m in self.mixins.values(): m.reinit(self.transformer) - + def add_mixin(self, name, new_mixin, reinit=False): assert name not in self.mixins assert isinstance(new_mixin, BaseMixin) - + self.mixins[name] = new_mixin # will auto-register parameters object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr - + if reinit: new_mixin.reinit(self.transformer, **self.mixins) # also pass current mixins self.collect_hooks_() - + + def del_mixin(self, name): + assert name in self.mixins + del self.mixins[name] + self.collect_hooks_() + def get_mixin(self, name): return self.mixins[name] - + def forward(self, *args, **kwargs): # update hooks as the current model (overrided forwards) # Attention! the transformer might be shared by multiple models self.transformer.hooks.clear() self.transformer.hooks.update(self.hooks) return self.transformer(*args, **kwargs) - + def collect_hooks_(self): names = ['word_embedding_forward', 'position_embedding_forward', 'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward', @@ -97,6 +102,6 @@ class BaseModel(torch.nn.Module): self.hooks = hooks self.hook_origins = hook_origins return hooks - + def disable_untrainable_params(self): pass \ No newline at end of file diff --git a/SwissArmyTransformer/model/glm_model.py b/SwissArmyTransformer/model/glm_model.py index 076ee44c58f526ecf39343f0b127c0c2920e0a17..4ad68bf243bed66f7005dc1cf09d4a858450e851 100644 --- a/SwissArmyTransformer/model/glm_model.py +++ b/SwissArmyTransformer/model/glm_model.py @@ -19,8 +19,8 @@ class BlockPositionEmbeddingMixin(BaseMixin): return position_embeddings + block_position_embeddings class GLMModel(BaseModel): - def __init__(self, args, transformer=None): - super().__init__(args, transformer=transformer) + def __init__(self, args, transformer=None, parallel_output=True): + super().__init__(args, transformer=transformer, parallel_output=parallel_output) self.add_mixin('block_position_embedding', BlockPositionEmbeddingMixin(args.max_sequence_length, args.hidden_size) ) diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index c25500d59cb87f5f80640785eac5d3004587cd97..c2e97cd5233a4d9edbcc1ce8cab394d5ea8f9a37 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -406,7 +406,7 @@ class BaseTransformer(torch.nn.Module): if branch_input is None and 'branch_final_forward' in self.hooks: branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args) - if self.parallel_output: + if not self.parallel_output: logits_parallel = gather_from_model_parallel_region(logits_parallel) if branch_input is not None: diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py index fba8453129a24c8003e710e14c598c7d99340c9f..fa93f3fd575bb4049c0b3bbbddd7d66e196d1e2b 100644 --- a/SwissArmyTransformer/tokenization/__init__.py +++ b/SwissArmyTransformer/tokenization/__init__.py @@ -56,7 +56,7 @@ def get_tokenizer(args=None, outer_tokenizer=None): ) elif args.tokenizer_type.startswith('glm'): kwargs = {"add_block_symbols": True, "add_task_mask": args.task_mask, - "add_decoder_mask": False} + "add_decoder_mask": args.block_mask_prob > 0.0} if args.tokenizer_type == "glm_GPT2BPETokenizer": from .glm import GPT2BPETokenizer get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs) diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py index a6a6953449a3340148625b7c2688160d7d3ce3b0..6dc457301432a4e2c2c24f377be349df452e6324 100644 --- a/SwissArmyTransformer/training/deepspeed_training.py +++ b/SwissArmyTransformer/training/deepspeed_training.py @@ -328,8 +328,11 @@ def train(model, optimizer, lr_scheduler, def train_step(data_iterator, model, optimizer, lr_scheduler, - args, timers, hooks={}): + args, timers, hooks=None, single_step=False): """Single training step.""" + if hooks is None: + hooks = {} + lm_loss_total, metrics_total, count = 0.0, {}, 0 forward_step = hooks['forward_step'] while True: @@ -354,6 +357,13 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, print('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!') return lm_loss.detach(), 1, metrics + # Accumulate the statistics + lm_loss_total += lm_loss_reduced + for name in metrics: + if name not in metrics_total: + metrics_total[name] = 0.0 + metrics_total[name] += metrics[name] + count += 1 # Calculate gradients, reduce across processes, and clip. timers('backward').start() backward_step(optimizer, model, lm_loss, args, timers) @@ -374,9 +384,11 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, else: raise ValueError('Currently, we only support training with deepspeed.') timers('optimizer').stop() - if complete: + if complete or single_step: break - return lm_loss_reduced, skipped_iter, metrics + lm_loss_total /= count + metrics_total = {key: value / count for key, value in metrics_total.items()} + return lm_loss_total, skipped_iter, metrics_total def backward_step(optimizer, model, loss, args, timers): """Backward step.""" @@ -500,9 +512,9 @@ def initialize_distributed(args): torch.cuda.set_device(args.device) # Call the init process init_method = 'tcp://' - master_ip = os.getenv('MASTER_ADDR', 'localhost') - master_port = os.getenv('MASTER_PORT', '6000') - init_method += master_ip + ':' + master_port + args.master_ip = os.getenv('MASTER_ADDR', 'localhost') + args.master_port = os.getenv('MASTER_PORT', '6000') + init_method += args.master_ip + ':' + args.master_port torch.distributed.init_process_group( backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, @@ -513,7 +525,7 @@ def initialize_distributed(args): # Optional DeepSpeed Activation Checkpointing Features if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing: - set_deepspeed_activation_checkpointing(args) # TODO manual model-parallel seed + set_deepspeed_activation_checkpointing(args) # TODO manual model-parallel seed def set_random_seed(seed): diff --git a/SwissArmyTransformer/training/utils.py b/SwissArmyTransformer/training/utils.py index 066efacd0df40093763db50e25c28370c5f4b60e..79197626ab8d557395023788eecb01ead10de277 100755 --- a/SwissArmyTransformer/training/utils.py +++ b/SwissArmyTransformer/training/utils.py @@ -126,8 +126,8 @@ def report_memory(name): torch.cuda.memory_allocated() / mega_bytes) string += ' | max allocated: {}'.format( torch.cuda.max_memory_allocated() / mega_bytes) - string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes) + string += ' | cached: {}'.format(torch.cuda.memory_reserved() / mega_bytes) string += ' | max cached: {}'.format( - torch.cuda.memory_reserved() / mega_bytes) + torch.cuda.max_memory_reserved() / mega_bytes) print_rank_0(string)