diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py
index 52393e6fb527830aa12eb0e2e63f657f01f98db9..a47795b82cdf8623818f585a463b7645750189ab 100644
--- a/SwissArmyTransformer/training/deepspeed_training.py
+++ b/SwissArmyTransformer/training/deepspeed_training.py
@@ -36,7 +36,7 @@ from .utils import print_rank_0
 from .utils import get_sample_writer
 
 from SwissArmyTransformer import mpu
-from SwissArmyTransformer.data_utils import make_loaders 
+from SwissArmyTransformer.data_utils import make_loaders
 from SwissArmyTransformer.tokenization import get_tokenizer
 
 
@@ -46,23 +46,23 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio
         'forward_step': forward_step_function,
         'init_function': init_function,
         'create_dataset_function': create_dataset_function
-        }
-    
+    }
+
     torch.backends.cuda.matmul.allow_tf32 = False
-    torch.backends.cudnn.enabled = False # Disable CuDNN.
-    timers = Timers() # Timer.
-    
+    torch.backends.cudnn.enabled = False  # Disable CuDNN.
+    timers = Timers()  # Timer.
+
     # Experiment Name
-    if args.load and args.mode == 'pretrain': # continue training
+    if args.load and args.mode == 'pretrain':  # continue training
         args.experiment_name = os.path.basename(os.path.normpath(args.load))
     else:
         args.experiment_name = args.experiment_name + datetime.now().strftime("%m-%d-%H-%M")
-        
+
     # Pytorch distributed. must before seed
     initialize_distributed(args)
-    set_random_seed(args.seed) # Random seeds for reproducability.
+    set_random_seed(args.seed)  # Random seeds for reproducability.
     # init tokenizer
-    get_tokenizer(args) # set args.vocab_size.
+    get_tokenizer(args)  # set args.vocab_size.
     # Data stuff.
     train_data, val_data, test_data = make_loaders(args, hooks['create_dataset_function'])
 
@@ -80,7 +80,7 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio
     if args.save:
         args.save = os.path.join(args.save, args.experiment_name)
     torch.distributed.barrier()
-    
+
     # initialize lr scheduler
     lr_scheduler = get_learning_rate_scheduler(optimizer, args.iteration, args)
 
@@ -108,7 +108,7 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio
         val_data_iterator = iter(val_data)
     else:
         val_data_iterator = None
-        
+
     # init hook before training
     if hooks['init_function'] is not None:
         hooks['init_function'](args, model, optimizer)
@@ -120,27 +120,29 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio
             with ExitStack() as stack:
                 def save_on_exit(args_, model_, optimizer_, lr_scheduler_):
                     save_checkpoint(args_.iteration, model_, optimizer_, lr_scheduler_, args_)
+
                 iteration, skipped = train(model, optimizer,
-                    lr_scheduler,
-                    train_data_iterator,
-                    val_data_iterator,
-                    timers, args, summary_writer=summary_writer,
-                    hooks=hooks
-                    )
+                                           lr_scheduler,
+                                           train_data_iterator,
+                                           val_data_iterator,
+                                           timers, args, summary_writer=summary_writer,
+                                           hooks=hooks
+                                           )
         if args.do_valid:
             prefix = 'the end of training for val data'
             val_loss = evaluate_and_print_results(prefix, val_data_iterator,
-                model, args, timers, False)
+                                                  model, args, timers, False)
 
     # final save
-    if args.save and iteration != 0: # TODO save
+    if args.save and iteration != 0:  # TODO save
         save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
 
     # final testing
     if args.do_test and test_data is not None:
         prefix = 'the end of training for test data'
         evaluate_and_print_results(prefix, iter(test_data),
-            model, args, timers, True)
+                                   model, args, timers, True)
+
 
 def get_model(args, model_cls):
     """Build the model."""
@@ -159,12 +161,13 @@ def get_model(args, model_cls):
 
     return model
 
+
 def setup_model_and_optimizer(args, model_cls):
     """Setup model and optimizer."""
 
     model = get_model(args, model_cls)
-    
-    model.disable_untrainable_params() # mark trainable params
+
+    model.disable_untrainable_params()  # mark trainable params
 
     param_groups = get_optimizer_param_groups(model)
 
@@ -187,7 +190,6 @@ def setup_model_and_optimizer(args, model_cls):
 
 
 def get_params_for_weight_decay_optimization(module):
-    
     weight_decay_params = {'params': []}
     no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
     for module_ in module.modules():
@@ -210,11 +212,12 @@ def get_params_for_weight_decay_optimization(module):
 
     return weight_decay_params, no_weight_decay_params
 
+
 def get_optimizer_param_groups(model):
     # Build parameter groups (weight decay and non-decay).
     if hasattr(model, 'module'):
         model = model.module
-    param_groups = get_params_for_weight_decay_optimization(model) # TODO move to here
+    param_groups = get_params_for_weight_decay_optimization(model)  # TODO move to here
     # Add model parallel attribute if it is not set.
     for param_group in param_groups:
         for param in param_group['params']:
@@ -222,7 +225,8 @@ def get_optimizer_param_groups(model):
                 param.model_parallel = False
     return param_groups
 
-def get_learning_rate_scheduler(optimizer, iteration, args, 
+
+def get_learning_rate_scheduler(optimizer, iteration, args,
                                 auto_warmup_steps=100, auto_warmup_rate=0.05):
     """Build the learning rate scheduler."""
 
@@ -232,7 +236,7 @@ def get_learning_rate_scheduler(optimizer, iteration, args,
     else:
         num_iters = args.train_iters
     num_iters = max(1, num_iters)
-    init_step = max(iteration-auto_warmup_steps, 0)
+    init_step = max(iteration - auto_warmup_steps, 0)
     if args.mode == 'pretrain' and iteration == 0:
         auto_warmup_steps = 0
     # If init_step <= current_steps <= init_step + auto_warmup_steps,
@@ -240,22 +244,22 @@ def get_learning_rate_scheduler(optimizer, iteration, args,
     # This overrides other rules.
     warmup_iter = args.warmup * num_iters
     lr_scheduler = AnnealingLR(optimizer,
-        start_lr=args.lr,
-        warmup_iter=warmup_iter,
-        num_iters=num_iters,
-        decay_style=args.lr_decay_style,
-        last_iter=init_step,
-        decay_ratio=args.lr_decay_ratio,
-        auto_warmup_steps=auto_warmup_steps,
-        auto_warmup_rate=auto_warmup_rate
-        )
+                               start_lr=args.lr,
+                               warmup_iter=warmup_iter,
+                               num_iters=num_iters,
+                               decay_style=args.lr_decay_style,
+                               last_iter=init_step,
+                               decay_ratio=args.lr_decay_ratio,
+                               auto_warmup_steps=auto_warmup_steps,
+                               auto_warmup_rate=auto_warmup_rate
+                               )
 
     return lr_scheduler
 
 
 def train(model, optimizer, lr_scheduler,
-        train_data_iterator, val_data_iterator, timers, args, 
-        summary_writer=None, hooks={}):
+          train_data_iterator, val_data_iterator, timers, args,
+          summary_writer=None, hooks={}):
     """Train the model."""
     # Turn on training mode which enables dropout.
     model.train()
@@ -272,10 +276,10 @@ def train(model, optimizer, lr_scheduler,
     while args.iteration < args.train_iters:
 
         lm_loss, skipped_iter, metrics = train_step(train_data_iterator,
-                                        model,
-                                        optimizer,
-                                        lr_scheduler,
-                                        args, timers, hooks=hooks)
+                                                    model,
+                                                    optimizer,
+                                                    lr_scheduler,
+                                                    args, timers, hooks=hooks)
         skipped_iters += skipped_iter
         args.iteration += 1
 
@@ -295,8 +299,8 @@ def train(model, optimizer, lr_scheduler,
 
             elapsed_time = timers('interval time').elapsed()
             report_iteration_metrics(summary_writer, optimizer, learning_rate, avg_lm_loss,
-                                    elapsed_time * 1000.0 / args.log_interval, args.iteration, args.train_iters, args,
-                                    avg_metrics)
+                                     elapsed_time * 1000.0 / args.log_interval, args.iteration, args.train_iters, args,
+                                     avg_metrics)
             total_lm_loss = 0.0
             total_metrics = defaultdict(float)
             if report_memory_flag:
@@ -304,8 +308,8 @@ def train(model, optimizer, lr_scheduler,
                 report_memory_flag = False
 
             timers.log(['forward', 'backward', 'allreduce', 'optimizer',
-                            'batch generator', 'data loader'],
-                        normalizer=args.log_interval)
+                        'batch generator', 'data loader'],
+                       normalizer=args.log_interval)
         # Checkpointing
         if args.save and args.save_interval and args.iteration % args.save_interval == 0:
             save_checkpoint(args.iteration, model, optimizer, lr_scheduler, args)
@@ -314,7 +318,8 @@ def train(model, optimizer, lr_scheduler,
         if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid:
             prefix = 'iteration {}'.format(args.iteration)
             evaluate_and_print_results(
-                prefix, val_data_iterator, model, args, timers, False, step=args.iteration, summary_writer=summary_writer, hooks=hooks)
+                prefix, val_data_iterator, model, args, timers, False, step=args.iteration,
+                summary_writer=summary_writer, hooks=hooks)
 
         if args.exit_interval and args.iteration % args.exit_interval == 0:
             torch.distributed.barrier()
@@ -328,17 +333,17 @@ def train(model, optimizer, lr_scheduler,
 
 
 def train_step(data_iterator, model, optimizer, lr_scheduler,
-               args, timers, hooks=None, single_step=False):
+               args, timers, hooks=None, single_step=False, **kwargs):
     """Single training step."""
     if hooks is None:
         hooks = {}
     lm_loss_total, metrics_total, count = 0.0, {}, 0
     forward_step = hooks['forward_step']
-    
+
     while True:
         # Forward model for one step.
         timers('forward').start()
-        lm_loss, metrics = forward_step(data_iterator, model, args, timers)
+        lm_loss, metrics = forward_step(data_iterator, model, args, timers, **kwargs)
         timers('forward').stop()
 
         # Check nan or inf in forward, preventing it from interfering loss scaler,
@@ -390,6 +395,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
     metrics_total = {key: value / count for key, value in metrics_total.items()}
     return lm_loss_total, skipped_iter, metrics_total
 
+
 def backward_step(optimizer, model, loss, args, timers):
     """Backward step."""
 
@@ -406,6 +412,7 @@ def backward_step(optimizer, model, loss, args, timers):
 
     return
 
+
 def evaluate(data_iterator, model, args, timers, verbose=False, hooks={}):
     """Evaluation."""
     forward_step = hooks['forward_step']
@@ -436,8 +443,9 @@ def evaluate(data_iterator, model, args, timers, verbose=False, hooks={}):
     total_lm_loss /= args.eval_iters
     return total_lm_loss
 
+
 def evaluate_and_print_results(prefix, data_iterator, model,
-                            args, timers, verbose=False, step=None, summary_writer=None, hooks={}):
+                               args, timers, verbose=False, step=None, summary_writer=None, hooks={}):
     """Helper function to evaluate and dump results on screen."""
     # import line_profiler
     # profile = line_profiler.LineProfiler(model.module.module.transformer.layers[0].forward)
@@ -452,6 +460,7 @@ def evaluate_and_print_results(prefix, data_iterator, model,
 
     return lm_loss
 
+
 def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time, step, total_step, args, avg_metrics):
     log_string = ' iteration {:8d}/{:8d} |'.format(step, total_step)
     log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(elapsed_time)
@@ -481,7 +490,7 @@ def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step):
     if summary_writer is not None:
         summary_writer.add_scalar(f'Train/valid_ppl', ppl, step)
         summary_writer.add_scalar(f'Train/valid_loss', loss, step)
-        
+
 
 '''
     Optional DeepSpeed Activation Checkpointing features
@@ -498,6 +507,7 @@ def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step):
     This must be done before all the calls to mpu.model_parallel_cuda_manual_seed
     '''
 
+
 def set_deepspeed_activation_checkpointing(args):
     deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
     mpu.checkpoint = deepspeed.checkpointing.checkpoint
@@ -535,14 +545,10 @@ def set_random_seed(seed):
         np.random.seed(seed)
         torch.manual_seed(seed)
         torch.cuda.manual_seed(seed)
-        torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
+        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
         torch.backends.cudnn.benchmark = False
         torch.backends.cudnn.deterministic = True
         torch.backends.cudnn.enabled = False
         torch.backends.cuda.matmul.allow_tf32 = False
         if hasattr(mpu, 'model_parallel_cuda_manual_seed'):
             mpu.model_parallel_cuda_manual_seed(seed)
-        
-    
-
-