diff --git a/arguments.py b/arguments.py index ebefb2fb78ef1af4c57ff506c29654be8391926d..9d87a2344d6e8998deaeb0e64fa5ba4e1a20fb27 100755 --- a/arguments.py +++ b/arguments.py @@ -158,6 +158,7 @@ def add_text_generate_args(parser): group.add_argument("--top_p", type=float, default=0.0) group.add_argument("--top_k", type=int, default=0) group.add_argument("--num-beams", type=int, default=1) + group.add_argument("--length-penalty", type=float, default=0.0) group.add_argument("--out-seq-length", type=int, default=256) group.add_argument('--input-source', type=str, default='interactive', help='what input mode to use, interactive or path') diff --git a/generation/glm_sampling.py b/generation/glm_sampling.py index 704b54d6ed696aca4f58868a994bfccf0bd52d6a..f024ce326fa8481d9ff613a54be0f2b707365622 100644 --- a/generation/glm_sampling.py +++ b/generation/glm_sampling.py @@ -5,34 +5,6 @@ from .autoregressive_sampling import update_mems from .sampling_strategies.beam_search_strategy import BeamSearchScorer -def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): - # This function has been mostly taken from huggingface conversational ai code at - # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 - - if top_k > 0: - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - - if top_p > 0.0: - # convert to 1D - logits = logits.view(logits.size()[1]).contiguous() - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - indices_to_remove = sorted_indices[sorted_indices_to_remove] - logits[indices_to_remove] = filter_value - # going back to 2D - logits = logits.view(1, -1).contiguous() - - return logits - - def filling_sequence_glm(model, tokenizer, mask_position, strategy, args, mems=None, end_tokens=None, device='cuda'): tokens = torch.full((1, 1), tokenizer.get_command('sop').Id, device=device, dtype=torch.long) counter = 0 @@ -40,7 +12,7 @@ def filling_sequence_glm(model, tokenizer, mask_position, strategy, args, mems=N mems = [] # if end_tokens is None: # end_tokens = [tokenizer.get_command('eos').Id] - while counter < args.out_seq_length: + while counter < args.out_seq_length - 1: last_beam_num = tokens.size(0) if args.block_lm: if args.no_block_position: @@ -73,5 +45,5 @@ def filling_sequence_glm(model, tokenizer, mask_position, strategy, args, mems=N # prev = prev.view(1, 1) # tokens = prev if tokens is None else torch.cat((tokens, prev), dim=1) counter += 1 - strategy.finalize(tokens, mems) + tokens, mems = strategy.finalize(tokens, mems) return tokens, mems diff --git a/generation/sampling_strategies/__init__.py b/generation/sampling_strategies/__init__.py index 2f71e09703c38106088167808d3be758ae8c9b24..2e6b4f6f481d2dbe24aa6c899656173deeb0b163 100644 --- a/generation/sampling_strategies/__init__.py +++ b/generation/sampling_strategies/__init__.py @@ -1,2 +1,3 @@ from .base_strategy import BaseStrategy -from .iterative_entfilter_strategy import IterativeEntfilterStrategy \ No newline at end of file +from .iterative_entfilter_strategy import IterativeEntfilterStrategy +from .beam_search_strategy import BeamSearchStrategy \ No newline at end of file diff --git a/generation/sampling_strategies/base_strategy.py b/generation/sampling_strategies/base_strategy.py index e46a8ca4e505c2891bae2599ebe10746cc54d876..20339941e16802cb60c88613a5bafee76db75b74 100644 --- a/generation/sampling_strategies/base_strategy.py +++ b/generation/sampling_strategies/base_strategy.py @@ -14,26 +14,65 @@ import random import torch import torch.nn.functional as F -def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')): - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value + +def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): + # This function has been mostly taken from huggingface conversational ai code at + # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + # convert to 1D + logits = logits.view(logits.size()[1]).contiguous() + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + # going back to 2D + logits = logits.view(1, -1).contiguous() + return logits + class BaseStrategy: - def __init__(self, invalid_slices=[], temperature=1., topk=200, eps=1e-4): + def __init__(self, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None): self.invalid_slices = invalid_slices self.temperature = temperature - self.topk = topk + self.topk = top_k + self.top_p = top_p self.eps = eps + if end_tokens is None: + end_tokens = [] + self.end_tokens = end_tokens + self._is_done = False + + @property + def is_done(self) -> bool: + return self._is_done + def forward(self, logits, tokens, mems, temperature=None): if temperature is None: - temperature = self.temperature + temperature = self.temperature logits = logits / temperature for invalid_slice in self.invalid_slices: logits[..., invalid_slice] = -65504 - - logits = top_k_logits_(logits, self.topk) - probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch + + logits = top_k_logits(logits, self.topk, self.top_p) + probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch pred = torch.multinomial(probs, num_samples=1) + if pred.item() in self.end_tokens: + self._is_done = True tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1) return tokens, mems + + def finalize(self, tokens, mems): + return tokens, mems diff --git a/generation/sampling_strategies/beam_search_strategy.py b/generation/sampling_strategies/beam_search_strategy.py index 98ba2d62dc45f9ef528888510180d8f8f919f131..6ce4fd959a266562b64a8522cd59ff7ca1accc8f 100644 --- a/generation/sampling_strategies/beam_search_strategy.py +++ b/generation/sampling_strategies/beam_search_strategy.py @@ -14,34 +14,6 @@ from collections import UserDict from typing import Optional, Tuple, List, Iterable, Union -def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): - # This function has been mostly taken from huggingface conversational ai code at - # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 - - if top_k > 0: - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - - if top_p > 0.0: - # convert to 1D - logits = logits.view(logits.size()[1]).contiguous() - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - indices_to_remove = sorted_indices[sorted_indices_to_remove] - logits[indices_to_remove] = filter_value - # going back to 2D - logits = logits.view(1, -1).contiguous() - - return logits - - class BeamScorer(ABC): """ Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and @@ -226,7 +198,7 @@ class BeamSearchScorer(BeamScorer): mems=None ) -> Tuple[torch.LongTensor, List[torch.Tensor], torch.FloatTensor]: batch_size = len(self._beam_hyps) - + breakpoint() # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: diff --git a/inference_cogview.py b/inference_cogview.py index dd4c90f8934ab576f8cbc4e783a7cbe1eea065c8..f5ecb78d9d8ee061c8eb757569381d904e4ab953 100644 --- a/inference_cogview.py +++ b/inference_cogview.py @@ -37,8 +37,8 @@ def main(args): # define function for each query query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024' if not args.full_query else '{}' invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] - strategy = BaseStrategy(invalid_slices, - temperature=args.temperature, topk=args.top_k) + strategy = BaseStrategy(invalid_slices, + temperature=args.temperature, top_k=args.top_k) def process(raw_text): if args.with_id: diff --git a/inference_cogview2.py b/inference_cogview2.py index 75ad1c6c269d4bd0eb8eb03a65e1b05f45d78c3b..c69e589d71009b40d3620f5052086dbb9509992b 100644 --- a/inference_cogview2.py +++ b/inference_cogview2.py @@ -42,8 +42,8 @@ def main(args): # define function for each query query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024 [EOI1]' if not args.full_query else '{}' invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] - strategy0 = BaseStrategy(invalid_slices, - temperature=args.temperature, topk=args.top_k) + strategy0 = BaseStrategy(invalid_slices, + temperature=args.temperature, top_k=args.top_k) strategy1 = IterativeEntfilterStrategy(invalid_slices, temperature=args.temperature, topk=10) # temperature not used tr = transforms.Compose([ diff --git a/inference_glm.py b/inference_glm.py index cfba93f1fc1715be50b8280ce4ae96ac9b472df4..26d247104888abc71da184169371e8fbd5942bec 100644 --- a/inference_glm.py +++ b/inference_glm.py @@ -20,6 +20,7 @@ from arguments import get_args from model.glm_model import GLMModel from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer from generation.glm_sampling import filling_sequence_glm +from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy def read_context(tokenizer, args, output=None): @@ -130,8 +131,14 @@ def generate_samples(model, tokenizer, args): position = position_ids[0, mask_position].item() else: position = mask_position - new_tokens, mems = filling_sequence_glm(model, tokenizer, position, args, mems=mems, - end_tokens=end_tokens) + if args.num_beams > 1: + strategy = BeamSearchStrategy(num_beams=args.num_beams, max_length=args.out_seq_length, + length_penalty=args.length_penalty, end_tokens=end_tokens) + else: + strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, + end_tokens=end_tokens) + new_tokens, mems = filling_sequence_glm(model, tokenizer, position, strategy, args, mems=mems, + end_tokens=end_tokens) tokens = torch.cat((tokens, new_tokens), dim=1) output_tokens_list = tokens.view(-1).contiguous() if mpu.get_model_parallel_rank() == 0: diff --git a/scripts/generate_glm.sh b/scripts/generate_glm.sh index 30e6d9eb1a367b61acb782012cd90889fb1ef2f6..c3c99c99f01d8aaa10b973e0795ff8d1ce53867c 100644 --- a/scripts/generate_glm.sh +++ b/scripts/generate_glm.sh @@ -22,6 +22,8 @@ python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTE --mode inference \ --model-parallel-size $MPSIZE \ $MODEL_ARGS \ + --num-beams 4 \ + --length-penalty 0.7 \ --fp16 \ --out-seq-length $MAXSEQLEN \ --temperature $TEMP \