diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py index 52393e6fb527830aa12eb0e2e63f657f01f98db9..a47795b82cdf8623818f585a463b7645750189ab 100644 --- a/SwissArmyTransformer/training/deepspeed_training.py +++ b/SwissArmyTransformer/training/deepspeed_training.py @@ -36,7 +36,7 @@ from .utils import print_rank_0 from .utils import get_sample_writer from SwissArmyTransformer import mpu -from SwissArmyTransformer.data_utils import make_loaders +from SwissArmyTransformer.data_utils import make_loaders from SwissArmyTransformer.tokenization import get_tokenizer @@ -46,23 +46,23 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio 'forward_step': forward_step_function, 'init_function': init_function, 'create_dataset_function': create_dataset_function - } - + } + torch.backends.cuda.matmul.allow_tf32 = False - torch.backends.cudnn.enabled = False # Disable CuDNN. - timers = Timers() # Timer. - + torch.backends.cudnn.enabled = False # Disable CuDNN. + timers = Timers() # Timer. + # Experiment Name - if args.load and args.mode == 'pretrain': # continue training + if args.load and args.mode == 'pretrain': # continue training args.experiment_name = os.path.basename(os.path.normpath(args.load)) else: args.experiment_name = args.experiment_name + datetime.now().strftime("%m-%d-%H-%M") - + # Pytorch distributed. must before seed initialize_distributed(args) - set_random_seed(args.seed) # Random seeds for reproducability. + set_random_seed(args.seed) # Random seeds for reproducability. # init tokenizer - get_tokenizer(args) # set args.vocab_size. + get_tokenizer(args) # set args.vocab_size. # Data stuff. train_data, val_data, test_data = make_loaders(args, hooks['create_dataset_function']) @@ -80,7 +80,7 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio if args.save: args.save = os.path.join(args.save, args.experiment_name) torch.distributed.barrier() - + # initialize lr scheduler lr_scheduler = get_learning_rate_scheduler(optimizer, args.iteration, args) @@ -108,7 +108,7 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio val_data_iterator = iter(val_data) else: val_data_iterator = None - + # init hook before training if hooks['init_function'] is not None: hooks['init_function'](args, model, optimizer) @@ -120,27 +120,29 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio with ExitStack() as stack: def save_on_exit(args_, model_, optimizer_, lr_scheduler_): save_checkpoint(args_.iteration, model_, optimizer_, lr_scheduler_, args_) + iteration, skipped = train(model, optimizer, - lr_scheduler, - train_data_iterator, - val_data_iterator, - timers, args, summary_writer=summary_writer, - hooks=hooks - ) + lr_scheduler, + train_data_iterator, + val_data_iterator, + timers, args, summary_writer=summary_writer, + hooks=hooks + ) if args.do_valid: prefix = 'the end of training for val data' val_loss = evaluate_and_print_results(prefix, val_data_iterator, - model, args, timers, False) + model, args, timers, False) # final save - if args.save and iteration != 0: # TODO save + if args.save and iteration != 0: # TODO save save_checkpoint(iteration, model, optimizer, lr_scheduler, args) # final testing if args.do_test and test_data is not None: prefix = 'the end of training for test data' evaluate_and_print_results(prefix, iter(test_data), - model, args, timers, True) + model, args, timers, True) + def get_model(args, model_cls): """Build the model.""" @@ -159,12 +161,13 @@ def get_model(args, model_cls): return model + def setup_model_and_optimizer(args, model_cls): """Setup model and optimizer.""" model = get_model(args, model_cls) - - model.disable_untrainable_params() # mark trainable params + + model.disable_untrainable_params() # mark trainable params param_groups = get_optimizer_param_groups(model) @@ -187,7 +190,6 @@ def setup_model_and_optimizer(args, model_cls): def get_params_for_weight_decay_optimization(module): - weight_decay_params = {'params': []} no_weight_decay_params = {'params': [], 'weight_decay': 0.0} for module_ in module.modules(): @@ -210,11 +212,12 @@ def get_params_for_weight_decay_optimization(module): return weight_decay_params, no_weight_decay_params + def get_optimizer_param_groups(model): # Build parameter groups (weight decay and non-decay). if hasattr(model, 'module'): model = model.module - param_groups = get_params_for_weight_decay_optimization(model) # TODO move to here + param_groups = get_params_for_weight_decay_optimization(model) # TODO move to here # Add model parallel attribute if it is not set. for param_group in param_groups: for param in param_group['params']: @@ -222,7 +225,8 @@ def get_optimizer_param_groups(model): param.model_parallel = False return param_groups -def get_learning_rate_scheduler(optimizer, iteration, args, + +def get_learning_rate_scheduler(optimizer, iteration, args, auto_warmup_steps=100, auto_warmup_rate=0.05): """Build the learning rate scheduler.""" @@ -232,7 +236,7 @@ def get_learning_rate_scheduler(optimizer, iteration, args, else: num_iters = args.train_iters num_iters = max(1, num_iters) - init_step = max(iteration-auto_warmup_steps, 0) + init_step = max(iteration - auto_warmup_steps, 0) if args.mode == 'pretrain' and iteration == 0: auto_warmup_steps = 0 # If init_step <= current_steps <= init_step + auto_warmup_steps, @@ -240,22 +244,22 @@ def get_learning_rate_scheduler(optimizer, iteration, args, # This overrides other rules. warmup_iter = args.warmup * num_iters lr_scheduler = AnnealingLR(optimizer, - start_lr=args.lr, - warmup_iter=warmup_iter, - num_iters=num_iters, - decay_style=args.lr_decay_style, - last_iter=init_step, - decay_ratio=args.lr_decay_ratio, - auto_warmup_steps=auto_warmup_steps, - auto_warmup_rate=auto_warmup_rate - ) + start_lr=args.lr, + warmup_iter=warmup_iter, + num_iters=num_iters, + decay_style=args.lr_decay_style, + last_iter=init_step, + decay_ratio=args.lr_decay_ratio, + auto_warmup_steps=auto_warmup_steps, + auto_warmup_rate=auto_warmup_rate + ) return lr_scheduler def train(model, optimizer, lr_scheduler, - train_data_iterator, val_data_iterator, timers, args, - summary_writer=None, hooks={}): + train_data_iterator, val_data_iterator, timers, args, + summary_writer=None, hooks={}): """Train the model.""" # Turn on training mode which enables dropout. model.train() @@ -272,10 +276,10 @@ def train(model, optimizer, lr_scheduler, while args.iteration < args.train_iters: lm_loss, skipped_iter, metrics = train_step(train_data_iterator, - model, - optimizer, - lr_scheduler, - args, timers, hooks=hooks) + model, + optimizer, + lr_scheduler, + args, timers, hooks=hooks) skipped_iters += skipped_iter args.iteration += 1 @@ -295,8 +299,8 @@ def train(model, optimizer, lr_scheduler, elapsed_time = timers('interval time').elapsed() report_iteration_metrics(summary_writer, optimizer, learning_rate, avg_lm_loss, - elapsed_time * 1000.0 / args.log_interval, args.iteration, args.train_iters, args, - avg_metrics) + elapsed_time * 1000.0 / args.log_interval, args.iteration, args.train_iters, args, + avg_metrics) total_lm_loss = 0.0 total_metrics = defaultdict(float) if report_memory_flag: @@ -304,8 +308,8 @@ def train(model, optimizer, lr_scheduler, report_memory_flag = False timers.log(['forward', 'backward', 'allreduce', 'optimizer', - 'batch generator', 'data loader'], - normalizer=args.log_interval) + 'batch generator', 'data loader'], + normalizer=args.log_interval) # Checkpointing if args.save and args.save_interval and args.iteration % args.save_interval == 0: save_checkpoint(args.iteration, model, optimizer, lr_scheduler, args) @@ -314,7 +318,8 @@ def train(model, optimizer, lr_scheduler, if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid: prefix = 'iteration {}'.format(args.iteration) evaluate_and_print_results( - prefix, val_data_iterator, model, args, timers, False, step=args.iteration, summary_writer=summary_writer, hooks=hooks) + prefix, val_data_iterator, model, args, timers, False, step=args.iteration, + summary_writer=summary_writer, hooks=hooks) if args.exit_interval and args.iteration % args.exit_interval == 0: torch.distributed.barrier() @@ -328,17 +333,17 @@ def train(model, optimizer, lr_scheduler, def train_step(data_iterator, model, optimizer, lr_scheduler, - args, timers, hooks=None, single_step=False): + args, timers, hooks=None, single_step=False, **kwargs): """Single training step.""" if hooks is None: hooks = {} lm_loss_total, metrics_total, count = 0.0, {}, 0 forward_step = hooks['forward_step'] - + while True: # Forward model for one step. timers('forward').start() - lm_loss, metrics = forward_step(data_iterator, model, args, timers) + lm_loss, metrics = forward_step(data_iterator, model, args, timers, **kwargs) timers('forward').stop() # Check nan or inf in forward, preventing it from interfering loss scaler, @@ -390,6 +395,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, metrics_total = {key: value / count for key, value in metrics_total.items()} return lm_loss_total, skipped_iter, metrics_total + def backward_step(optimizer, model, loss, args, timers): """Backward step.""" @@ -406,6 +412,7 @@ def backward_step(optimizer, model, loss, args, timers): return + def evaluate(data_iterator, model, args, timers, verbose=False, hooks={}): """Evaluation.""" forward_step = hooks['forward_step'] @@ -436,8 +443,9 @@ def evaluate(data_iterator, model, args, timers, verbose=False, hooks={}): total_lm_loss /= args.eval_iters return total_lm_loss + def evaluate_and_print_results(prefix, data_iterator, model, - args, timers, verbose=False, step=None, summary_writer=None, hooks={}): + args, timers, verbose=False, step=None, summary_writer=None, hooks={}): """Helper function to evaluate and dump results on screen.""" # import line_profiler # profile = line_profiler.LineProfiler(model.module.module.transformer.layers[0].forward) @@ -452,6 +460,7 @@ def evaluate_and_print_results(prefix, data_iterator, model, return lm_loss + def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time, step, total_step, args, avg_metrics): log_string = ' iteration {:8d}/{:8d} |'.format(step, total_step) log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(elapsed_time) @@ -481,7 +490,7 @@ def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step): if summary_writer is not None: summary_writer.add_scalar(f'Train/valid_ppl', ppl, step) summary_writer.add_scalar(f'Train/valid_loss', loss, step) - + ''' Optional DeepSpeed Activation Checkpointing features @@ -498,6 +507,7 @@ def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step): This must be done before all the calls to mpu.model_parallel_cuda_manual_seed ''' + def set_deepspeed_activation_checkpointing(args): deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers) mpu.checkpoint = deepspeed.checkpointing.checkpoint @@ -535,14 +545,10 @@ def set_random_seed(seed): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.backends.cudnn.enabled = False torch.backends.cuda.matmul.allow_tf32 = False if hasattr(mpu, 'model_parallel_cuda_manual_seed'): mpu.model_parallel_cuda_manual_seed(seed) - - - -