diff --git a/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py b/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py index b30926294ce26e4a890b51b89896b82de88b285a..9b996fb645dfbe485b65aa0f588e89ef65f4449a 100644 --- a/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py +++ b/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py @@ -11,7 +11,7 @@ import torch import torch.nn.functional as F class BeamSearchStrategy: - def __init__(self, num_beams, length_penalty=1., return_only_end=False, + def __init__(self, num_beams, length_penalty=1., consider_end=False, end_tokens=[], invalid_slices=[], no_repeat_ngram_size=0, min_tgt_length=0): self.num_beams = num_beams self.length_penalty = length_penalty @@ -19,7 +19,7 @@ class BeamSearchStrategy: self.ngram = no_repeat_ngram_size self.min_tgt_length = min_tgt_length self.invalid_slices = invalid_slices - self.return_only_end = return_only_end + self.consider_end = consider_end self._init_cache() def _init_cache(self): @@ -34,10 +34,10 @@ class BeamSearchStrategy: for i in range(len(self.end_beams), -1, -1): if i == 0 or score < self.end_beams_penalized_scores[i-1]: break - self.num_beams.insert(i, beam) + self.end_beams.insert(i, beam) self.end_beams_penalized_scores.insert(i, score) - self.num_beams = self.num_beams[:self.num_beams] + self.end_beams = self.end_beams[:self.num_beams] self.end_beams_penalized_scores = self.end_beams_penalized_scores[:self.num_beams] def forward(self, logits, tokens, mems): @@ -52,11 +52,14 @@ class BeamSearchStrategy: if self.ngram > 0 and seq_len > self.ngram: for i in range(batch_size): ngram_prefix = tokens[i, -(self.ngram-1):].tolist() # TODO ngram=1 - for banned_index in self.cached_beam_ngram_bans.get(tuple(ngram_prefix), default=[]): + for banned_index in self.cached_beam_ngram_bans[i].get(tuple(ngram_prefix), []): logits[i, banned_index] = -65504 next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size] - next_token_scores = next_token_scores + self.cached_beam_scores[:, None].expand_as(next_token_scores) + prev_scores = self.cached_beam_scores + if isinstance(self.cached_beam_scores, torch.Tensor): + prev_scores = prev_scores[:, None].expand_as(next_token_scores) + next_token_scores = next_token_scores + prev_scores next_token_scores = next_token_scores.view(batch_size * vocab_size) @@ -70,12 +73,14 @@ class BeamSearchStrategy: next_tokens = next_tokens % vocab_size # select out end beams or continue beams + if mems.shape[1] < batch_size: + mems = mems.expand(-1, batch_size, -1, -1) beam_continue = [] scores_continue = [] bans_continue = [] mems_contiue = [] for i in range(len(next_tokens)): - beam = torch.cat(tokens[next_indices[i]], next_tokens[i:i+1]) + beam = torch.cat((tokens[next_indices[i]], next_tokens[i:i+1])) if int(next_tokens[i]) in self.end_tokens: self._add_end_beams(next_token_scores[i], beam) elif len(beam_continue) < batch_size: @@ -98,10 +103,13 @@ class BeamSearchStrategy: # TODO is_done return tokens, mems - def finalize(self, tokens): - if not self.return_only_end: + def finalize(self, tokens, mems): + if self.consider_end: for i in range(tokens.shape[0]): self._add_end_beams(self.cached_beam_scores[i], tokens[i]) - ret = self.end_beams + mems = None + ret = self.end_beams + else: + ret = tokens self._init_cache() - return ret + return ret, mems diff --git a/examples/cogview2/inference_cogview2.py b/examples/cogview2/inference_cogview2.py index 823d4918cde2b1317edf5869c90abbfa2c034f87..9a423d3d7daa08c3a65c88a8cb38d10f952c1e03 100644 --- a/examples/cogview2/inference_cogview2.py +++ b/examples/cogview2/inference_cogview2.py @@ -53,7 +53,7 @@ def main(args): print('raw text: ', raw_text) text = query_template.format(raw_text) seq = tokenizer.parse_query(text, img_size=args.img_size) - if len(seq) > 1088: + if len(seq) > 1088: raise ValueError('text too long.') # calibrate text length txt_len = seq.index(tokenizer['[BASE]']) @@ -67,11 +67,12 @@ def main(args): assert args.batch_size < mbz or args.batch_size % mbz == 0 output_list = [] for tim in range(max(args.batch_size // mbz, 1)): - output0, mems = filling_sequence(model0, seq.clone(), + output0 = filling_sequence(model0, seq.clone(), batch_size=min(args.batch_size, mbz), strategy=strategy0, log_attention_weights=log_attention_weights - ) + )[0] + # auto del mems to save CUDA memory as possible imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1025:-1].tolist())) for x in output0] blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(args.device), add_normalization=True) # [batch_size, 4096] len_tim = output0.shape[0] diff --git a/examples/cogview2/scripts/large_scale_text2image_cogview2.sh b/examples/cogview2/scripts/large_scale_text2image_cogview2.sh index fb1eb174a23ceaa394b0181ffe0f09764fc64e32..91aa80eed30b14e0abf8394227ba001df906e8c8 100755 --- a/examples/cogview2/scripts/large_scale_text2image_cogview2.sh +++ b/examples/cogview2/scripts/large_scale_text2image_cogview2.sh @@ -1,6 +1,6 @@ #!/bin/bash -NUM_WORKERS=4 +NUM_WORKERS=5 NUM_GPUS_PER_WORKER=8 MP_SIZE=1 @@ -8,7 +8,7 @@ OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" HOST_FILE_PATH="hostfile" # HOST_FILE_PATH="hostfile_single" -CHECKPOINT_PATH=pretrained/cogview/cogview2-base +CHECKPOINT_PATH=/dataset/fd5061f6/sat_pretrained/cogview/cogview2-base NLAYERS=48 NHIDDEN=2560 NATT=40 @@ -25,7 +25,7 @@ script_dir=$(dirname $script_path) gpt_options=" \ --tokenizer-type cogview \ - --img-tokenizer-path pretrained/vqvae/l1+ms-ssim+revd_percep.pt \ + --img-tokenizer-path /dataset/fd5061f6/sat_pretrained/vqvae/l1+ms-ssim+revd_percep.pt \ --mode inference \ --distributed-backend nccl \ --max-sequence-length 1089 \ @@ -42,7 +42,7 @@ gpt_options=" \ --input-source ./coco30k.txt \ --output-path coco_samples \ --batch-size 60 \ - --max-inference-batch-size 12 \ + --max-inference-batch-size 6 \ --with-id \ " diff --git a/examples/glm/inference_glm.py b/examples/glm/inference_glm.py index bfc0d6957ce13b7024afba795db7f9d12bd212ab..a1f5b832d5e04dcda0c3a94e787c38b07c0a08d1 100644 --- a/examples/glm/inference_glm.py +++ b/examples/glm/inference_glm.py @@ -61,7 +61,12 @@ def main(args): end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id] # define function for each query - strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k,end_tokens=end_tokens) + if args.sampling_strategy == 'BaseStrategy': + strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k,end_tokens=end_tokens) + elif args.sampling_strategy == 'BeamSearchStrategy': + strategy = BeamSearchStrategy(args.batch_size, length_penalty=args.length_penalty, consider_end=True, end_tokens=end_tokens, no_repeat_ngram_size=args.no_repeat_ngram_size, min_tgt_length=args.min_tgt_length) + else: + raise ValueError(f'unknown strategy {args.sampling_strategy}') def process(raw_text): if args.with_id: @@ -77,15 +82,6 @@ def main(args): print('raw text: {}\n'.format(raw_text)) if len(seq) > args.max_sequence_length: raise ValueError('text too long.') - - # find mask tokens positions - # mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK'] - # mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens] - # mask_positions = [] - # context_tokens_tensor = torch.tensor(seq, dtype=torch.long, device=args.device) - # for token in mask_tokens: - # mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist() - # mask_positions.sort() # generation mbz = args.max_inference_batch_size @@ -112,12 +108,12 @@ def main(args): input_seq = torch.cuda.LongTensor( seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length - len(seq) - 1), device=args.device) - output, _mems = filling_sequence(model, input_seq, + output = filling_sequence(model, input_seq, batch_size=min(args.batch_size, mbz), strategy=strategy, log_attention_weights=None, get_masks_and_position_ids=get_func - ) # we don't use mems, fill back + )[0] # we don't use mems, fill back if isinstance(output, torch.Tensor): # different strategies output = list(output) @@ -159,7 +155,7 @@ def main(args): if __name__ == "__main__": py_parser = argparse.ArgumentParser(add_help=False) - + py_parser.add_argument('--sampling-strategy', type=str, default='BaseStrategy', help='type name of sampling strategy') known, args_list = py_parser.parse_known_args() args = get_args(args_list) args = argparse.Namespace(**vars(args), **vars(known)) diff --git a/examples/glm/scripts/generate_glm.sh b/examples/glm/scripts/generate_glm.sh index eea0ed2742e8762c9294a406f64dcc670540552f..9e1c1631005dbe9ea681e56b28b871d570faf07a 100755 --- a/examples/glm/scripts/generate_glm.sh +++ b/examples/glm/scripts/generate_glm.sh @@ -1,29 +1,6 @@ #!/bin/bash CHECKPOINT_PATH=/dataset/fd5061f6/sat_pretrained/glm -# MODEL_ARGS="--block-lm \ -# --cloze-eval \ -# --num-layers 24 \ -# --hidden-size 1024 \ -# --num-attention-heads 16 \ -# --max-sequence-length 513 \ -# --tokenizer-model-type roberta \ -# --tokenizer-type glm_GPT2BPETokenizer \ -# --load ${CHECKPOINT_PATH}/glm-roberta-large-blank" - -#MODEL_TYPE="blocklm-10B" -#MODEL_ARGS="--block-lm \ -# --cloze-eval \ -# --task-mask \ -# --num-layers 48 \ -# --hidden-size 4096 \ -# --num-attention-heads 64 \ -# --max-sequence-length 1025 \ -# --tokenizer-model-type gpt2 \ -# --tokenizer-type glm_GPT2BPETokenizer \ -# --old-checkpoint \ -# --load ${CHECKPOINT_PATH}/glm-en-10b" - source $1 MPSIZE=1 MAXSEQLEN=512 @@ -52,6 +29,8 @@ python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTE --temperature $TEMP \ --top_k $TOPK \ --output-path samples_glm \ - --batch-size 1 \ + --batch-size 2 \ --out-seq-length 200 \ - --mode inference + --mode inference \ + --input-source ./input.txt \ + --sampling-strategy BeamSearchStrategy