diff --git a/inference_glm.py b/inference_glm.py index a295e416771a1feedb63d6a86e9fcd8524f16b30..66dabaeb1258064a74802742e819d083cddb5c2a 100644 --- a/inference_glm.py +++ b/inference_glm.py @@ -88,10 +88,24 @@ def main(args): # generation mbz = args.max_inference_batch_size assert args.batch_size < mbz or args.batch_size % mbz == 0 - output_list = [] - # call for each position - for mp_idx, mask_position in enumerate(mask_positions): + output_list = [seq] + # continually detect the first mark position + while True: + seq = output_list[0] # TODO find the best one + # detect + mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK'] + mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens] + mask_position = len(seq) + for token in mask_tokens: + try: + mask_position = min(mask_position, seq.index(token)) + except ValueError: + pass + if mask_position == len(seq): + break + get_func = partial(get_masks_and_position_ids_glm, mask_position=mask_position, context_length=len(seq)) + output_list = [] for tim in range(max(args.batch_size // mbz, 1)): input_seq = torch.cuda.LongTensor(seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length-len(seq)-1), device=args.device) output, _mems = filling_sequence(model, input_seq, @@ -112,14 +126,10 @@ def main(args): unfinished = output.index(-1) except ValueError: unfinished = len(output) + if output[unfinished-1] in end_tokens: + unfinished -= 1 bog = output.index(tokenizer.get_command('sop').Id) - output_list[i] = output[:mask_position] + output[bog:unfinished] + output[mask_position+1:bog] - - # prepare the next auto-regressive generation - if mp_idx < len(mask_positions) - 1: - # TODO, here to select the best for this time, inverse prompting? - seq = output_list[0] - output_list = [] + output_list[i] = output[:mask_position] + output[bog+1:unfinished] + output[mask_position+1:bog] # decoding txts = []