Skip to content
Snippets Groups Projects
Commit 69cbe5dd authored by Ming Ding's avatar Ming Ding
Browse files

support beam search

parent 26029fec
No related branches found
No related tags found
No related merge requests found
......@@ -11,7 +11,7 @@ import torch
import torch.nn.functional as F
class BeamSearchStrategy:
def __init__(self, num_beams, length_penalty=1., return_only_end=False,
def __init__(self, num_beams, length_penalty=1., consider_end=False,
end_tokens=[], invalid_slices=[], no_repeat_ngram_size=0, min_tgt_length=0):
self.num_beams = num_beams
self.length_penalty = length_penalty
......@@ -19,7 +19,7 @@ class BeamSearchStrategy:
self.ngram = no_repeat_ngram_size
self.min_tgt_length = min_tgt_length
self.invalid_slices = invalid_slices
self.return_only_end = return_only_end
self.consider_end = consider_end
self._init_cache()
def _init_cache(self):
......@@ -34,10 +34,10 @@ class BeamSearchStrategy:
for i in range(len(self.end_beams), -1, -1):
if i == 0 or score < self.end_beams_penalized_scores[i-1]:
break
self.num_beams.insert(i, beam)
self.end_beams.insert(i, beam)
self.end_beams_penalized_scores.insert(i, score)
self.num_beams = self.num_beams[:self.num_beams]
self.end_beams = self.end_beams[:self.num_beams]
self.end_beams_penalized_scores = self.end_beams_penalized_scores[:self.num_beams]
def forward(self, logits, tokens, mems):
......@@ -52,11 +52,14 @@ class BeamSearchStrategy:
if self.ngram > 0 and seq_len > self.ngram:
for i in range(batch_size):
ngram_prefix = tokens[i, -(self.ngram-1):].tolist() # TODO ngram=1
for banned_index in self.cached_beam_ngram_bans.get(tuple(ngram_prefix), default=[]):
for banned_index in self.cached_beam_ngram_bans[i].get(tuple(ngram_prefix), []):
logits[i, banned_index] = -65504
next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size]
next_token_scores = next_token_scores + self.cached_beam_scores[:, None].expand_as(next_token_scores)
prev_scores = self.cached_beam_scores
if isinstance(self.cached_beam_scores, torch.Tensor):
prev_scores = prev_scores[:, None].expand_as(next_token_scores)
next_token_scores = next_token_scores + prev_scores
next_token_scores = next_token_scores.view(batch_size * vocab_size)
......@@ -70,12 +73,14 @@ class BeamSearchStrategy:
next_tokens = next_tokens % vocab_size
# select out end beams or continue beams
if mems.shape[1] < batch_size:
mems = mems.expand(-1, batch_size, -1, -1)
beam_continue = []
scores_continue = []
bans_continue = []
mems_contiue = []
for i in range(len(next_tokens)):
beam = torch.cat(tokens[next_indices[i]], next_tokens[i:i+1])
beam = torch.cat((tokens[next_indices[i]], next_tokens[i:i+1]))
if int(next_tokens[i]) in self.end_tokens:
self._add_end_beams(next_token_scores[i], beam)
elif len(beam_continue) < batch_size:
......@@ -98,10 +103,13 @@ class BeamSearchStrategy:
# TODO is_done
return tokens, mems
def finalize(self, tokens):
if not self.return_only_end:
def finalize(self, tokens, mems):
if self.consider_end:
for i in range(tokens.shape[0]):
self._add_end_beams(self.cached_beam_scores[i], tokens[i])
ret = self.end_beams
mems = None
ret = self.end_beams
else:
ret = tokens
self._init_cache()
return ret
return ret, mems
......@@ -53,7 +53,7 @@ def main(args):
print('raw text: ', raw_text)
text = query_template.format(raw_text)
seq = tokenizer.parse_query(text, img_size=args.img_size)
if len(seq) > 1088:
if len(seq) > 1088:
raise ValueError('text too long.')
# calibrate text length
txt_len = seq.index(tokenizer['[BASE]'])
......@@ -67,11 +67,12 @@ def main(args):
assert args.batch_size < mbz or args.batch_size % mbz == 0
output_list = []
for tim in range(max(args.batch_size // mbz, 1)):
output0, mems = filling_sequence(model0, seq.clone(),
output0 = filling_sequence(model0, seq.clone(),
batch_size=min(args.batch_size, mbz),
strategy=strategy0,
log_attention_weights=log_attention_weights
)
)[0]
# auto del mems to save CUDA memory as possible
imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1025:-1].tolist())) for x in output0]
blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(args.device), add_normalization=True) # [batch_size, 4096]
len_tim = output0.shape[0]
......
#!/bin/bash
NUM_WORKERS=4
NUM_WORKERS=5
NUM_GPUS_PER_WORKER=8
MP_SIZE=1
......@@ -8,7 +8,7 @@ OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
HOST_FILE_PATH="hostfile"
# HOST_FILE_PATH="hostfile_single"
CHECKPOINT_PATH=pretrained/cogview/cogview2-base
CHECKPOINT_PATH=/dataset/fd5061f6/sat_pretrained/cogview/cogview2-base
NLAYERS=48
NHIDDEN=2560
NATT=40
......@@ -25,7 +25,7 @@ script_dir=$(dirname $script_path)
gpt_options=" \
--tokenizer-type cogview \
--img-tokenizer-path pretrained/vqvae/l1+ms-ssim+revd_percep.pt \
--img-tokenizer-path /dataset/fd5061f6/sat_pretrained/vqvae/l1+ms-ssim+revd_percep.pt \
--mode inference \
--distributed-backend nccl \
--max-sequence-length 1089 \
......@@ -42,7 +42,7 @@ gpt_options=" \
--input-source ./coco30k.txt \
--output-path coco_samples \
--batch-size 60 \
--max-inference-batch-size 12 \
--max-inference-batch-size 6 \
--with-id \
"
......
......@@ -61,7 +61,12 @@ def main(args):
end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
# define function for each query
strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k,end_tokens=end_tokens)
if args.sampling_strategy == 'BaseStrategy':
strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k,end_tokens=end_tokens)
elif args.sampling_strategy == 'BeamSearchStrategy':
strategy = BeamSearchStrategy(args.batch_size, length_penalty=args.length_penalty, consider_end=True, end_tokens=end_tokens, no_repeat_ngram_size=args.no_repeat_ngram_size, min_tgt_length=args.min_tgt_length)
else:
raise ValueError(f'unknown strategy {args.sampling_strategy}')
def process(raw_text):
if args.with_id:
......@@ -77,15 +82,6 @@ def main(args):
print('raw text: {}\n'.format(raw_text))
if len(seq) > args.max_sequence_length:
raise ValueError('text too long.')
# find mask tokens positions
# mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
# mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
# mask_positions = []
# context_tokens_tensor = torch.tensor(seq, dtype=torch.long, device=args.device)
# for token in mask_tokens:
# mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
# mask_positions.sort()
# generation
mbz = args.max_inference_batch_size
......@@ -112,12 +108,12 @@ def main(args):
input_seq = torch.cuda.LongTensor(
seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length - len(seq) - 1),
device=args.device)
output, _mems = filling_sequence(model, input_seq,
output = filling_sequence(model, input_seq,
batch_size=min(args.batch_size, mbz),
strategy=strategy,
log_attention_weights=None,
get_masks_and_position_ids=get_func
) # we don't use mems, fill back
)[0] # we don't use mems, fill back
if isinstance(output, torch.Tensor): # different strategies
output = list(output)
......@@ -159,7 +155,7 @@ def main(args):
if __name__ == "__main__":
py_parser = argparse.ArgumentParser(add_help=False)
py_parser.add_argument('--sampling-strategy', type=str, default='BaseStrategy', help='type name of sampling strategy')
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
......
#!/bin/bash
CHECKPOINT_PATH=/dataset/fd5061f6/sat_pretrained/glm
# MODEL_ARGS="--block-lm \
# --cloze-eval \
# --num-layers 24 \
# --hidden-size 1024 \
# --num-attention-heads 16 \
# --max-sequence-length 513 \
# --tokenizer-model-type roberta \
# --tokenizer-type glm_GPT2BPETokenizer \
# --load ${CHECKPOINT_PATH}/glm-roberta-large-blank"
#MODEL_TYPE="blocklm-10B"
#MODEL_ARGS="--block-lm \
# --cloze-eval \
# --task-mask \
# --num-layers 48 \
# --hidden-size 4096 \
# --num-attention-heads 64 \
# --max-sequence-length 1025 \
# --tokenizer-model-type gpt2 \
# --tokenizer-type glm_GPT2BPETokenizer \
# --old-checkpoint \
# --load ${CHECKPOINT_PATH}/glm-en-10b"
source $1
MPSIZE=1
MAXSEQLEN=512
......@@ -52,6 +29,8 @@ python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTE
--temperature $TEMP \
--top_k $TOPK \
--output-path samples_glm \
--batch-size 1 \
--batch-size 2 \
--out-seq-length 200 \
--mode inference
--mode inference \
--input-source ./input.txt \
--sampling-strategy BeamSearchStrategy
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment