diff --git a/arguments.py b/arguments.py index 9d87a2344d6e8998deaeb0e64fa5ba4e1a20fb27..55d3d1996144134d20b16b55ee1812abc589d91b 100755 --- a/arguments.py +++ b/arguments.py @@ -159,6 +159,8 @@ def add_text_generate_args(parser): group.add_argument("--top_k", type=int, default=0) group.add_argument("--num-beams", type=int, default=1) group.add_argument("--length-penalty", type=float, default=0.0) + group.add_argument("--no-repeat-ngram-size", type=int, default=0) + group.add_argument("--min-tgt-length", type=int, default=0) group.add_argument("--out-seq-length", type=int, default=256) group.add_argument('--input-source', type=str, default='interactive', help='what input mode to use, interactive or path') diff --git a/generation/sampling_strategies/beam_search_strategy.py b/generation/sampling_strategies/beam_search_strategy.py index 6ce4fd959a266562b64a8522cd59ff7ca1accc8f..ec205e071590c1683d2e96bbac80ff57db31d2df 100644 --- a/generation/sampling_strategies/beam_search_strategy.py +++ b/generation/sampling_strategies/beam_search_strategy.py @@ -198,7 +198,6 @@ class BeamSearchScorer(BeamScorer): mems=None ) -> Tuple[torch.LongTensor, List[torch.Tensor], torch.FloatTensor]: batch_size = len(self._beam_hyps) - breakpoint() # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: @@ -327,20 +326,23 @@ class MinLengthLogitsProcessor(LogitsProcessor): The id of the `end-of-sequence` token. """ - def __init__(self, min_length: int, eos_token_id: int): + def __init__(self, min_length: int, eos_token_ids: Union[List[int], int]): if not isinstance(min_length, int) or min_length < 0: raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") - - if not isinstance(eos_token_id, int) or eos_token_id < 0: - raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") + if isinstance(eos_token_ids, int): + eos_token_ids = [eos_token_ids] + for eos_token_id in eos_token_ids: + if not isinstance(eos_token_id, int) or eos_token_id < 0: + raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") self.min_length = min_length - self.eos_token_id = eos_token_id + self.eos_token_ids = eos_token_ids def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: cur_len = input_ids.shape[-1] if cur_len < self.min_length: - scores[:, self.eos_token_id] = -float("inf") + for eos_token_id in self.eos_token_ids: + scores[:, eos_token_id] = -float("inf") return scores @@ -395,11 +397,21 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor): class BeamSearchStrategy: - def __init__(self, num_beams, max_length, length_penalty, end_tokens, device='cuda'): + def __init__(self, num_beams, max_length, length_penalty, end_tokens, device='cuda', no_repeat_ngram_size=0, + min_tgt_length=0): self.num_beams = num_beams self.max_length = max_length self.length_penalty = length_penalty self.end_tokens = end_tokens + self.no_repeat_ngram_size = no_repeat_ngram_size + self.min_tgt_length = min_tgt_length + self.processors = LogitsProcessorList() + if min_tgt_length > 0: + processor = MinLengthLogitsProcessor(min_tgt_length, self.end_tokens) + self.processors.append(processor) + if no_repeat_ngram_size > 0: + processor = NoRepeatNGramLogitsProcessor(no_repeat_ngram_size) + self.processors.append(processor) self.beam_scorer = BeamSearchScorer( batch_size=1, max_length=max_length, @@ -416,6 +428,7 @@ class BeamSearchStrategy: def forward(self, logits, tokens, mems): last_beam_num = tokens.size(0) + logits = self.processors(tokens, logits) next_token_scores = F.log_softmax(logits, dim=-1) next_token_scores = next_token_scores + self.beam_scores[:, None].expand_as(next_token_scores) vocab_size = next_token_scores.shape[-1] @@ -448,7 +461,6 @@ class BeamSearchStrategy: return tokens, mems def finalize(self, tokens, mems): - # TODO check the eos token here tokens, mems, scores = self.beam_scorer.finalize(tokens, self.beam_scores, eos_token_id=self.end_tokens[0], mems=mems) diff --git a/inference_glm.py b/inference_glm.py index 26d247104888abc71da184169371e8fbd5942bec..792efd5e030fa1564a7f52310b261844eb113c03 100644 --- a/inference_glm.py +++ b/inference_glm.py @@ -133,7 +133,9 @@ def generate_samples(model, tokenizer, args): position = mask_position if args.num_beams > 1: strategy = BeamSearchStrategy(num_beams=args.num_beams, max_length=args.out_seq_length, - length_penalty=args.length_penalty, end_tokens=end_tokens) + length_penalty=args.length_penalty, end_tokens=end_tokens, + no_repeat_ngram_size=args.no_repeat_ngram_size, + min_tgt_length=args.min_tgt_length) else: strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens) diff --git a/scripts/generate_glm.sh b/scripts/generate_glm.sh index c3c99c99f01d8aaa10b973e0795ff8d1ce53867c..143d7b5baa08029075b6c360f1b252a445b5a686 100644 --- a/scripts/generate_glm.sh +++ b/scripts/generate_glm.sh @@ -23,6 +23,7 @@ python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTE --model-parallel-size $MPSIZE \ $MODEL_ARGS \ --num-beams 4 \ + --no-repeat-ngram-size 3 \ --length-penalty 0.7 \ --fp16 \ --out-seq-length $MAXSEQLEN \