diff --git a/inference_glm.py b/inference_glm.py
index a295e416771a1feedb63d6a86e9fcd8524f16b30..66dabaeb1258064a74802742e819d083cddb5c2a 100644
--- a/inference_glm.py
+++ b/inference_glm.py
@@ -88,10 +88,24 @@ def main(args):
         # generation
         mbz = args.max_inference_batch_size
         assert args.batch_size < mbz or args.batch_size % mbz == 0
-        output_list = []
-        # call for each position
-        for mp_idx, mask_position in enumerate(mask_positions):
+        output_list = [seq]
+        # continually detect the first mark position
+        while True:
+            seq = output_list[0] # TODO find the best one
+            # detect
+            mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
+            mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
+            mask_position = len(seq)
+            for token in mask_tokens:
+                try:
+                    mask_position = min(mask_position, seq.index(token))
+                except ValueError:
+                    pass
+            if mask_position == len(seq):
+                break
+            
             get_func = partial(get_masks_and_position_ids_glm, mask_position=mask_position, context_length=len(seq))
+            output_list = []
             for tim in range(max(args.batch_size // mbz, 1)):
                 input_seq = torch.cuda.LongTensor(seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length-len(seq)-1), device=args.device)
                 output, _mems = filling_sequence(model, input_seq,
@@ -112,14 +126,10 @@ def main(args):
                     unfinished = output.index(-1)
                 except ValueError:
                     unfinished = len(output)
+                if output[unfinished-1] in end_tokens:
+                    unfinished -= 1
                 bog = output.index(tokenizer.get_command('sop').Id)
-                output_list[i] = output[:mask_position] + output[bog:unfinished] + output[mask_position+1:bog]
-            
-            # prepare the next auto-regressive generation
-            if mp_idx < len(mask_positions) - 1: 
-                # TODO, here to select the best for this time, inverse prompting?
-                seq = output_list[0]
-                output_list = []
+                output_list[i] = output[:mask_position] + output[bog+1:unfinished] + output[mask_position+1:bog]
 
         # decoding
         txts = []