From 2fa0908e2a2a0e250143d77a2f6aa5b2d468fa75 Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Thu, 28 Oct 2021 09:35:10 +0000 Subject: [PATCH] fix multiple fill back bug --- inference_glm.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/inference_glm.py b/inference_glm.py index a295e41..66dabae 100644 --- a/inference_glm.py +++ b/inference_glm.py @@ -88,10 +88,24 @@ def main(args): # generation mbz = args.max_inference_batch_size assert args.batch_size < mbz or args.batch_size % mbz == 0 - output_list = [] - # call for each position - for mp_idx, mask_position in enumerate(mask_positions): + output_list = [seq] + # continually detect the first mark position + while True: + seq = output_list[0] # TODO find the best one + # detect + mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK'] + mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens] + mask_position = len(seq) + for token in mask_tokens: + try: + mask_position = min(mask_position, seq.index(token)) + except ValueError: + pass + if mask_position == len(seq): + break + get_func = partial(get_masks_and_position_ids_glm, mask_position=mask_position, context_length=len(seq)) + output_list = [] for tim in range(max(args.batch_size // mbz, 1)): input_seq = torch.cuda.LongTensor(seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length-len(seq)-1), device=args.device) output, _mems = filling_sequence(model, input_seq, @@ -112,14 +126,10 @@ def main(args): unfinished = output.index(-1) except ValueError: unfinished = len(output) + if output[unfinished-1] in end_tokens: + unfinished -= 1 bog = output.index(tokenizer.get_command('sop').Id) - output_list[i] = output[:mask_position] + output[bog:unfinished] + output[mask_position+1:bog] - - # prepare the next auto-regressive generation - if mp_idx < len(mask_positions) - 1: - # TODO, here to select the best for this time, inverse prompting? - seq = output_list[0] - output_list = [] + output_list[i] = output[:mask_position] + output[bog+1:unfinished] + output[mask_position+1:bog] # decoding txts = [] -- GitLab