Skip to content
Snippets Groups Projects
Commit 3d55652b authored by Ming Ding's avatar Ming Ding
Browse files

fix glm mask bug

parent f8d2632a
Branches
Tags
No related merge requests found
......@@ -32,6 +32,7 @@ def get_masks_and_position_ids_glm(seq, mask_position, context_length):
attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
attention_mask.tril_()
attention_mask[..., :context_length] = 1
attention_mask.unsqueeze_(1)
position_ids = torch.zeros(2, len(seq), device=tokens.device, dtype=torch.long)
......@@ -71,7 +72,7 @@ def main(args):
seq = [tokenizer.get_command('ENC').Id] + seq
if not raw_text.endswith('MASK]'):
seq = seq + [tokenizer.get_command('eos').Id]
print('raw text: ', raw_text)
print('raw text: {}\n'.format(raw_text))
if len(seq) > args.max_sequence_length:
raise ValueError('text too long.')
......@@ -112,7 +113,7 @@ def main(args):
except ValueError:
unfinished = len(output)
bog = output.index(tokenizer.get_command('sop').Id)
output_list[i] = output[:mask_position] + output[bog+1:unfinished] + output[mask_position+1:bog]
output_list[i] = output[:mask_position] + output[bog:unfinished] + output[mask_position+1:bog]
# prepare the next auto-regressive generation
if mp_idx < len(mask_positions) - 1:
......
......@@ -203,6 +203,12 @@ def get_params_for_weight_decay_optimization(module):
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias' and p.requires_grad])
if len(weight_decay_params['params']) == 0:
return tuple(no_weight_decay_params)
elif len(no_weight_decay_params['params']) == 0:
return tuple(weight_decay_params)
return weight_decay_params, no_weight_decay_params
def get_optimizer_param_groups(model):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment