diff --git a/inference_glm.py b/inference_glm.py index b88c2a73099f02ba4881558d0870afc72d40f95c..376773d3b5f13f3fcfc35fb51f0f4123851bb881 100644 --- a/inference_glm.py +++ b/inference_glm.py @@ -162,7 +162,8 @@ def sample_sequence(model, tokenizer, context_tokens, context_length, args, devi position_ids = context_tokens.new_ones(last_beam_num, 2, 1) position_ids[:, 0] = context_length position_ids[:, 1] = counter + 1 - attention_mask = context_tokens.new_zeros([1], device=context_tokens.device, dtype=torch.long) + attention_mask = context_tokens.new_ones(1, context_length + counter, device=context_tokens.device, + dtype=torch.long) else: position_ids = context_tokens.new_ones((last_beam_num, 1)) * (context_length + counter - 1) attention_mask = context_tokens.new_ones(last_beam_num, 1, 1, args.mem_length + 1,