diff --git a/arguments.py b/arguments.py index dbdeea61f4199fc940f6988fa0d4897f66d561a2..ebefb2fb78ef1af4c57ff506c29654be8391926d 100755 --- a/arguments.py +++ b/arguments.py @@ -250,6 +250,7 @@ def add_glm_args(parser): group.add_argument('--random-position', action='store_true', help="Use random start position to cover all the position embeddings") group.add_argument('--cloze-eval', action='store_true', help='Evaluation dataset with cloze task') + group.add_argument('--old-checkpoint', action='store_true', help="Loading the checkpoint from old libraray") return parser diff --git a/config/model_glm_10B.sh b/config/model_glm_10B.sh new file mode 100644 index 0000000000000000000000000000000000000000..5ca0df0ed1a753f49baeae6818b0872268b35e91 --- /dev/null +++ b/config/model_glm_10B.sh @@ -0,0 +1,12 @@ +MODEL_TYPE="blocklm-10B" +MODEL_ARGS="--block-lm \ + --cloze-eval \ + --task-mask \ + --num-layers 48 \ + --hidden-size 4096 \ + --num-attention-heads 64 \ + --max-sequence-length 1025 \ + --tokenizer-model-type gpt2 \ + --tokenizer-type GPT2BPETokenizer \ + --old-checkpoint \ + --load ${CHECKPOINT_PATH}/blocklm-10b-1024" \ No newline at end of file diff --git a/config/model_glm_roberta_large.sh b/config/model_glm_roberta_large.sh index 6fb7216ad23e370ab5cccaeadd626d19e5d84364..5c4eaf200fb2d6b91550db9112beae7f56e1ba98 100644 --- a/config/model_glm_roberta_large.sh +++ b/config/model_glm_roberta_large.sh @@ -7,4 +7,5 @@ MODEL_ARGS="--block-lm \ --max-sequence-length 513 \ --tokenizer-model-type roberta \ --tokenizer-type GPT2BPETokenizer \ + --old-checkpoint \ --load ${CHECKPOINT_PATH}/blocklm-roberta-large-blank" \ No newline at end of file diff --git a/inference_glm.py b/inference_glm.py index 376773d3b5f13f3fcfc35fb51f0f4123851bb881..0456cc550df2065cc954d9509fd3f370fa1aa26d 100644 --- a/inference_glm.py +++ b/inference_glm.py @@ -21,7 +21,7 @@ from model.glm_model import GLMModel from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer from tokenization import get_tokenizer from generation.sampling_strategies import BaseStrategy -from generation.autoregressive_sampling import filling_sequence +from generation.autoregressive_sampling import update_mems from generation.utils import timed_name, save_multiple_images, generate_continually @@ -80,19 +80,19 @@ def read_context(tokenizer, args, output=None): return terminate_runs, raw_text, context_tokens_tensor, context_length -def get_batch(context_tokens, device, args): +def get_batch(context_tokens, args): tokens = context_tokens tokens = tokens.view(1, -1).contiguous() - tokens = tokens.to(device) + tokens = tokens.to('cuda') # Get the masks and postition ids. if args.block_lm: - attention_mask = torch.ones(1, 1, tokens.size(1), tokens.size(1), device=device, dtype=torch.long) + attention_mask = torch.ones(tokens.size(1), tokens.size(1), device='cuda', dtype=torch.long) if args.fp16: attention_mask = attention_mask.half() - position_ids = torch.arange(tokens.size(1), device=device, dtype=torch.long) + position_ids = torch.arange(tokens.size(1), device='cuda', dtype=torch.long) if not args.no_block_position: - block_position_ids = torch.zeros(tokens.size(1), device=device, dtype=torch.long) + block_position_ids = torch.zeros(tokens.size(1), device='cuda', dtype=torch.long) position_ids = torch.stack((position_ids, block_position_ids), dim=0) position_ids = position_ids.unsqueeze(0) else: @@ -129,12 +129,8 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): return logits -def sample_sequence(model, tokenizer, context_tokens, context_length, args, device, mems=None, end_tokens=None): - if not args.block_lm: - context_tokens, attention_mask, position_ids = get_batch(context_tokens, device, args) - tokens = torch.empty((args.num_beams, 0), device=context_tokens.device, dtype=torch.long) - else: - tokens = context_tokens.new_full((1, 1), tokenizer.get_command('sop').Id) +def sample_sequence(model, tokenizer, context_tokens, context_length, args, mems=None, end_tokens=None): + tokens = context_tokens.new_full((1, 1), tokenizer.get_command('sop').Id) counter = 0 if mems is None: mems = [] @@ -152,24 +148,22 @@ def sample_sequence(model, tokenizer, context_tokens, context_length, args, devi beam_scores = torch.zeros(1, dtype=torch.float, device=context_tokens.device) last_beam_num = 1 while counter < args.out_seq_length: - if counter == 0 and not args.block_lm: - next_token_logits, *mems = model(context_tokens, position_ids, attention_mask, *mems) - else: - if args.block_lm: - if args.no_block_position: - position_ids = context_tokens.new_full((last_beam_num, 1), context_length + counter) - else: - position_ids = context_tokens.new_ones(last_beam_num, 2, 1) - position_ids[:, 0] = context_length - position_ids[:, 1] = counter + 1 - attention_mask = context_tokens.new_ones(1, context_length + counter, device=context_tokens.device, - dtype=torch.long) + if args.block_lm: + if args.no_block_position: + position_ids = context_tokens.new_full((last_beam_num, 1), context_length + counter) else: - position_ids = context_tokens.new_ones((last_beam_num, 1)) * (context_length + counter - 1) - attention_mask = context_tokens.new_ones(last_beam_num, 1, 1, args.mem_length + 1, - device=context_tokens.device, dtype=torch.float) - last_token = tokens[:, -1:] - next_token_logits, *mems = model(last_token, position_ids, attention_mask, *mems) + position_ids = context_tokens.new_ones(last_beam_num, 2, 1) + position_ids[:, 0] = context_length + position_ids[:, 1] = counter + 1 + attention_mask = context_tokens.new_ones(1, context_length + counter, device=context_tokens.device, + dtype=torch.long) + else: + position_ids = context_tokens.new_ones((last_beam_num, 1)) * (context_length + counter - 1) + attention_mask = context_tokens.new_ones(last_beam_num, 1, 1, args.mem_length + 1, + device=context_tokens.device, dtype=torch.float) + last_token = tokens[:, -1:] + next_token_logits, *mem_kvs = model(last_token, position_ids, attention_mask, *mems) + mems = update_mems(mem_kvs, mems, max_memory_length=1000000) next_token_logits = next_token_logits[:, -1] if args.num_beams > 1: next_token_scores = F.log_softmax(next_token_logits, dim=-1) @@ -228,7 +222,7 @@ def sample_sequence(model, tokenizer, context_tokens, context_length, args, devi return torch.cat((context_tokens, tokens), dim=1), mems -def generate_samples(model, tokenizer, args, device): +def generate_samples(model, tokenizer, args): model.eval() output_path = "./samples" if not os.path.exists(output_path): @@ -237,14 +231,13 @@ def generate_samples(model, tokenizer, args, device): with torch.no_grad(), open(output_path, "w") as output: while True: torch.distributed.barrier(group=mpu.get_model_parallel_group()) - terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output) if terminate_runs == 1: return start_time = time.time() if args.block_lm: mems = [] - tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, device, args) + tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args) mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK'] mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens] end_tokens = [tokenizer.get_command('eop').Id, args.eod_token] @@ -261,10 +254,8 @@ def generate_samples(model, tokenizer, args, device): position = position_ids[0, mask_position].item() else: position = mask_position - tokens, mems = sample_sequence(model, tokenizer, tokens, position, - args, device, mems=mems, end_tokens=end_tokens) - else: - tokens, _ = sample_sequence(model, tokenizer, context_tokens_tensor, context_length, args, device) + tokens, mems = sample_sequence(model, tokenizer, tokens, position, args, mems=mems, + end_tokens=end_tokens) output_tokens_list = tokens.view(-1).contiguous() if mpu.get_model_parallel_rank() == 0: os.system('clear') @@ -290,7 +281,7 @@ def main(args): load_checkpoint(model, args) set_random_seed(args.seed) model.eval() - generate_samples(model, tokenizer, args, torch.cuda.current_device()) + generate_samples(model, tokenizer, args) if __name__ == "__main__": diff --git a/mpu/transformer.py b/mpu/transformer.py index 975167234ab05efb85dbdcabf0b8432313b09529..bad1a9ce47aa4dfdb167ff090cbdb3055ec44703 100755 --- a/mpu/transformer.py +++ b/mpu/transformer.py @@ -52,9 +52,9 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, if log_attention_weights is not None: attention_scores += log_attention_weights - if attention_mask.shape[-2] > 1: # if auto-regressive, skip - attention_scores = torch.mul(attention_scores, attention_mask) - \ - 10000.0 * (1.0 - attention_mask) + # if attention_mask.shape[-2] > 1: # if auto-regressive, skip + # attention_scores = torch.mul(attention_scores, attention_mask) - \ + # 10000.0 * (1.0 - attention_mask) attention_probs = F.softmax(attention_scores, dim=-1) diff --git a/scripts/generate_glm.sh b/scripts/generate_glm.sh index af5a59a418c6c9ebe56e3e0990bcfef74394aebf..30e6d9eb1a367b61acb782012cd90889fb1ef2f6 100644 --- a/scripts/generate_glm.sh +++ b/scripts/generate_glm.sh @@ -1,5 +1,5 @@ #!/bin/bash -CHECKPOINT_PATH=./checkpoints +CHECKPOINT_PATH=/dataset/fd5061f6/english_data/checkpoints source $1 diff --git a/training/model_io.py b/training/model_io.py index df6b7524f6ac8334243af401f4147c9a34dfffae..8de74ea100da51f0212a53b64bf2ebd306cf786f 100644 --- a/training/model_io.py +++ b/training/model_io.py @@ -28,8 +28,12 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False): d += '_zero_dp_rank_{}'.format(dp_rank) return os.path.join(checkpoints_path, d, 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank())) -def get_checkpoint_tracker_filename(checkpoints_path): - return os.path.join(checkpoints_path, 'latest') +def get_checkpoint_tracker_filename(checkpoints_path, old_checkpoint=False): + if old_checkpoint: + return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') + else: + return os.path.join(checkpoints_path, 'latest') + def save_checkpoint(iteration, model, optimizer, lr_scheduler, args): @@ -85,7 +89,7 @@ def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save def get_checkpoint_iteration(args): # Read the tracker file and set the iteration. - tracker_filename = get_checkpoint_tracker_filename(args.load) + tracker_filename = get_checkpoint_tracker_filename(args.load, old_checkpoint=args.old_checkpoint) if not os.path.isfile(tracker_filename): print_rank_0('WARNING: could not find the metadata file {} '.format( tracker_filename)) @@ -126,7 +130,12 @@ def load_checkpoint(model, args): module = model.module else: # inference without deepspeed module = model - + + # Process the checkpoint for GLM + if args.block_lm and args.old_checkpoint: + sd['module']['transformer.word_embeddings.weight'] = sd['module']['word_embeddings.weight'] + del sd['module']['word_embeddings.weight'] + # only load module, other hyperparameters are just for recording. missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False) if len(unexpected_keys) > 0: