Skip to content
Snippets Groups Projects
Commit 2515985c authored by Zhengxiao Du's avatar Zhengxiao Du
Browse files

Implement no repeat ngram and min target length strategies

parent d59e2a62
No related branches found
No related tags found
No related merge requests found
......@@ -159,6 +159,8 @@ def add_text_generate_args(parser):
group.add_argument("--top_k", type=int, default=0)
group.add_argument("--num-beams", type=int, default=1)
group.add_argument("--length-penalty", type=float, default=0.0)
group.add_argument("--no-repeat-ngram-size", type=int, default=0)
group.add_argument("--min-tgt-length", type=int, default=0)
group.add_argument("--out-seq-length", type=int, default=256)
group.add_argument('--input-source', type=str, default='interactive',
help='what input mode to use, interactive or path')
......
......@@ -198,7 +198,6 @@ class BeamSearchScorer(BeamScorer):
mems=None
) -> Tuple[torch.LongTensor, List[torch.Tensor], torch.FloatTensor]:
batch_size = len(self._beam_hyps)
breakpoint()
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
......@@ -327,20 +326,23 @@ class MinLengthLogitsProcessor(LogitsProcessor):
The id of the `end-of-sequence` token.
"""
def __init__(self, min_length: int, eos_token_id: int):
def __init__(self, min_length: int, eos_token_ids: Union[List[int], int]):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
for eos_token_id in eos_token_ids:
if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
self.min_length = min_length
self.eos_token_id = eos_token_id
self.eos_token_ids = eos_token_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len < self.min_length:
scores[:, self.eos_token_id] = -float("inf")
for eos_token_id in self.eos_token_ids:
scores[:, eos_token_id] = -float("inf")
return scores
......@@ -395,11 +397,21 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
class BeamSearchStrategy:
def __init__(self, num_beams, max_length, length_penalty, end_tokens, device='cuda'):
def __init__(self, num_beams, max_length, length_penalty, end_tokens, device='cuda', no_repeat_ngram_size=0,
min_tgt_length=0):
self.num_beams = num_beams
self.max_length = max_length
self.length_penalty = length_penalty
self.end_tokens = end_tokens
self.no_repeat_ngram_size = no_repeat_ngram_size
self.min_tgt_length = min_tgt_length
self.processors = LogitsProcessorList()
if min_tgt_length > 0:
processor = MinLengthLogitsProcessor(min_tgt_length, self.end_tokens)
self.processors.append(processor)
if no_repeat_ngram_size > 0:
processor = NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)
self.processors.append(processor)
self.beam_scorer = BeamSearchScorer(
batch_size=1,
max_length=max_length,
......@@ -416,6 +428,7 @@ class BeamSearchStrategy:
def forward(self, logits, tokens, mems):
last_beam_num = tokens.size(0)
logits = self.processors(tokens, logits)
next_token_scores = F.log_softmax(logits, dim=-1)
next_token_scores = next_token_scores + self.beam_scores[:, None].expand_as(next_token_scores)
vocab_size = next_token_scores.shape[-1]
......@@ -448,7 +461,6 @@ class BeamSearchStrategy:
return tokens, mems
def finalize(self, tokens, mems):
# TODO check the eos token here
tokens, mems, scores = self.beam_scorer.finalize(tokens, self.beam_scores,
eos_token_id=self.end_tokens[0],
mems=mems)
......
......@@ -133,7 +133,9 @@ def generate_samples(model, tokenizer, args):
position = mask_position
if args.num_beams > 1:
strategy = BeamSearchStrategy(num_beams=args.num_beams, max_length=args.out_seq_length,
length_penalty=args.length_penalty, end_tokens=end_tokens)
length_penalty=args.length_penalty, end_tokens=end_tokens,
no_repeat_ngram_size=args.no_repeat_ngram_size,
min_tgt_length=args.min_tgt_length)
else:
strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
end_tokens=end_tokens)
......
......@@ -23,6 +23,7 @@ python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTE
--model-parallel-size $MPSIZE \
$MODEL_ARGS \
--num-beams 4 \
--no-repeat-ngram-size 3 \
--length-penalty 0.7 \
--fp16 \
--out-seq-length $MAXSEQLEN \
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment