Skip to content
Snippets Groups Projects
Commit 826b7d53 authored by Zhengxiao Du's avatar Zhengxiao Du
Browse files

Fix parallel_output and add argument

parent aff0493f
No related branches found
No related tags found
No related merge requests found
...@@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module): ...@@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module):
# ... # ...
class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
def __init__(self, args, transformer=None): def __init__(self, args, transformer=None, parallel_output=True):
super(BaseModel, self).__init__() super(BaseModel, self).__init__()
self.mixins = torch.nn.ModuleDict() self.mixins = torch.nn.ModuleDict()
self.collect_hooks_() self.collect_hooks_()
...@@ -45,7 +45,7 @@ class BaseModel(torch.nn.Module): ...@@ -45,7 +45,7 @@ class BaseModel(torch.nn.Module):
checkpoint_activations=args.checkpoint_activations, checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers, checkpoint_num_layers=args.checkpoint_num_layers,
sandwich_ln=args.sandwich_ln, sandwich_ln=args.sandwich_ln,
parallel_output=True, parallel_output=parallel_output,
hooks=self.hooks hooks=self.hooks
) )
......
...@@ -405,7 +405,7 @@ class BaseTransformer(torch.nn.Module): ...@@ -405,7 +405,7 @@ class BaseTransformer(torch.nn.Module):
if branch_input is None and 'branch_final_forward' in self.hooks: if branch_input is None and 'branch_final_forward' in self.hooks:
branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args) branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args)
if self.parallel_output: if not self.parallel_output:
logits_parallel = gather_from_model_parallel_region(logits_parallel) logits_parallel = gather_from_model_parallel_region(logits_parallel)
if branch_input is not None: if branch_input is not None:
......
...@@ -55,7 +55,7 @@ def get_tokenizer(args=None, outer_tokenizer=None): ...@@ -55,7 +55,7 @@ def get_tokenizer(args=None, outer_tokenizer=None):
) )
elif args.tokenizer_type.startswith('glm_'): elif args.tokenizer_type.startswith('glm_'):
kwargs = {"add_block_symbols": True, "add_task_mask": args.task_mask, kwargs = {"add_block_symbols": True, "add_task_mask": args.task_mask,
"add_decoder_mask": False} "add_decoder_mask": args.block_mask_prob > 0.0}
if args.tokenizer_type == "glm_GPT2BPETokenizer": if args.tokenizer_type == "glm_GPT2BPETokenizer":
from .glm import GPT2BPETokenizer from .glm import GPT2BPETokenizer
get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs) get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment