Skip to content
Snippets Groups Projects
Commit a61c570f authored by Ming Ding's avatar Ming Ding
Browse files

Merge remote-tracking branch 'origin/dev' into dev

parents a374aeb7 04752ae3
No related branches found
No related tags found
No related merge requests found
...@@ -8,5 +8,4 @@ MODEL_ARGS="--block-lm \ ...@@ -8,5 +8,4 @@ MODEL_ARGS="--block-lm \
--max-sequence-length 1025 \ --max-sequence-length 1025 \
--tokenizer-model-type gpt2 \ --tokenizer-model-type gpt2 \
--tokenizer-type GPT2BPETokenizer \ --tokenizer-type GPT2BPETokenizer \
--old-checkpoint \
--load ${CHECKPOINT_PATH}/blocklm-10b-1024" --load ${CHECKPOINT_PATH}/blocklm-10b-1024"
\ No newline at end of file
MODEL_TYPE="blocklm-10B-chinese"
MODEL_ARGS="--block-lm \
--cloze-eval \
--task-mask \
--num-layers 48 \
--hidden-size 4096 \
--num-attention-heads 64 \
--max-sequence-length 1025 \
--tokenizer-type glm_ChineseSPTokenizer \
--tokenizer-model-type glm-10b \
--load /dataset/fd5061f6/english_data/checkpoints/blocklm-10b-chinese07-08-15-28"
\ No newline at end of file
MODEL_TYPE="blocklm-large-chinese"
MODEL_ARGS="--block-lm \
--cloze-eval \
--task-mask \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--max-sequence-length 1025 \
--tokenizer-type glm_ChineseSPTokenizer \
--tokenizer-model-type glm-large \
--load /dataset/fd5061f6/english_data/checkpoints/blocklm-large-chinese"
\ No newline at end of file
...@@ -7,5 +7,4 @@ MODEL_ARGS="--block-lm \ ...@@ -7,5 +7,4 @@ MODEL_ARGS="--block-lm \
--max-sequence-length 513 \ --max-sequence-length 513 \
--tokenizer-model-type roberta \ --tokenizer-model-type roberta \
--tokenizer-type GPT2BPETokenizer \ --tokenizer-type GPT2BPETokenizer \
--old-checkpoint \
--load ${CHECKPOINT_PATH}/blocklm-roberta-large-blank" --load ${CHECKPOINT_PATH}/blocklm-roberta-large-blank"
\ No newline at end of file
...@@ -27,6 +27,7 @@ from generation.autoregressive_sampling import filling_sequence ...@@ -27,6 +27,7 @@ from generation.autoregressive_sampling import filling_sequence
from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy
from generation.utils import timed_name, generate_continually from generation.utils import timed_name, generate_continually
def get_masks_and_position_ids_glm(seq, mask_position, context_length): def get_masks_and_position_ids_glm(seq, mask_position, context_length):
tokens = seq.unsqueeze(0) tokens = seq.unsqueeze(0)
...@@ -45,6 +46,7 @@ def get_masks_and_position_ids_glm(seq, mask_position, context_length): ...@@ -45,6 +46,7 @@ def get_masks_and_position_ids_glm(seq, mask_position, context_length):
def main(args): def main(args):
args.do_train = False
initialize_distributed(args) initialize_distributed(args)
tokenizer = prepare_tokenizer(args) tokenizer = prepare_tokenizer(args)
# build model # build model
...@@ -77,14 +79,14 @@ def main(args): ...@@ -77,14 +79,14 @@ def main(args):
raise ValueError('text too long.') raise ValueError('text too long.')
# find mask tokens positions # find mask tokens positions
mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK'] # mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens] # mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
mask_positions = [] # mask_positions = []
context_tokens_tensor = torch.tensor(seq, dtype=torch.long, device=args.device) # context_tokens_tensor = torch.tensor(seq, dtype=torch.long, device=args.device)
for token in mask_tokens: # for token in mask_tokens:
mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist() # mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
mask_positions.sort() # mask_positions.sort()
# generation # generation
mbz = args.max_inference_batch_size mbz = args.max_inference_batch_size
assert args.batch_size < mbz or args.batch_size % mbz == 0 assert args.batch_size < mbz or args.batch_size % mbz == 0
...@@ -107,7 +109,9 @@ def main(args): ...@@ -107,7 +109,9 @@ def main(args):
get_func = partial(get_masks_and_position_ids_glm, mask_position=mask_position, context_length=len(seq)) get_func = partial(get_masks_and_position_ids_glm, mask_position=mask_position, context_length=len(seq))
output_list = [] output_list = []
for tim in range(max(args.batch_size // mbz, 1)): for tim in range(max(args.batch_size // mbz, 1)):
input_seq = torch.cuda.LongTensor(seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length-len(seq)-1), device=args.device) input_seq = torch.cuda.LongTensor(
seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length - len(seq) - 1),
device=args.device)
output, _mems = filling_sequence(model, input_seq, output, _mems = filling_sequence(model, input_seq,
batch_size=min(args.batch_size, mbz), batch_size=min(args.batch_size, mbz),
strategy=strategy, strategy=strategy,
...@@ -116,7 +120,7 @@ def main(args): ...@@ -116,7 +120,7 @@ def main(args):
) # we don't use mems, fill back ) # we don't use mems, fill back
if isinstance(output, torch.Tensor): # different strategies if isinstance(output, torch.Tensor): # different strategies
output = list(output) output = list(output)
output_list.extend(output) output_list.extend(output)
# clip -1s and fill back generated things into seq # clip -1s and fill back generated things into seq
...@@ -126,10 +130,10 @@ def main(args): ...@@ -126,10 +130,10 @@ def main(args):
unfinished = output.index(-1) unfinished = output.index(-1)
except ValueError: except ValueError:
unfinished = len(output) unfinished = len(output)
if output[unfinished-1] in end_tokens: if output[unfinished - 1] in end_tokens:
unfinished -= 1 unfinished -= 1
bog = output.index(tokenizer.get_command('sop').Id) bog = output.index(tokenizer.get_command('sop').Id)
output_list[i] = output[:mask_position] + output[bog+1:unfinished] + output[mask_position+1:bog] output_list[i] = output[:mask_position] + output[bog + 1:unfinished] + output[mask_position + 1:bog]
# decoding # decoding
txts = [] txts = []
...@@ -147,11 +151,12 @@ def main(args): ...@@ -147,11 +151,12 @@ def main(args):
with open(full_path, 'w') as fout: with open(full_path, 'w') as fout:
for txt in txts: for txt in txts:
fout.write(txt + '\n') fout.write(txt + '\n')
os.chmod(full_path, stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU) os.chmod(full_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU)
os.makedirs(args.output_path, exist_ok=True) os.makedirs(args.output_path, exist_ok=True)
generate_continually(process, args.input_source) generate_continually(process, args.input_source)
if __name__ == "__main__": if __name__ == "__main__":
py_parser = argparse.ArgumentParser(add_help=False) py_parser = argparse.ArgumentParser(add_help=False)
...@@ -160,4 +165,4 @@ if __name__ == "__main__": ...@@ -160,4 +165,4 @@ if __name__ == "__main__":
args = argparse.Namespace(**vars(args), **vars(known)) args = argparse.Namespace(**vars(args), **vars(known))
with torch.no_grad(): with torch.no_grad():
main(args) main(args)
\ No newline at end of file
#!/bin/bash #!/bin/bash
CHECKPOINT_PATH=pretrained/glm #CHECKPOINT_PATH=/workspace/dm/SwissArmyTransformer/pretrained/glm
# MODEL_ARGS="--block-lm \ # MODEL_ARGS="--block-lm \
# --cloze-eval \ # --cloze-eval \
...@@ -11,19 +11,20 @@ CHECKPOINT_PATH=pretrained/glm ...@@ -11,19 +11,20 @@ CHECKPOINT_PATH=pretrained/glm
# --tokenizer-type glm_GPT2BPETokenizer \ # --tokenizer-type glm_GPT2BPETokenizer \
# --load ${CHECKPOINT_PATH}/glm-roberta-large-blank" # --load ${CHECKPOINT_PATH}/glm-roberta-large-blank"
MODEL_TYPE="blocklm-10B" #MODEL_TYPE="blocklm-10B"
MODEL_ARGS="--block-lm \ #MODEL_ARGS="--block-lm \
--cloze-eval \ # --cloze-eval \
--task-mask \ # --task-mask \
--num-layers 48 \ # --num-layers 48 \
--hidden-size 4096 \ # --hidden-size 4096 \
--num-attention-heads 64 \ # --num-attention-heads 64 \
--max-sequence-length 1025 \ # --max-sequence-length 1025 \
--tokenizer-model-type gpt2 \ # --tokenizer-model-type gpt2 \
--tokenizer-type glm_GPT2BPETokenizer \ # --tokenizer-type glm_GPT2BPETokenizer \
--old-checkpoint \ # --old-checkpoint \
--load ${CHECKPOINT_PATH}/glm-en-10b" # --load ${CHECKPOINT_PATH}/glm-en-10b"
source $1
MPSIZE=1 MPSIZE=1
MAXSEQLEN=512 MAXSEQLEN=512
MASTER_PORT=$(shuf -n 1 -i 10000-65535) MASTER_PORT=$(shuf -n 1 -i 10000-65535)
...@@ -52,5 +53,5 @@ python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTE ...@@ -52,5 +53,5 @@ python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTE
--top_k $TOPK \ --top_k $TOPK \
--output-path glm_text \ --output-path glm_text \
--batch-size 1 \ --batch-size 1 \
--out-seq-length 100 \ --out-seq-length 512 \
--mode inference --mode inference
...@@ -34,7 +34,7 @@ def get_tokenizer(args=None): ...@@ -34,7 +34,7 @@ def get_tokenizer(args=None):
get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs) get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)
elif args.tokenizer_type == "glm_ChineseSPTokenizer": elif args.tokenizer_type == "glm_ChineseSPTokenizer":
from .text import ChineseSPTokenizer from .text import ChineseSPTokenizer
get_tokenizer.tokenizer = ChineseSPTokenizer(**kwargs) get_tokenizer.tokenizer = ChineseSPTokenizer(args.tokenizer_model_type, **kwargs)
else: else:
assert args.vocab_size > 0 assert args.vocab_size > 0
get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size) get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size)
......
...@@ -1157,29 +1157,57 @@ class GPT2BPETokenizer(Tokenizer): ...@@ -1157,29 +1157,57 @@ class GPT2BPETokenizer(Tokenizer):
class ChineseSPTokenizer(Tokenizer): class ChineseSPTokenizer(Tokenizer):
def __init__(self, add_block_symbols=False, **kwargs): def __init__(self, model_type_or_path, add_block_symbols=False, add_task_mask=False, add_decoder_mask=False,
**kwargs):
self.text_tokenizer = sp_tokenizer.from_pretrained() self.text_tokenizer = sp_tokenizer.from_pretrained()
self.num_command_tokens = 2 self.num_command_tokens = 0
self.num_text_tokens = self.text_tokenizer.sp.vocab_size() self.num_text_tokens = self.text_tokenizer.sp.vocab_size()
self.num_tokens = self.num_text_tokens + 1 self.num_tokens = self.num_text_tokens
self.num_type_tokens = 2 self.num_type_tokens = 2
self._command_tokens = [ self._command_tokens = [
CommandToken('pad', '<|endoftext|>', self.num_text_tokens), CommandToken('pad', '<|endoftext|>', self.num_text_tokens),
CommandToken('eos', '<|endoftext|>', self.num_text_tokens), CommandToken('eos', '<|endoftext|>', self.num_text_tokens),
CommandToken('sep', '[SEP]', self.num_text_tokens + 1),
CommandToken('ENC', '[CLS]', self.num_text_tokens + 2),
CommandToken('MASK', '[MASK]', self.num_text_tokens + 3, lstrip=True),
CommandToken('unk', '[UNK]', self.num_text_tokens + 4)
] ]
self.num_tokens += 5
self.num_command_tokens += 6
if add_block_symbols: if add_block_symbols:
self._command_tokens.extend([ self._command_tokens.extend([
CommandToken('sop', '<|startofpiece|>', self.num_text_tokens + 1), CommandToken('sop', '<|startofpiece|>', self.num_tokens + 1),
CommandToken('eop', '<|endofpiece|>', self.num_text_tokens + 2) CommandToken('eop', '<|endofpiece|>', self.num_tokens + 2)
]) ])
self.num_tokens += 2 if model_type_or_path == 'glm-large':
self.num_tokens += 3
else:
self.num_tokens += 2
self.num_command_tokens += 2 self.num_command_tokens += 2
if add_task_mask:
if model_type_or_path == 'glm-large':
self._command_tokens.extend([
CommandToken('sMASK', '[sMASK]', self.num_tokens, lstrip=True),
CommandToken('gMASK', '[gMASK]', self.num_tokens + 1, lstrip=True)
])
else:
self._command_tokens.extend([
CommandToken('gMASK', '[gMASK]', self.num_tokens, lstrip=True),
CommandToken('sMASK', '[sMASK]', self.num_tokens + 1, lstrip=True)
])
self.num_tokens += 2
self.num_command_tokens += 2
if add_decoder_mask:
self._command_tokens.extend([
CommandToken('dBLOCK', '[dBLOCK]', self.num_tokens)
])
self.num_tokens += 1
self.num_command_tokens += 1
self.command_name_map = {tok.name: tok for tok in self._command_tokens} self.command_name_map = {tok.name: tok for tok in self._command_tokens}
self.command_token_map = {tok.token: tok for tok in self._command_tokens} self.command_token_map = {tok.token: tok for tok in self._command_tokens}
self.command_id_map = {tok.Id: tok for tok in self._command_tokens} self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
self.type_tokens = [ self.type_tokens = [
TypeToken('str0', '<str0>', 0), TypeToken('str0', '<str0>', 0),
TypeToken('str1', '<str1>', 1), TypeToken('str1', '<str1>', 1),
...@@ -1224,7 +1252,7 @@ class ChineseSPTokenizer(Tokenizer): ...@@ -1224,7 +1252,7 @@ class ChineseSPTokenizer(Tokenizer):
elif Id in self.type_id_map: elif Id in self.type_id_map:
return self.type_id_map[Id].token return self.type_id_map[Id].token
else: else:
return self.text_tokenizer.convert_id_to_token(Id) return self.text_tokenizer.convert_id_to_token(int(Id))
def TokenToId(self, token, type_token=False): def TokenToId(self, token, type_token=False):
if isinstance(token, (TypeToken, CommandToken)): if isinstance(token, (TypeToken, CommandToken)):
...@@ -1238,13 +1266,22 @@ class ChineseSPTokenizer(Tokenizer): ...@@ -1238,13 +1266,22 @@ class ChineseSPTokenizer(Tokenizer):
return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids) return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
if isinstance(Ids, Tokenization): if isinstance(Ids, Tokenization):
Ids = Ids.tokenization Ids = Ids.tokenization
try: Ids = list(map(int, Ids))
first_eos = Ids.index(self.get_command('eos').Id) pieces = []
eos_count = len(Ids) - first_eos last = 0
Ids = Ids[:first_eos] for i, token_id in enumerate(Ids):
except ValueError: if token_id in self.command_id_map:
eos_count = 0 pieces.append(Ids[last: i])
return " ".join((self.text_tokenizer.decode(Ids), *(['<|endoftext|>'] * eos_count))) pieces.append(token_id)
last = i + 1
pieces.append(Ids[last:])
text = ""
for piece in pieces:
if isinstance(piece, int):
text += self.command_id_map[piece].token
elif piece:
text += self.text_tokenizer.decode(piece)
return text
def DecodeTokens(self, Tokens, type_token=False): def DecodeTokens(self, Tokens, type_token=False):
if type_token: if type_token:
......
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
''' '''
@File : model_io.py @File : model_io.py
...@@ -18,6 +17,7 @@ import numpy as np ...@@ -18,6 +17,7 @@ import numpy as np
import mpu import mpu
from .utils import print_rank_0 from .utils import print_rank_0
def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False): def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False):
if release: if release:
d = 'release' d = 'release'
...@@ -28,6 +28,7 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False): ...@@ -28,6 +28,7 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False):
d += '_zero_dp_rank_{}'.format(dp_rank) d += '_zero_dp_rank_{}'.format(dp_rank)
return os.path.join(checkpoints_path, d, 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank())) return os.path.join(checkpoints_path, d, 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank()))
def get_checkpoint_tracker_filename(checkpoints_path, old_checkpoint=False): def get_checkpoint_tracker_filename(checkpoints_path, old_checkpoint=False):
if old_checkpoint: if old_checkpoint:
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
...@@ -70,9 +71,9 @@ def save_ds_checkpoint(iteration, model, lr_scheduler, args): ...@@ -70,9 +71,9 @@ def save_ds_checkpoint(iteration, model, lr_scheduler, args):
sd['cuda_rng_state'] = torch.cuda.get_rng_state() sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
save_ds_checkpoint_no_optim(model, args.save, str(iteration), client_state=sd) save_ds_checkpoint_no_optim(model, args.save, str(iteration), client_state=sd)
def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save_latest=True): def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save_latest=True):
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
# Ensure tag is a string # Ensure tag is a string
tag = str(tag) tag = str(tag)
...@@ -112,6 +113,7 @@ def get_checkpoint_iteration(args): ...@@ -112,6 +113,7 @@ def get_checkpoint_iteration(args):
return iteration, release, True return iteration, release, True
def load_checkpoint(model, args): def load_checkpoint(model, args):
"""Load a model checkpoint.""" """Load a model checkpoint."""
...@@ -131,10 +133,17 @@ def load_checkpoint(model, args): ...@@ -131,10 +133,17 @@ def load_checkpoint(model, args):
else: # inference without deepspeed else: # inference without deepspeed
module = model module = model
# sd['module']['transformer.word_embeddings.weight'] = sd['module']['word_embeddings.weight']
# del sd['module']['word_embeddings.weight']
# sd['module']['mixins.block_position_embedding.block_position_embeddings.weight'] = sd['module'][
# 'transformer.block_position_embeddings.weight']
# del sd['module']['transformer.block_position_embeddings.weight']
# only load module, other hyperparameters are just for recording. # only load module, other hyperparameters are just for recording.
missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False) missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False)
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
print_rank_0(f'Will continue but found unexpected_keys! Check whether you are loading correct checkpoints: {unexpected_keys}.') print_rank_0(
f'Will continue but found unexpected_keys! Check whether you are loading correct checkpoints: {unexpected_keys}.')
if len(missing_keys) > 0: if len(missing_keys) > 0:
if not args.do_train: if not args.do_train:
raise ValueError(f'Missing keys for inference: {missing_keys}.') raise ValueError(f'Missing keys for inference: {missing_keys}.')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment