From 3d55652ba3b8e9d21f4cdeb60fef4fff5b783bcb Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Thu, 28 Oct 2021 09:12:20 +0000
Subject: [PATCH] fix glm mask bug

---
 inference_glm.py               | 5 +++--
 training/deepspeed_training.py | 6 ++++++
 2 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/inference_glm.py b/inference_glm.py
index 0e5e44d..a295e41 100644
--- a/inference_glm.py
+++ b/inference_glm.py
@@ -32,6 +32,7 @@ def get_masks_and_position_ids_glm(seq, mask_position, context_length):
 
     attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
     attention_mask.tril_()
+    attention_mask[..., :context_length] = 1
     attention_mask.unsqueeze_(1)
 
     position_ids = torch.zeros(2, len(seq), device=tokens.device, dtype=torch.long)
@@ -71,7 +72,7 @@ def main(args):
         seq = [tokenizer.get_command('ENC').Id] + seq
         if not raw_text.endswith('MASK]'):
             seq = seq + [tokenizer.get_command('eos').Id]
-        print('raw text: ', raw_text)
+        print('raw text: {}\n'.format(raw_text))
         if len(seq) > args.max_sequence_length:
             raise ValueError('text too long.')
         
@@ -112,7 +113,7 @@ def main(args):
                 except ValueError:
                     unfinished = len(output)
                 bog = output.index(tokenizer.get_command('sop').Id)
-                output_list[i] = output[:mask_position] + output[bog+1:unfinished] + output[mask_position+1:bog]
+                output_list[i] = output[:mask_position] + output[bog:unfinished] + output[mask_position+1:bog]
             
             # prepare the next auto-regressive generation
             if mp_idx < len(mask_positions) - 1: 
diff --git a/training/deepspeed_training.py b/training/deepspeed_training.py
index 10538ac..e59c0bf 100644
--- a/training/deepspeed_training.py
+++ b/training/deepspeed_training.py
@@ -203,6 +203,12 @@ def get_params_for_weight_decay_optimization(module):
             no_weight_decay_params['params'].extend(
                 [p for n, p in list(module_._parameters.items())
                  if p is not None and n == 'bias' and p.requires_grad])
+
+    if len(weight_decay_params['params']) == 0:
+        return tuple(no_weight_decay_params)
+    elif len(no_weight_decay_params['params']) == 0:
+        return tuple(weight_decay_params)
+
     return weight_decay_params, no_weight_decay_params
 
 def get_optimizer_param_groups(model):
-- 
GitLab