diff --git a/generation/glm_sampling.py b/generation/glm_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..704b54d6ed696aca4f58868a994bfccf0bd52d6a --- /dev/null +++ b/generation/glm_sampling.py @@ -0,0 +1,77 @@ +import torch +import torch.nn.functional as F +import mpu +from .autoregressive_sampling import update_mems +from .sampling_strategies.beam_search_strategy import BeamSearchScorer + + +def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): + # This function has been mostly taken from huggingface conversational ai code at + # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + # convert to 1D + logits = logits.view(logits.size()[1]).contiguous() + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + # going back to 2D + logits = logits.view(1, -1).contiguous() + + return logits + + +def filling_sequence_glm(model, tokenizer, mask_position, strategy, args, mems=None, end_tokens=None, device='cuda'): + tokens = torch.full((1, 1), tokenizer.get_command('sop').Id, device=device, dtype=torch.long) + counter = 0 + if mems is None: + mems = [] + # if end_tokens is None: + # end_tokens = [tokenizer.get_command('eos').Id] + while counter < args.out_seq_length: + last_beam_num = tokens.size(0) + if args.block_lm: + if args.no_block_position: + position_ids = torch.full((last_beam_num, 1), mask_position + counter, device=device, dtype=torch.long) + else: + position_ids = torch.ones(last_beam_num, 2, 1, device=device, dtype=torch.long) + position_ids[:, 0] = mask_position + position_ids[:, 1] = counter + 1 + attention_mask = torch.ones(1, 1, device=device, dtype=torch.float) + else: + position_ids = torch.full((last_beam_num, 1), mask_position + counter - 1, device=device, dtype=torch.long) + attention_mask = torch.ones(last_beam_num, 1, 1, args.mem_length + 1, device=device, dtype=torch.float) + if args.fp16: + attention_mask = attention_mask.half() + last_token = tokens[:, -1:] + logits, *mem_kvs = model(last_token, position_ids, attention_mask, *mems) + mems = update_mems(mem_kvs, mems, max_memory_length=1000000) + next_token_logits = logits[:, -1] + tokens, mems = strategy.forward(next_token_logits, tokens, mems) + if strategy.is_done: + break + # else: + # next_token_logits /= args.temperature + # next_token_logits = top_k_logits(next_token_logits, top_k=args.top_k, top_p=args.top_p) + # log_probs = F.softmax(next_token_logits, dim=-1) + # prev = torch.multinomial(log_probs, num_samples=1)[0] + # is_end = prev.item() in end_tokens + # if is_end: + # break + # prev = prev.view(1, 1) + # tokens = prev if tokens is None else torch.cat((tokens, prev), dim=1) + counter += 1 + strategy.finalize(tokens, mems) + return tokens, mems diff --git a/generation/sampling_strategies/beam_search_strategy.py b/generation/sampling_strategies/beam_search_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..98ba2d62dc45f9ef528888510180d8f8f919f131 --- /dev/null +++ b/generation/sampling_strategies/beam_search_strategy.py @@ -0,0 +1,483 @@ +# -*- encoding: utf-8 -*- +''' +@File : base_strategy.py +@Time : 2021/10/08 22:22:42 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import torch +import torch.nn.functional as F +from abc import ABC, abstractmethod +from collections import UserDict +from typing import Optional, Tuple, List, Iterable, Union + + +def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): + # This function has been mostly taken from huggingface conversational ai code at + # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + # convert to 1D + logits = logits.view(logits.size()[1]).contiguous() + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + # going back to 2D + logits = logits.view(1, -1).contiguous() + + return logits + + +class BeamScorer(ABC): + """ + Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and + :meth:`~transformers.PretrainedModel.beam_sample`. + """ + + @abstractmethod + def process( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + **kwargs + ) -> Tuple[torch.Tensor]: + raise NotImplementedError("This is an abstract method.") + + @abstractmethod + def finalize( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + **kwargs + ) -> torch.LongTensor: + raise NotImplementedError("This is an abstract method.") + + +class BeamSearchScorer(BeamScorer): + r""" + :class:`transformers.BeamScorer` implementing standard beam search decoding. + + Adapted in part from `Facebook's XLM beam search code + <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__. + + Args: + batch_size (:obj:`int`): + Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel. + max_length (:obj:`int`): + The maximum length of the sequence to be generated. + num_beams (:obj:`int`): + Number of beams for beam search. + device (:obj:`torch.device`): + Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of + :obj:`BeamSearchScorer` will be allocated. + length_penalty (:obj:`float`, `optional`, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the + model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer + sequences. + do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. + num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): + The number of beam hypotheses that shall be returned upon calling + :meth:`~transformer.BeamSearchScorer.finalize`. + """ + + def __init__( + self, + batch_size: int, + max_length: int, + num_beams: int, + device: Union[torch.device, str], + length_penalty: Optional[float] = 1.0, + do_early_stopping: Optional[bool] = False, + num_beam_hyps_to_keep: Optional[int] = 1, + ): + self.max_length = max_length + self.num_beams = num_beams + self.device = device + self.length_penalty = length_penalty + self.do_early_stopping = do_early_stopping + self.num_beam_hyps_to_keep = num_beam_hyps_to_keep + + self._is_init = False + self._beam_hyps = [ + BeamHypotheses( + num_beams=self.num_beams, + max_length=self.max_length, + length_penalty=self.length_penalty, + early_stopping=self.do_early_stopping, + ) + for _ in range(batch_size) + ] + self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) + + # if not isinstance(num_beams, int) or num_beams <= 1: + # raise ValueError( + # f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." + # ) + + @property + def is_done(self) -> bool: + return self._done.all() + + def process( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + mems=None + ) -> Tuple[torch.Tensor]: + cur_len = input_ids.shape[-1] + batch_size = len(self._beam_hyps) + assert batch_size == (input_ids.shape[0] // self.num_beams) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + device = next_scores.device + next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device) + next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device) + next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device) + + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + assert ( + len(beam_hyp) >= self.num_beams + ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams) + assert ( + eos_token_id is not None and pad_token_id is not None + ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" + # pad the batch + next_beam_scores[batch_idx, :] = 0 + next_beam_tokens[batch_idx, :] = pad_token_id + next_beam_indices[batch_idx, :] = 0 + continue + + # next tokens for this sentence + beam_idx = 0 + for beam_token_rank, (next_token, next_score, next_index) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) + ): + batch_beam_idx = batch_idx * self.num_beams + next_index + # add to generated hypotheses if end of sentence + if (eos_token_id is not None) and (next_token.item() in eos_token_id): + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams + if is_beam_token_worse_than_top_num_beams: + continue + beam_hyp.add( + input_ids[batch_beam_idx].clone(), + next_score.item(), + mems=[mem[[next_index.item()]] for mem in mems] if mems else None + ) + else: + # add next predicted token since it is not eos_token + next_beam_scores[batch_idx, beam_idx] = next_score + next_beam_tokens[batch_idx, beam_idx] = next_token + next_beam_indices[batch_idx, beam_idx] = batch_beam_idx + beam_idx += 1 + + # once the beam for next step is full, don't add more tokens to it. + if beam_idx == self.num_beams: + break + + if beam_idx < self.num_beams: + raise ValueError( + f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." + ) + + # Check if we are done so that we can save a pad step if all(done) + self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( + next_scores[batch_idx].max().item(), cur_len + ) + + return UserDict( + { + "next_beam_scores": next_beam_scores.view(-1), + "next_beam_tokens": next_beam_tokens.view(-1), + "next_beam_indices": next_beam_indices.view(-1), + } + ) + + def finalize( + self, + input_ids: torch.LongTensor, + final_beam_scores: torch.FloatTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + mems=None + ) -> Tuple[torch.LongTensor, List[torch.Tensor], torch.FloatTensor]: + batch_size = len(self._beam_hyps) + + # finalize all open beam hypotheses and add to generated hypotheses + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + continue + + # need to add best num_beams hypotheses to generated hyps + for beam_id in range(self.num_beams): + batch_beam_idx = batch_idx * self.num_beams + beam_id + final_score = final_beam_scores[batch_beam_idx].item() + final_tokens = input_ids[batch_beam_idx] + beam_hyp.add(final_tokens, final_score, mems=[mem[[batch_beam_idx]] for mem in mems] if mems else None) + + # select the best hypotheses + sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) + best = [] + + # retrieve best hypotheses + for i, beam_hyp in enumerate(self._beam_hyps): + sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) + for j in range(self.num_beam_hyps_to_keep): + score, best_hyp, mems = sorted_hyps.pop() + sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) + best.append((best_hyp, mems, score)) + + # prepare for adding eos + sent_max_len = min(sent_lengths.max().item(), self.max_length) + decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) + scores = final_beam_scores.new(batch_size * self.num_beam_hyps_to_keep) + # shorter batches are padded if needed + if sent_lengths.min().item() != sent_lengths.max().item(): + assert pad_token_id is not None, "`pad_token_id` has to be defined" + decoded.fill_(pad_token_id) + + # fill with hypotheses and eos_token_id if the latter fits in + mems = [] + for i, (hypo, mem, score) in enumerate(best): + scores[i] = score + decoded[i, : sent_lengths[i]] = hypo + if sent_lengths[i] < sent_max_len: + decoded[i, sent_lengths[i]] = eos_token_id + mems.append(mem) + mems = [torch.cat([mem[i] for mem in mems], dim=0) for i in range(len(mems[0]))] if mems and mems[0] else None + return decoded, mems, scores + + +class BeamHypotheses: + def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool): + """ + Initialize n-best list of hypotheses. + """ + self.max_length = max_length - 1 # ignoring bos_token + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.num_beams = num_beams + self.beams = [] + self.worst_score = 1e9 + + def __len__(self): + """ + Number of hypotheses in the list. + """ + return len(self.beams) + + def add(self, hyp: torch.LongTensor, sum_logprobs: float, mems=None): + """ + Add a new hypothesis to the list. + """ + score = sum_logprobs / (max(hyp.shape[-1], 1) ** self.length_penalty) + if len(self) < self.num_beams or score > self.worst_score: + self.beams.append((score, hyp, mems)) + if len(self) > self.num_beams: + sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)]) + del self.beams[sorted_next_scores[0][1]] + self.worst_score = sorted_next_scores[1][0] + else: + self.worst_score = min(score, self.worst_score) + + def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: + """ + If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst + one in the heap, then we are done with this sentence. + """ + + if len(self) < self.num_beams: + return False + elif self.early_stopping: + return True + else: + cur_score = best_sum_logprobs / cur_len ** self.length_penalty + ret = self.worst_score >= cur_score + return ret + + +class LogitsProcessor(ABC): + """Abstract base class for all logit processors that can be applied during generation.""" + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + """Torch method for processing logits.""" + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + +class LogitsProcessorList(list): + """ + This class can be used to create a list of :class:`~transformers.LogitsProcessor` or + :class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from + list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or + :class:`~transformers.LogitsProcessor` to the inputs. + """ + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + for processor in self: + scores = processor(input_ids, scores) + return scores + + +class MinLengthLogitsProcessor(LogitsProcessor): + r""" + :class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0. + + Args: + min_length (:obj:`int`): + The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`. + eos_token_id (:obj:`int`): + The id of the `end-of-sequence` token. + """ + + def __init__(self, min_length: int, eos_token_id: int): + if not isinstance(min_length, int) or min_length < 0: + raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") + + if not isinstance(eos_token_id, int) or eos_token_id < 0: + raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") + + self.min_length = min_length + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + cur_len = input_ids.shape[-1] + if cur_len < self.min_length: + scores[:, self.eos_token_id] = -float("inf") + return scores + + +class NoRepeatNGramLogitsProcessor(LogitsProcessor): + r""" + :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq + <https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__. + + Args: + ngram_size (:obj:`int`): + All ngrams of size :obj:`ngram_size` can only occur once. + """ + + def __init__(self, ngram_size: int): + if not isinstance(ngram_size, int) or ngram_size <= 0: + raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") + self.ngram_size = ngram_size + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + num_batch_hypotheses = scores.shape[0] + cur_len = input_ids.shape[-1] + banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len) + + for i, banned_tokens in enumerate(banned_batch_tokens): + scores[i, banned_tokens] = -float("inf") + + return scores + + def _calc_banned_ngram_tokens( + self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int + ) -> List[Iterable[int]]: + """Copied from fairseq for no_repeat_ngram in beam_search""" + if cur_len + 1 < self.ngram_size: + # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + return [[] for _ in range(num_hypos)] + generated_ngrams = [{} for _ in range(num_hypos)] + for idx in range(num_hypos): + gen_tokens = prev_input_ids[idx].tolist() + generated_ngram = generated_ngrams[idx] + for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]): + prev_ngram_tuple = tuple(ngram[:-1]) + generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] + + def _get_generated_ngrams(hypo_idx): + # Before decoding the next token, prevent decoding of ngrams that have already appeared + start_idx = cur_len + 1 - self.ngram_size + ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) + return generated_ngrams[hypo_idx].get(ngram_idx, []) + + banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] + return banned_tokens + + +class BeamSearchStrategy: + def __init__(self, num_beams, max_length, length_penalty, end_tokens, device='cuda'): + self.num_beams = num_beams + self.max_length = max_length + self.length_penalty = length_penalty + self.end_tokens = end_tokens + self.beam_scorer = BeamSearchScorer( + batch_size=1, + max_length=max_length, + num_beams=num_beams, + device=device, + length_penalty=length_penalty, + do_early_stopping=False, + ) + self.beam_scores = torch.zeros(1, dtype=torch.float, device=device) + + @property + def is_done(self) -> bool: + return self.beam_scorer.is_done + + def forward(self, logits, tokens, mems): + last_beam_num = tokens.size(0) + next_token_scores = F.log_softmax(logits, dim=-1) + next_token_scores = next_token_scores + self.beam_scores[:, None].expand_as(next_token_scores) + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(1, last_beam_num * vocab_size) + + probs = F.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=2 * self.num_beams) + next_token_scores = torch.gather(next_token_scores, -1, next_tokens) + next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) + next_tokens = torch.gather(next_tokens, -1, _indices) + + next_indices = next_tokens // vocab_size + next_tokens = next_tokens % vocab_size + # stateless + tokens = tokens.expand((self.num_beams, -1)) + beam_outputs = self.beam_scorer.process( + tokens, + next_token_scores, + next_tokens, + next_indices, + eos_token_id=self.end_tokens, + mems=mems + ) + self.beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + beam_next_tokens = beam_next_tokens.unsqueeze(-1) + tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], dim=-1) + mems = [mem[beam_idx] for mem in mems] if mems else None + return tokens, mems + + def finalize(self, tokens, mems): + # TODO check the eos token here + tokens, mems, scores = self.beam_scorer.finalize(tokens, self.beam_scores, + eos_token_id=self.end_tokens[0], + mems=mems) + return tokens, mems diff --git a/inference_glm.py b/inference_glm.py index 0456cc550df2065cc954d9509fd3f370fa1aa26d..cfba93f1fc1715be50b8280ce4ae96ac9b472df4 100644 --- a/inference_glm.py +++ b/inference_glm.py @@ -19,10 +19,7 @@ import mpu from arguments import get_args from model.glm_model import GLMModel from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer -from tokenization import get_tokenizer -from generation.sampling_strategies import BaseStrategy -from generation.autoregressive_sampling import update_mems -from generation.utils import timed_name, save_multiple_images, generate_continually +from generation.glm_sampling import filling_sequence_glm def read_context(tokenizer, args, output=None): @@ -101,127 +98,6 @@ def get_batch(context_tokens, args): return tokens, attention_mask, position_ids -def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): - # This function has been mostly taken from huggingface conversational ai code at - # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 - - if top_k > 0: - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - - if top_p > 0.0: - # convert to 1D - logits = logits.view(logits.size()[1]).contiguous() - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - indices_to_remove = sorted_indices[sorted_indices_to_remove] - logits[indices_to_remove] = filter_value - # going back to 2D - logits = logits.view(1, -1).contiguous() - - return logits - - -def sample_sequence(model, tokenizer, context_tokens, context_length, args, mems=None, end_tokens=None): - tokens = context_tokens.new_full((1, 1), tokenizer.get_command('sop').Id) - counter = 0 - if mems is None: - mems = [] - if end_tokens is None: - end_tokens = [args.eod_token] - if args.num_beams > 1: - beam_scorer = BeamSearchScorer( - batch_size=1, - max_length=args.out_seq_length, - num_beams=args.num_beams, - device=context_tokens.device, - length_penalty=args.length_penalty, - do_early_stopping=False, - ) - beam_scores = torch.zeros(1, dtype=torch.float, device=context_tokens.device) - last_beam_num = 1 - while counter < args.out_seq_length: - if args.block_lm: - if args.no_block_position: - position_ids = context_tokens.new_full((last_beam_num, 1), context_length + counter) - else: - position_ids = context_tokens.new_ones(last_beam_num, 2, 1) - position_ids[:, 0] = context_length - position_ids[:, 1] = counter + 1 - attention_mask = context_tokens.new_ones(1, context_length + counter, device=context_tokens.device, - dtype=torch.long) - else: - position_ids = context_tokens.new_ones((last_beam_num, 1)) * (context_length + counter - 1) - attention_mask = context_tokens.new_ones(last_beam_num, 1, 1, args.mem_length + 1, - device=context_tokens.device, dtype=torch.float) - last_token = tokens[:, -1:] - next_token_logits, *mem_kvs = model(last_token, position_ids, attention_mask, *mems) - mems = update_mems(mem_kvs, mems, max_memory_length=1000000) - next_token_logits = next_token_logits[:, -1] - if args.num_beams > 1: - next_token_scores = F.log_softmax(next_token_logits, dim=-1) - next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view(1, last_beam_num * vocab_size) - - probs = F.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=2 * args.num_beams) - next_token_scores = torch.gather(next_token_scores, -1, next_tokens) - next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) - next_tokens = torch.gather(next_tokens, -1, _indices) - - next_indices = next_tokens // vocab_size - next_tokens = next_tokens % vocab_size - # stateless - tokens = tokens.expand((args.num_beams, -1)) - beam_outputs = beam_scorer.process( - tokens, - next_token_scores, - next_tokens, - next_indices, - eos_token_id=end_tokens, - mems=mems - ) - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - beam_next_tokens = beam_next_tokens.unsqueeze(-1) - tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], dim=-1) - mems = [mem[beam_idx] for mem in mems] if mems else None - if beam_scorer.is_done: - break - last_beam_num = args.num_beams - else: - next_token_logits /= args.temperature - next_token_logits = top_k_logits(next_token_logits, top_k=args.top_k, top_p=args.top_p) - log_probs = F.softmax(next_token_logits, dim=-1) - prev = torch.multinomial(log_probs, num_samples=1)[0] - is_end = prev.item() in end_tokens - if is_end: - break - prev = prev.view(1, 1) - tokens = prev if tokens is None else torch.cat((tokens, prev), dim=1) - counter += 1 - if not args.block_lm and mpu.get_model_parallel_rank() == 0 and counter % 16 == 0: - output_tokens_list = tokens.view(-1).contiguous() - decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist()) - if mpu.get_model_parallel_rank() == 0 and (counter % 128 == 0 or is_end): - os.system('clear') - trim_decode_tokens = decode_tokens - print(trim_decode_tokens, flush=True) - if args.num_beams > 1: - tokens, mems = beam_scorer.finalize(tokens, beam_scores, next_tokens, next_indices, eos_token_id=args.eod_token, - mems=mems) - return torch.cat((context_tokens, tokens), dim=1), mems - - def generate_samples(model, tokenizer, args): model.eval() output_path = "./samples" @@ -240,7 +116,7 @@ def generate_samples(model, tokenizer, args): tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args) mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK'] mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens] - end_tokens = [tokenizer.get_command('eop').Id, args.eod_token] + end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id] mask_positions = [] for token in mask_tokens: mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist() @@ -254,8 +130,9 @@ def generate_samples(model, tokenizer, args): position = position_ids[0, mask_position].item() else: position = mask_position - tokens, mems = sample_sequence(model, tokenizer, tokens, position, args, mems=mems, + new_tokens, mems = filling_sequence_glm(model, tokenizer, position, args, mems=mems, end_tokens=end_tokens) + tokens = torch.cat((tokens, new_tokens), dim=1) output_tokens_list = tokens.view(-1).contiguous() if mpu.get_model_parallel_rank() == 0: os.system('clear') @@ -272,7 +149,6 @@ def generate_samples(model, tokenizer, args): def main(args): initialize_distributed(args) tokenizer = prepare_tokenizer(args) - args.eod_token = tokenizer.get_command('eos').Id # build model model = GLMModel(args) if args.fp16: diff --git a/move_weights_glm.py b/move_weights_glm.py deleted file mode 100644 index 6d3755d22f33f6137add7a4dcee26d31fd106e50..0000000000000000000000000000000000000000 --- a/move_weights_glm.py +++ /dev/null @@ -1,38 +0,0 @@ -import sys -import os -import torch -import copy - -checkpoint = sys.argv[1] -target_path = sys.argv[2] - -assert os.path.isdir(checkpoint) -iteration_file = os.path.join(checkpoint, 'latest_checkpointed_iteration.txt') -if os.path.exists(iteration_file): - with open(iteration_file) as fin: - iteration = int(fin.read().strip()) - checkpoint = os.path.join(checkpoint, str(iteration)) -else: - iteration = None - -os.makedirs(target_path, exist_ok=True) -if iteration is not None: - with open(os.path.join(target_path, "latest"), "w") as output: - output.write(str(iteration)) - target_path = os.path.join(target_path, str(iteration)) - os.makedirs(target_path, exist_ok=True) - - -filenames = os.listdir(checkpoint) -filenames = [filename for filename in filenames if filename.startswith("mp_rank_")] -filenames = sorted(filenames, - key=lambda x: int(x.split('_')[2])) -filenames = [os.path.join(checkpoint, x) for x in filenames] - -for filename in filenames: - data = torch.load(filename) - state_dict = data['module'] - state_dict['transformer.word_embeddings.weight'] = state_dict['word_embeddings.weight'] - del state_dict['word_embeddings.weight'] - # print(f"Target path: {os.path.join(target_path, os.path.basename(filename))}") - torch.save(data, os.path.join(target_path, os.path.basename(filename)))