From 826b7d536da9e810f213353615d68ba355997dde Mon Sep 17 00:00:00 2001
From: Zhengxiao Du <zx-du20@mails.tsinghua.edu.cn>
Date: Sun, 7 Nov 2021 19:40:00 +0800
Subject: [PATCH] Fix parallel_output and add argument

---
 SwissArmyTransformer/model/base_model.py      | 4 ++--
 SwissArmyTransformer/mpu/transformer.py       | 2 +-
 SwissArmyTransformer/tokenization/__init__.py | 2 +-
 3 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py
index 0b45605..b3fd95f 100644
--- a/SwissArmyTransformer/model/base_model.py
+++ b/SwissArmyTransformer/model/base_model.py
@@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module):
     # ...
 
 class BaseModel(torch.nn.Module):
-    def __init__(self, args, transformer=None):
+    def __init__(self, args, transformer=None, parallel_output=True):
         super(BaseModel, self).__init__()
         self.mixins = torch.nn.ModuleDict()
         self.collect_hooks_()
@@ -45,7 +45,7 @@ class BaseModel(torch.nn.Module):
                 checkpoint_activations=args.checkpoint_activations,
                 checkpoint_num_layers=args.checkpoint_num_layers,
                 sandwich_ln=args.sandwich_ln,
-                parallel_output=True,
+                parallel_output=parallel_output,
                 hooks=self.hooks
             )
         
diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index 2cf3852..dbd8fe5 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -405,7 +405,7 @@ class BaseTransformer(torch.nn.Module):
         if branch_input is None and 'branch_final_forward' in self.hooks:
             branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args)
 
-        if self.parallel_output:
+        if not self.parallel_output:
             logits_parallel = gather_from_model_parallel_region(logits_parallel)
             
         if branch_input is not None:
diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py
index b54c607..5ee929a 100644
--- a/SwissArmyTransformer/tokenization/__init__.py
+++ b/SwissArmyTransformer/tokenization/__init__.py
@@ -55,7 +55,7 @@ def get_tokenizer(args=None, outer_tokenizer=None):
             )
         elif args.tokenizer_type.startswith('glm_'):
             kwargs = {"add_block_symbols": True, "add_task_mask": args.task_mask,
-                      "add_decoder_mask": False}
+                      "add_decoder_mask": args.block_mask_prob > 0.0}
             if args.tokenizer_type == "glm_GPT2BPETokenizer":
                 from .glm import GPT2BPETokenizer
                 get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)
-- 
GitLab