From d59e2a62262d85b88e7f520d3caddb28df903ae1 Mon Sep 17 00:00:00 2001 From: Zhengxiao Du <zx-du20@mails.tsinghua.edu.cn> Date: Sat, 23 Oct 2021 15:27:47 +0800 Subject: [PATCH] Fix beam search strategy --- arguments.py | 1 + generation/glm_sampling.py | 32 +---------- generation/sampling_strategies/__init__.py | 3 +- .../sampling_strategies/base_strategy.py | 57 ++++++++++++++++--- .../beam_search_strategy.py | 30 +--------- inference_cogview.py | 4 +- inference_cogview2.py | 4 +- inference_glm.py | 11 +++- scripts/generate_glm.sh | 2 + 9 files changed, 69 insertions(+), 75 deletions(-) diff --git a/arguments.py b/arguments.py index ebefb2f..9d87a23 100755 --- a/arguments.py +++ b/arguments.py @@ -158,6 +158,7 @@ def add_text_generate_args(parser): group.add_argument("--top_p", type=float, default=0.0) group.add_argument("--top_k", type=int, default=0) group.add_argument("--num-beams", type=int, default=1) + group.add_argument("--length-penalty", type=float, default=0.0) group.add_argument("--out-seq-length", type=int, default=256) group.add_argument('--input-source', type=str, default='interactive', help='what input mode to use, interactive or path') diff --git a/generation/glm_sampling.py b/generation/glm_sampling.py index 704b54d..f024ce3 100644 --- a/generation/glm_sampling.py +++ b/generation/glm_sampling.py @@ -5,34 +5,6 @@ from .autoregressive_sampling import update_mems from .sampling_strategies.beam_search_strategy import BeamSearchScorer -def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): - # This function has been mostly taken from huggingface conversational ai code at - # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 - - if top_k > 0: - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - - if top_p > 0.0: - # convert to 1D - logits = logits.view(logits.size()[1]).contiguous() - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - indices_to_remove = sorted_indices[sorted_indices_to_remove] - logits[indices_to_remove] = filter_value - # going back to 2D - logits = logits.view(1, -1).contiguous() - - return logits - - def filling_sequence_glm(model, tokenizer, mask_position, strategy, args, mems=None, end_tokens=None, device='cuda'): tokens = torch.full((1, 1), tokenizer.get_command('sop').Id, device=device, dtype=torch.long) counter = 0 @@ -40,7 +12,7 @@ def filling_sequence_glm(model, tokenizer, mask_position, strategy, args, mems=N mems = [] # if end_tokens is None: # end_tokens = [tokenizer.get_command('eos').Id] - while counter < args.out_seq_length: + while counter < args.out_seq_length - 1: last_beam_num = tokens.size(0) if args.block_lm: if args.no_block_position: @@ -73,5 +45,5 @@ def filling_sequence_glm(model, tokenizer, mask_position, strategy, args, mems=N # prev = prev.view(1, 1) # tokens = prev if tokens is None else torch.cat((tokens, prev), dim=1) counter += 1 - strategy.finalize(tokens, mems) + tokens, mems = strategy.finalize(tokens, mems) return tokens, mems diff --git a/generation/sampling_strategies/__init__.py b/generation/sampling_strategies/__init__.py index 2f71e09..2e6b4f6 100644 --- a/generation/sampling_strategies/__init__.py +++ b/generation/sampling_strategies/__init__.py @@ -1,2 +1,3 @@ from .base_strategy import BaseStrategy -from .iterative_entfilter_strategy import IterativeEntfilterStrategy \ No newline at end of file +from .iterative_entfilter_strategy import IterativeEntfilterStrategy +from .beam_search_strategy import BeamSearchStrategy \ No newline at end of file diff --git a/generation/sampling_strategies/base_strategy.py b/generation/sampling_strategies/base_strategy.py index e46a8ca..2033994 100644 --- a/generation/sampling_strategies/base_strategy.py +++ b/generation/sampling_strategies/base_strategy.py @@ -14,26 +14,65 @@ import random import torch import torch.nn.functional as F -def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')): - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value + +def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): + # This function has been mostly taken from huggingface conversational ai code at + # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + # convert to 1D + logits = logits.view(logits.size()[1]).contiguous() + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + # going back to 2D + logits = logits.view(1, -1).contiguous() + return logits + class BaseStrategy: - def __init__(self, invalid_slices=[], temperature=1., topk=200, eps=1e-4): + def __init__(self, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None): self.invalid_slices = invalid_slices self.temperature = temperature - self.topk = topk + self.topk = top_k + self.top_p = top_p self.eps = eps + if end_tokens is None: + end_tokens = [] + self.end_tokens = end_tokens + self._is_done = False + + @property + def is_done(self) -> bool: + return self._is_done + def forward(self, logits, tokens, mems, temperature=None): if temperature is None: - temperature = self.temperature + temperature = self.temperature logits = logits / temperature for invalid_slice in self.invalid_slices: logits[..., invalid_slice] = -65504 - - logits = top_k_logits_(logits, self.topk) - probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch + + logits = top_k_logits(logits, self.topk, self.top_p) + probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch pred = torch.multinomial(probs, num_samples=1) + if pred.item() in self.end_tokens: + self._is_done = True tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1) return tokens, mems + + def finalize(self, tokens, mems): + return tokens, mems diff --git a/generation/sampling_strategies/beam_search_strategy.py b/generation/sampling_strategies/beam_search_strategy.py index 98ba2d6..6ce4fd9 100644 --- a/generation/sampling_strategies/beam_search_strategy.py +++ b/generation/sampling_strategies/beam_search_strategy.py @@ -14,34 +14,6 @@ from collections import UserDict from typing import Optional, Tuple, List, Iterable, Union -def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): - # This function has been mostly taken from huggingface conversational ai code at - # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 - - if top_k > 0: - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - - if top_p > 0.0: - # convert to 1D - logits = logits.view(logits.size()[1]).contiguous() - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - indices_to_remove = sorted_indices[sorted_indices_to_remove] - logits[indices_to_remove] = filter_value - # going back to 2D - logits = logits.view(1, -1).contiguous() - - return logits - - class BeamScorer(ABC): """ Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and @@ -226,7 +198,7 @@ class BeamSearchScorer(BeamScorer): mems=None ) -> Tuple[torch.LongTensor, List[torch.Tensor], torch.FloatTensor]: batch_size = len(self._beam_hyps) - + breakpoint() # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: diff --git a/inference_cogview.py b/inference_cogview.py index dd4c90f..f5ecb78 100644 --- a/inference_cogview.py +++ b/inference_cogview.py @@ -37,8 +37,8 @@ def main(args): # define function for each query query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024' if not args.full_query else '{}' invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] - strategy = BaseStrategy(invalid_slices, - temperature=args.temperature, topk=args.top_k) + strategy = BaseStrategy(invalid_slices, + temperature=args.temperature, top_k=args.top_k) def process(raw_text): if args.with_id: diff --git a/inference_cogview2.py b/inference_cogview2.py index 75ad1c6..c69e589 100644 --- a/inference_cogview2.py +++ b/inference_cogview2.py @@ -42,8 +42,8 @@ def main(args): # define function for each query query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024 [EOI1]' if not args.full_query else '{}' invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] - strategy0 = BaseStrategy(invalid_slices, - temperature=args.temperature, topk=args.top_k) + strategy0 = BaseStrategy(invalid_slices, + temperature=args.temperature, top_k=args.top_k) strategy1 = IterativeEntfilterStrategy(invalid_slices, temperature=args.temperature, topk=10) # temperature not used tr = transforms.Compose([ diff --git a/inference_glm.py b/inference_glm.py index cfba93f..26d2471 100644 --- a/inference_glm.py +++ b/inference_glm.py @@ -20,6 +20,7 @@ from arguments import get_args from model.glm_model import GLMModel from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer from generation.glm_sampling import filling_sequence_glm +from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy def read_context(tokenizer, args, output=None): @@ -130,8 +131,14 @@ def generate_samples(model, tokenizer, args): position = position_ids[0, mask_position].item() else: position = mask_position - new_tokens, mems = filling_sequence_glm(model, tokenizer, position, args, mems=mems, - end_tokens=end_tokens) + if args.num_beams > 1: + strategy = BeamSearchStrategy(num_beams=args.num_beams, max_length=args.out_seq_length, + length_penalty=args.length_penalty, end_tokens=end_tokens) + else: + strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, + end_tokens=end_tokens) + new_tokens, mems = filling_sequence_glm(model, tokenizer, position, strategy, args, mems=mems, + end_tokens=end_tokens) tokens = torch.cat((tokens, new_tokens), dim=1) output_tokens_list = tokens.view(-1).contiguous() if mpu.get_model_parallel_rank() == 0: diff --git a/scripts/generate_glm.sh b/scripts/generate_glm.sh index 30e6d9e..c3c99c9 100644 --- a/scripts/generate_glm.sh +++ b/scripts/generate_glm.sh @@ -22,6 +22,8 @@ python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTE --mode inference \ --model-parallel-size $MPSIZE \ $MODEL_ARGS \ + --num-beams 4 \ + --length-penalty 0.7 \ --fp16 \ --out-seq-length $MAXSEQLEN \ --temperature $TEMP \ -- GitLab