From 826b7d536da9e810f213353615d68ba355997dde Mon Sep 17 00:00:00 2001 From: Zhengxiao Du <zx-du20@mails.tsinghua.edu.cn> Date: Sun, 7 Nov 2021 19:40:00 +0800 Subject: [PATCH] Fix parallel_output and add argument --- SwissArmyTransformer/model/base_model.py | 4 ++-- SwissArmyTransformer/mpu/transformer.py | 2 +- SwissArmyTransformer/tokenization/__init__.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py index 0b45605..b3fd95f 100644 --- a/SwissArmyTransformer/model/base_model.py +++ b/SwissArmyTransformer/model/base_model.py @@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module): # ... class BaseModel(torch.nn.Module): - def __init__(self, args, transformer=None): + def __init__(self, args, transformer=None, parallel_output=True): super(BaseModel, self).__init__() self.mixins = torch.nn.ModuleDict() self.collect_hooks_() @@ -45,7 +45,7 @@ class BaseModel(torch.nn.Module): checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, sandwich_ln=args.sandwich_ln, - parallel_output=True, + parallel_output=parallel_output, hooks=self.hooks ) diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index 2cf3852..dbd8fe5 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -405,7 +405,7 @@ class BaseTransformer(torch.nn.Module): if branch_input is None and 'branch_final_forward' in self.hooks: branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args) - if self.parallel_output: + if not self.parallel_output: logits_parallel = gather_from_model_parallel_region(logits_parallel) if branch_input is not None: diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py index b54c607..5ee929a 100644 --- a/SwissArmyTransformer/tokenization/__init__.py +++ b/SwissArmyTransformer/tokenization/__init__.py @@ -55,7 +55,7 @@ def get_tokenizer(args=None, outer_tokenizer=None): ) elif args.tokenizer_type.startswith('glm_'): kwargs = {"add_block_symbols": True, "add_task_mask": args.task_mask, - "add_decoder_mask": False} + "add_decoder_mask": args.block_mask_prob > 0.0} if args.tokenizer_type == "glm_GPT2BPETokenizer": from .glm import GPT2BPETokenizer get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs) -- GitLab