From 3d55652ba3b8e9d21f4cdeb60fef4fff5b783bcb Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Thu, 28 Oct 2021 09:12:20 +0000 Subject: [PATCH] fix glm mask bug --- inference_glm.py | 5 +++-- training/deepspeed_training.py | 6 ++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/inference_glm.py b/inference_glm.py index 0e5e44d..a295e41 100644 --- a/inference_glm.py +++ b/inference_glm.py @@ -32,6 +32,7 @@ def get_masks_and_position_ids_glm(seq, mask_position, context_length): attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device) attention_mask.tril_() + attention_mask[..., :context_length] = 1 attention_mask.unsqueeze_(1) position_ids = torch.zeros(2, len(seq), device=tokens.device, dtype=torch.long) @@ -71,7 +72,7 @@ def main(args): seq = [tokenizer.get_command('ENC').Id] + seq if not raw_text.endswith('MASK]'): seq = seq + [tokenizer.get_command('eos').Id] - print('raw text: ', raw_text) + print('raw text: {}\n'.format(raw_text)) if len(seq) > args.max_sequence_length: raise ValueError('text too long.') @@ -112,7 +113,7 @@ def main(args): except ValueError: unfinished = len(output) bog = output.index(tokenizer.get_command('sop').Id) - output_list[i] = output[:mask_position] + output[bog+1:unfinished] + output[mask_position+1:bog] + output_list[i] = output[:mask_position] + output[bog:unfinished] + output[mask_position+1:bog] # prepare the next auto-regressive generation if mp_idx < len(mask_positions) - 1: diff --git a/training/deepspeed_training.py b/training/deepspeed_training.py index 10538ac..e59c0bf 100644 --- a/training/deepspeed_training.py +++ b/training/deepspeed_training.py @@ -203,6 +203,12 @@ def get_params_for_weight_decay_optimization(module): no_weight_decay_params['params'].extend( [p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias' and p.requires_grad]) + + if len(weight_decay_params['params']) == 0: + return tuple(no_weight_decay_params) + elif len(no_weight_decay_params['params']) == 0: + return tuple(weight_decay_params) + return weight_decay_params, no_weight_decay_params def get_optimizer_param_groups(model): -- GitLab