diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py
index 4e98560be2b9e38d9bd9140c4b09552701042616..c9e1c9017782b073546fa12423690b71888a51b0 100644
--- a/SwissArmyTransformer/model/base_model.py
+++ b/SwissArmyTransformer/model/base_model.py
@@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module):
     # ...
 
 class BaseModel(torch.nn.Module):
-    def __init__(self, args, transformer=None):
+    def __init__(self, args, transformer=None, parallel_output=True):
         super(BaseModel, self).__init__()
         self.mixins = torch.nn.ModuleDict()
         self.collect_hooks_()
@@ -45,36 +45,41 @@ class BaseModel(torch.nn.Module):
                 checkpoint_activations=args.checkpoint_activations,
                 checkpoint_num_layers=args.checkpoint_num_layers,
                 sandwich_ln=args.sandwich_ln,
-                parallel_output=True,
+                parallel_output=parallel_output,
                 hooks=self.hooks
             )
-        
+
     def reinit(self): # will be called when loading model
         # if some mixins are loaded, overrides this function
-        for m in self.mixins.values(): 
+        for m in self.mixins.values():
             m.reinit(self.transformer)
-            
+
     def add_mixin(self, name, new_mixin, reinit=False):
         assert name not in self.mixins
         assert isinstance(new_mixin, BaseMixin)
-        
+
         self.mixins[name] = new_mixin # will auto-register parameters
         object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr
-        
+
         if reinit:
             new_mixin.reinit(self.transformer, **self.mixins) # also pass current mixins
         self.collect_hooks_()
-        
+
+    def del_mixin(self, name):
+        assert name in self.mixins
+        del self.mixins[name]
+        self.collect_hooks_()
+
     def get_mixin(self, name):
         return self.mixins[name]
-    
+
     def forward(self, *args, **kwargs):
         # update hooks as the current model (overrided forwards)
         # Attention! the transformer might be shared by multiple models
         self.transformer.hooks.clear()
         self.transformer.hooks.update(self.hooks)
         return self.transformer(*args, **kwargs)
-        
+
     def collect_hooks_(self):
         names = ['word_embedding_forward', 'position_embedding_forward',
                 'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
@@ -97,6 +102,6 @@ class BaseModel(torch.nn.Module):
         self.hooks = hooks
         self.hook_origins = hook_origins
         return hooks
-    
+
     def disable_untrainable_params(self):
         pass
\ No newline at end of file
diff --git a/SwissArmyTransformer/model/glm_model.py b/SwissArmyTransformer/model/glm_model.py
index 076ee44c58f526ecf39343f0b127c0c2920e0a17..4ad68bf243bed66f7005dc1cf09d4a858450e851 100644
--- a/SwissArmyTransformer/model/glm_model.py
+++ b/SwissArmyTransformer/model/glm_model.py
@@ -19,8 +19,8 @@ class BlockPositionEmbeddingMixin(BaseMixin):
         return position_embeddings + block_position_embeddings
 
 class GLMModel(BaseModel):
-    def __init__(self, args, transformer=None):
-        super().__init__(args, transformer=transformer)
+    def __init__(self, args, transformer=None, parallel_output=True):
+        super().__init__(args, transformer=transformer, parallel_output=parallel_output)
         self.add_mixin('block_position_embedding', 
             BlockPositionEmbeddingMixin(args.max_sequence_length, args.hidden_size)
         )
diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index c25500d59cb87f5f80640785eac5d3004587cd97..c2e97cd5233a4d9edbcc1ce8cab394d5ea8f9a37 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -406,7 +406,7 @@ class BaseTransformer(torch.nn.Module):
         if branch_input is None and 'branch_final_forward' in self.hooks:
             branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args)
 
-        if self.parallel_output:
+        if not self.parallel_output:
             logits_parallel = gather_from_model_parallel_region(logits_parallel)
             
         if branch_input is not None:
diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py
index fba8453129a24c8003e710e14c598c7d99340c9f..fa93f3fd575bb4049c0b3bbbddd7d66e196d1e2b 100644
--- a/SwissArmyTransformer/tokenization/__init__.py
+++ b/SwissArmyTransformer/tokenization/__init__.py
@@ -56,7 +56,7 @@ def get_tokenizer(args=None, outer_tokenizer=None):
             )
         elif args.tokenizer_type.startswith('glm'):
             kwargs = {"add_block_symbols": True, "add_task_mask": args.task_mask,
-                      "add_decoder_mask": False}
+                      "add_decoder_mask": args.block_mask_prob > 0.0}
             if args.tokenizer_type == "glm_GPT2BPETokenizer":
                 from .glm import GPT2BPETokenizer
                 get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)
diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py
index a6a6953449a3340148625b7c2688160d7d3ce3b0..6dc457301432a4e2c2c24f377be349df452e6324 100644
--- a/SwissArmyTransformer/training/deepspeed_training.py
+++ b/SwissArmyTransformer/training/deepspeed_training.py
@@ -328,8 +328,11 @@ def train(model, optimizer, lr_scheduler,
 
 
 def train_step(data_iterator, model, optimizer, lr_scheduler,
-               args, timers, hooks={}):
+               args, timers, hooks=None, single_step=False):
     """Single training step."""
+    if hooks is None:
+        hooks = {}
+    lm_loss_total, metrics_total, count = 0.0, {}, 0
     forward_step = hooks['forward_step']
     
     while True:
@@ -354,6 +357,13 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
             print('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!')
             return lm_loss.detach(), 1, metrics
 
+        # Accumulate the statistics
+        lm_loss_total += lm_loss_reduced
+        for name in metrics:
+            if name not in metrics_total:
+                metrics_total[name] = 0.0
+            metrics_total[name] += metrics[name]
+        count += 1
         # Calculate gradients, reduce across processes, and clip.
         timers('backward').start()
         backward_step(optimizer, model, lm_loss, args, timers)
@@ -374,9 +384,11 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
         else:
             raise ValueError('Currently, we only support training with deepspeed.')
         timers('optimizer').stop()
-        if complete:
+        if complete or single_step:
             break
-    return lm_loss_reduced, skipped_iter, metrics
+    lm_loss_total /= count
+    metrics_total = {key: value / count for key, value in metrics_total.items()}
+    return lm_loss_total, skipped_iter, metrics_total
 
 def backward_step(optimizer, model, loss, args, timers):
     """Backward step."""
@@ -500,9 +512,9 @@ def initialize_distributed(args):
     torch.cuda.set_device(args.device)
     # Call the init process
     init_method = 'tcp://'
-    master_ip = os.getenv('MASTER_ADDR', 'localhost')
-    master_port = os.getenv('MASTER_PORT', '6000')
-    init_method += master_ip + ':' + master_port
+    args.master_ip = os.getenv('MASTER_ADDR', 'localhost')
+    args.master_port = os.getenv('MASTER_PORT', '6000')
+    init_method += args.master_ip + ':' + args.master_port
     torch.distributed.init_process_group(
         backend=args.distributed_backend,
         world_size=args.world_size, rank=args.rank,
@@ -513,7 +525,7 @@ def initialize_distributed(args):
 
     # Optional DeepSpeed Activation Checkpointing Features
     if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing:
-        set_deepspeed_activation_checkpointing(args) # TODO manual model-parallel seed
+        set_deepspeed_activation_checkpointing(args)  # TODO manual model-parallel seed
 
 
 def set_random_seed(seed):
diff --git a/SwissArmyTransformer/training/utils.py b/SwissArmyTransformer/training/utils.py
index 066efacd0df40093763db50e25c28370c5f4b60e..79197626ab8d557395023788eecb01ead10de277 100755
--- a/SwissArmyTransformer/training/utils.py
+++ b/SwissArmyTransformer/training/utils.py
@@ -126,8 +126,8 @@ def report_memory(name):
         torch.cuda.memory_allocated() / mega_bytes)
     string += ' | max allocated: {}'.format(
         torch.cuda.max_memory_allocated() / mega_bytes)
-    string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
+    string += ' | cached: {}'.format(torch.cuda.memory_reserved() / mega_bytes)
     string += ' | max cached: {}'.format(
-        torch.cuda.memory_reserved() / mega_bytes)
+        torch.cuda.max_memory_reserved() / mega_bytes)
     print_rank_0(string)