From d039667dc7d6e294619be567e9901000a034317b Mon Sep 17 00:00:00 2001 From: Zhengxiao Du <zx-du20@mails.tsinghua.edu.cn> Date: Wed, 20 Oct 2021 20:00:00 +0800 Subject: [PATCH] Fix attention mask in GLM --- inference_glm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/inference_glm.py b/inference_glm.py index b88c2a7..376773d 100644 --- a/inference_glm.py +++ b/inference_glm.py @@ -162,7 +162,8 @@ def sample_sequence(model, tokenizer, context_tokens, context_length, args, devi position_ids = context_tokens.new_ones(last_beam_num, 2, 1) position_ids[:, 0] = context_length position_ids[:, 1] = counter + 1 - attention_mask = context_tokens.new_zeros([1], device=context_tokens.device, dtype=torch.long) + attention_mask = context_tokens.new_ones(1, context_length + counter, device=context_tokens.device, + dtype=torch.long) else: position_ids = context_tokens.new_ones((last_beam_num, 1)) * (context_length + counter - 1) attention_mask = context_tokens.new_ones(last_beam_num, 1, 1, args.mem_length + 1, -- GitLab