diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py index 0b45605ae61b38de1ddc003f114a0d722f23e7c2..b3fd95f3742ec327d98027d2d9e4ce05c30899c8 100644 --- a/SwissArmyTransformer/model/base_model.py +++ b/SwissArmyTransformer/model/base_model.py @@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module): # ... class BaseModel(torch.nn.Module): - def __init__(self, args, transformer=None): + def __init__(self, args, transformer=None, parallel_output=True): super(BaseModel, self).__init__() self.mixins = torch.nn.ModuleDict() self.collect_hooks_() @@ -45,7 +45,7 @@ class BaseModel(torch.nn.Module): checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, sandwich_ln=args.sandwich_ln, - parallel_output=True, + parallel_output=parallel_output, hooks=self.hooks ) diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index 2cf38521f49b7f8ee207863fb90926c06e622ba2..dbd8fe5fbcdc9871a962ff43920a2797e8d73bf0 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -405,7 +405,7 @@ class BaseTransformer(torch.nn.Module): if branch_input is None and 'branch_final_forward' in self.hooks: branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args) - if self.parallel_output: + if not self.parallel_output: logits_parallel = gather_from_model_parallel_region(logits_parallel) if branch_input is not None: diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py index b54c6072c82cab8e45fff74e0cc1303fedc48427..5ee929afa5d2e6445edfa3a948f80098f6ec6c1f 100644 --- a/SwissArmyTransformer/tokenization/__init__.py +++ b/SwissArmyTransformer/tokenization/__init__.py @@ -55,7 +55,7 @@ def get_tokenizer(args=None, outer_tokenizer=None): ) elif args.tokenizer_type.startswith('glm_'): kwargs = {"add_block_symbols": True, "add_task_mask": args.task_mask, - "add_decoder_mask": False} + "add_decoder_mask": args.block_mask_prob > 0.0} if args.tokenizer_type == "glm_GPT2BPETokenizer": from .glm import GPT2BPETokenizer get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)