diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py
index 0b45605ae61b38de1ddc003f114a0d722f23e7c2..b3fd95f3742ec327d98027d2d9e4ce05c30899c8 100644
--- a/SwissArmyTransformer/model/base_model.py
+++ b/SwissArmyTransformer/model/base_model.py
@@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module):
     # ...
 
 class BaseModel(torch.nn.Module):
-    def __init__(self, args, transformer=None):
+    def __init__(self, args, transformer=None, parallel_output=True):
         super(BaseModel, self).__init__()
         self.mixins = torch.nn.ModuleDict()
         self.collect_hooks_()
@@ -45,7 +45,7 @@ class BaseModel(torch.nn.Module):
                 checkpoint_activations=args.checkpoint_activations,
                 checkpoint_num_layers=args.checkpoint_num_layers,
                 sandwich_ln=args.sandwich_ln,
-                parallel_output=True,
+                parallel_output=parallel_output,
                 hooks=self.hooks
             )
         
diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index 2cf38521f49b7f8ee207863fb90926c06e622ba2..dbd8fe5fbcdc9871a962ff43920a2797e8d73bf0 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -405,7 +405,7 @@ class BaseTransformer(torch.nn.Module):
         if branch_input is None and 'branch_final_forward' in self.hooks:
             branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args)
 
-        if self.parallel_output:
+        if not self.parallel_output:
             logits_parallel = gather_from_model_parallel_region(logits_parallel)
             
         if branch_input is not None:
diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py
index b54c6072c82cab8e45fff74e0cc1303fedc48427..5ee929afa5d2e6445edfa3a948f80098f6ec6c1f 100644
--- a/SwissArmyTransformer/tokenization/__init__.py
+++ b/SwissArmyTransformer/tokenization/__init__.py
@@ -55,7 +55,7 @@ def get_tokenizer(args=None, outer_tokenizer=None):
             )
         elif args.tokenizer_type.startswith('glm_'):
             kwargs = {"add_block_symbols": True, "add_task_mask": args.task_mask,
-                      "add_decoder_mask": False}
+                      "add_decoder_mask": args.block_mask_prob > 0.0}
             if args.tokenizer_type == "glm_GPT2BPETokenizer":
                 from .glm import GPT2BPETokenizer
                 get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)