Skip to content
Snippets Groups Projects
Commit 826b7d53 authored by Zhengxiao Du's avatar Zhengxiao Du
Browse files

Fix parallel_output and add argument

parent aff0493f
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module):
# ...
class BaseModel(torch.nn.Module):
def __init__(self, args, transformer=None):
def __init__(self, args, transformer=None, parallel_output=True):
super(BaseModel, self).__init__()
self.mixins = torch.nn.ModuleDict()
self.collect_hooks_()
......@@ -45,7 +45,7 @@ class BaseModel(torch.nn.Module):
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
sandwich_ln=args.sandwich_ln,
parallel_output=True,
parallel_output=parallel_output,
hooks=self.hooks
)
......
......@@ -405,7 +405,7 @@ class BaseTransformer(torch.nn.Module):
if branch_input is None and 'branch_final_forward' in self.hooks:
branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args)
if self.parallel_output:
if not self.parallel_output:
logits_parallel = gather_from_model_parallel_region(logits_parallel)
if branch_input is not None:
......
......@@ -55,7 +55,7 @@ def get_tokenizer(args=None, outer_tokenizer=None):
)
elif args.tokenizer_type.startswith('glm_'):
kwargs = {"add_block_symbols": True, "add_task_mask": args.task_mask,
"add_decoder_mask": False}
"add_decoder_mask": args.block_mask_prob > 0.0}
if args.tokenizer_type == "glm_GPT2BPETokenizer":
from .glm import GPT2BPETokenizer
get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment