diff --git a/inference_glm.py b/inference_glm.py
index b88c2a73099f02ba4881558d0870afc72d40f95c..376773d3b5f13f3fcfc35fb51f0f4123851bb881 100644
--- a/inference_glm.py
+++ b/inference_glm.py
@@ -162,7 +162,8 @@ def sample_sequence(model, tokenizer, context_tokens, context_length, args, devi
                     position_ids = context_tokens.new_ones(last_beam_num, 2, 1)
                     position_ids[:, 0] = context_length
                     position_ids[:, 1] = counter + 1
-                attention_mask = context_tokens.new_zeros([1], device=context_tokens.device, dtype=torch.long)
+                attention_mask = context_tokens.new_ones(1, context_length + counter, device=context_tokens.device,
+                                                         dtype=torch.long)
             else:
                 position_ids = context_tokens.new_ones((last_beam_num, 1)) * (context_length + counter - 1)
                 attention_mask = context_tokens.new_ones(last_beam_num, 1, 1, args.mem_length + 1,