Skip to content
Snippets Groups Projects
Commit 3d55652b authored by Ming Ding's avatar Ming Ding
Browse files

fix glm mask bug

parent f8d2632a
No related branches found
No related tags found
No related merge requests found
......@@ -32,6 +32,7 @@ def get_masks_and_position_ids_glm(seq, mask_position, context_length):
attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
attention_mask.tril_()
attention_mask[..., :context_length] = 1
attention_mask.unsqueeze_(1)
position_ids = torch.zeros(2, len(seq), device=tokens.device, dtype=torch.long)
......@@ -71,7 +72,7 @@ def main(args):
seq = [tokenizer.get_command('ENC').Id] + seq
if not raw_text.endswith('MASK]'):
seq = seq + [tokenizer.get_command('eos').Id]
print('raw text: ', raw_text)
print('raw text: {}\n'.format(raw_text))
if len(seq) > args.max_sequence_length:
raise ValueError('text too long.')
......@@ -112,7 +113,7 @@ def main(args):
except ValueError:
unfinished = len(output)
bog = output.index(tokenizer.get_command('sop').Id)
output_list[i] = output[:mask_position] + output[bog+1:unfinished] + output[mask_position+1:bog]
output_list[i] = output[:mask_position] + output[bog:unfinished] + output[mask_position+1:bog]
# prepare the next auto-regressive generation
if mp_idx < len(mask_positions) - 1:
......
......@@ -203,6 +203,12 @@ def get_params_for_weight_decay_optimization(module):
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias' and p.requires_grad])
if len(weight_decay_params['params']) == 0:
return tuple(no_weight_decay_params)
elif len(no_weight_decay_params['params']) == 0:
return tuple(weight_decay_params)
return weight_decay_params, no_weight_decay_params
def get_optimizer_param_groups(model):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment