From d039667dc7d6e294619be567e9901000a034317b Mon Sep 17 00:00:00 2001
From: Zhengxiao Du <zx-du20@mails.tsinghua.edu.cn>
Date: Wed, 20 Oct 2021 20:00:00 +0800
Subject: [PATCH] Fix attention mask in GLM

---
 inference_glm.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/inference_glm.py b/inference_glm.py
index b88c2a7..376773d 100644
--- a/inference_glm.py
+++ b/inference_glm.py
@@ -162,7 +162,8 @@ def sample_sequence(model, tokenizer, context_tokens, context_length, args, devi
                     position_ids = context_tokens.new_ones(last_beam_num, 2, 1)
                     position_ids[:, 0] = context_length
                     position_ids[:, 1] = counter + 1
-                attention_mask = context_tokens.new_zeros([1], device=context_tokens.device, dtype=torch.long)
+                attention_mask = context_tokens.new_ones(1, context_length + counter, device=context_tokens.device,
+                                                         dtype=torch.long)
             else:
                 position_ids = context_tokens.new_ones((last_beam_num, 1)) * (context_length + counter - 1)
                 attention_mask = context_tokens.new_ones(last_beam_num, 1, 1, args.mem_length + 1,
-- 
GitLab