diff --git a/generation/autoregressive_sampling.py b/generation/autoregressive_sampling.py index cccc7c79955c5231f47e6b9ace7e556b83ddb511..541d9091d565e27685c64b68775c271842252438 100644 --- a/generation/autoregressive_sampling.py +++ b/generation/autoregressive_sampling.py @@ -14,7 +14,7 @@ import random import torch from .sampling_strategies import BaseStrategy -def get_masks_and_position_ids(seq): +def get_masks_and_position_ids_default(seq): tokens = seq.unsqueeze(0) attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device) @@ -36,7 +36,6 @@ def update_mems(hiddens, mems, max_memory_length): memory_length = mems.shape[2] if mems is not None else 0 query_length = hiddens.shape[2] new_memory_length = min(max_memory_length, memory_length + query_length) - new_mems = [] with torch.no_grad(): if new_memory_length <= query_length: return hiddens[:, :, -new_memory_length:] @@ -55,10 +54,16 @@ def filling_sequence( batch_size, strategy=BaseStrategy(), max_memory_length=100000, - log_attention_weights=None + log_attention_weights=None, + get_masks_and_position_ids=get_masks_and_position_ids_default, + mems=None ): ''' seq: [2, 3, 5, ..., -1(to be generated), -1, ...] + mems: [num_layers, batch_size, len_mems(index), mem_hidden_size] + cache, should be first mems.shape[1] parts of context_tokens. + mems are the first-level citizens here, but we don't assume what is memorized. + input mems are used when multi-phase generation. ''' assert len(seq.shape) == 1 @@ -72,9 +77,7 @@ def filling_sequence( attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16 # initialize generation counter = context_length - 1 # Last fixed index is ``counter'' - index = 0 # Next forward starting index, also the length of cache. - mems = None # mems are the first-level citizens here, but we don't assume what is memorized. - + index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache. # step-by-step generation while counter < len(seq) - 1: # Now, we want to generate seq[counter + 1], @@ -83,7 +86,7 @@ def filling_sequence( if seq[counter + 1] >= 0: # provided tokens = torch.cat( ( - tokens, + tokens, seq[counter+1: counter+2].expand(tokens.shape[0], 1) ), dim=1 ) @@ -92,13 +95,16 @@ def filling_sequence( # forward if log_attention_weights is not None: - model.log_attention_weights = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen - kw_tensors = {'mems': mems} if mems is not None else {} + log_attention_weights_part = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen + else: + log_attention_weights_part = None + logits, *mem_kv = model( tokens[:, index:], position_ids[..., index: counter+1], attention_mask[..., index: counter+1, :counter+1], # TODO memlen - **kw_tensors # if no mems, cannot pass + mems=mems, + log_attention_weights=log_attention_weights_part ) mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length) counter += 1 @@ -107,6 +113,6 @@ def filling_sequence( logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size] tokens = tokens.expand(batch_size, -1) tokens, mems = strategy.forward(logits, tokens, mems) - - model.log_attention_weights = None - return tokens \ No newline at end of file + if strategy.is_done: + break + return strategy.finalize(tokens, mems) \ No newline at end of file diff --git a/generation/sampling_strategies/base_strategy.py b/generation/sampling_strategies/base_strategy.py index 20339941e16802cb60c88613a5bafee76db75b74..a5268e5ab4db13e6c8ede944f81291505560e4c6 100644 --- a/generation/sampling_strategies/base_strategy.py +++ b/generation/sampling_strategies/base_strategy.py @@ -15,7 +15,7 @@ import torch import torch.nn.functional as F -def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): +def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-65504): # This function has been mostly taken from huggingface conversational ai code at # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 @@ -69,10 +69,11 @@ class BaseStrategy: logits = top_k_logits(logits, self.topk, self.top_p) probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch pred = torch.multinomial(probs, num_samples=1) - if pred.item() in self.end_tokens: + if pred.numel() == 1 and pred.item() in self.end_tokens: self._is_done = True tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1) return tokens, mems def finalize(self, tokens, mems): + self._is_done = False return tokens, mems diff --git a/generation/sampling_strategies/beam_search_strategy.py b/generation/sampling_strategies/beam_search_strategy.py index ec205e071590c1683d2e96bbac80ff57db31d2df..b30926294ce26e4a890b51b89896b82de88b285a 100644 --- a/generation/sampling_strategies/beam_search_strategy.py +++ b/generation/sampling_strategies/beam_search_strategy.py @@ -9,459 +9,99 @@ # here put the import lib import torch import torch.nn.functional as F -from abc import ABC, abstractmethod -from collections import UserDict -from typing import Optional, Tuple, List, Iterable, Union - - -class BeamScorer(ABC): - """ - Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and - :meth:`~transformers.PretrainedModel.beam_sample`. - """ - - @abstractmethod - def process( - self, - input_ids: torch.LongTensor, - next_scores: torch.FloatTensor, - next_tokens: torch.LongTensor, - next_indices: torch.LongTensor, - **kwargs - ) -> Tuple[torch.Tensor]: - raise NotImplementedError("This is an abstract method.") - - @abstractmethod - def finalize( - self, - input_ids: torch.LongTensor, - next_scores: torch.FloatTensor, - next_tokens: torch.LongTensor, - next_indices: torch.LongTensor, - **kwargs - ) -> torch.LongTensor: - raise NotImplementedError("This is an abstract method.") - - -class BeamSearchScorer(BeamScorer): - r""" - :class:`transformers.BeamScorer` implementing standard beam search decoding. - - Adapted in part from `Facebook's XLM beam search code - <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__. - - Args: - batch_size (:obj:`int`): - Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel. - max_length (:obj:`int`): - The maximum length of the sequence to be generated. - num_beams (:obj:`int`): - Number of beams for beam search. - device (:obj:`torch.device`): - Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of - :obj:`BeamSearchScorer` will be allocated. - length_penalty (:obj:`float`, `optional`, defaults to 1.0): - Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the - model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer - sequences. - do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. - num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): - The number of beam hypotheses that shall be returned upon calling - :meth:`~transformer.BeamSearchScorer.finalize`. - """ - - def __init__( - self, - batch_size: int, - max_length: int, - num_beams: int, - device: Union[torch.device, str], - length_penalty: Optional[float] = 1.0, - do_early_stopping: Optional[bool] = False, - num_beam_hyps_to_keep: Optional[int] = 1, - ): - self.max_length = max_length - self.num_beams = num_beams - self.device = device - self.length_penalty = length_penalty - self.do_early_stopping = do_early_stopping - self.num_beam_hyps_to_keep = num_beam_hyps_to_keep - - self._is_init = False - self._beam_hyps = [ - BeamHypotheses( - num_beams=self.num_beams, - max_length=self.max_length, - length_penalty=self.length_penalty, - early_stopping=self.do_early_stopping, - ) - for _ in range(batch_size) - ] - self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) - - # if not isinstance(num_beams, int) or num_beams <= 1: - # raise ValueError( - # f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." - # ) - - @property - def is_done(self) -> bool: - return self._done.all() - - def process( - self, - input_ids: torch.LongTensor, - next_scores: torch.FloatTensor, - next_tokens: torch.LongTensor, - next_indices: torch.LongTensor, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - mems=None - ) -> Tuple[torch.Tensor]: - cur_len = input_ids.shape[-1] - batch_size = len(self._beam_hyps) - assert batch_size == (input_ids.shape[0] // self.num_beams) - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - device = next_scores.device - next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device) - next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device) - next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device) - - for batch_idx, beam_hyp in enumerate(self._beam_hyps): - if self._done[batch_idx]: - assert ( - len(beam_hyp) >= self.num_beams - ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams) - assert ( - eos_token_id is not None and pad_token_id is not None - ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" - # pad the batch - next_beam_scores[batch_idx, :] = 0 - next_beam_tokens[batch_idx, :] = pad_token_id - next_beam_indices[batch_idx, :] = 0 - continue - - # next tokens for this sentence - beam_idx = 0 - for beam_token_rank, (next_token, next_score, next_index) in enumerate( - zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) - ): - batch_beam_idx = batch_idx * self.num_beams + next_index - # add to generated hypotheses if end of sentence - if (eos_token_id is not None) and (next_token.item() in eos_token_id): - # if beam_token does not belong to top num_beams tokens, it should not be added - is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams - if is_beam_token_worse_than_top_num_beams: - continue - beam_hyp.add( - input_ids[batch_beam_idx].clone(), - next_score.item(), - mems=[mem[[next_index.item()]] for mem in mems] if mems else None - ) - else: - # add next predicted token since it is not eos_token - next_beam_scores[batch_idx, beam_idx] = next_score - next_beam_tokens[batch_idx, beam_idx] = next_token - next_beam_indices[batch_idx, beam_idx] = batch_beam_idx - beam_idx += 1 - - # once the beam for next step is full, don't add more tokens to it. - if beam_idx == self.num_beams: - break - - if beam_idx < self.num_beams: - raise ValueError( - f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." - ) - - # Check if we are done so that we can save a pad step if all(done) - self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( - next_scores[batch_idx].max().item(), cur_len - ) - - return UserDict( - { - "next_beam_scores": next_beam_scores.view(-1), - "next_beam_tokens": next_beam_tokens.view(-1), - "next_beam_indices": next_beam_indices.view(-1), - } - ) - - def finalize( - self, - input_ids: torch.LongTensor, - final_beam_scores: torch.FloatTensor, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - mems=None - ) -> Tuple[torch.LongTensor, List[torch.Tensor], torch.FloatTensor]: - batch_size = len(self._beam_hyps) - # finalize all open beam hypotheses and add to generated hypotheses - for batch_idx, beam_hyp in enumerate(self._beam_hyps): - if self._done[batch_idx]: - continue - - # need to add best num_beams hypotheses to generated hyps - for beam_id in range(self.num_beams): - batch_beam_idx = batch_idx * self.num_beams + beam_id - final_score = final_beam_scores[batch_beam_idx].item() - final_tokens = input_ids[batch_beam_idx] - beam_hyp.add(final_tokens, final_score, mems=[mem[[batch_beam_idx]] for mem in mems] if mems else None) - - # select the best hypotheses - sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) - best = [] - - # retrieve best hypotheses - for i, beam_hyp in enumerate(self._beam_hyps): - sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) - for j in range(self.num_beam_hyps_to_keep): - score, best_hyp, mems = sorted_hyps.pop() - sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) - best.append((best_hyp, mems, score)) - - # prepare for adding eos - sent_max_len = min(sent_lengths.max().item(), self.max_length) - decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) - scores = final_beam_scores.new(batch_size * self.num_beam_hyps_to_keep) - # shorter batches are padded if needed - if sent_lengths.min().item() != sent_lengths.max().item(): - assert pad_token_id is not None, "`pad_token_id` has to be defined" - decoded.fill_(pad_token_id) - - # fill with hypotheses and eos_token_id if the latter fits in - mems = [] - for i, (hypo, mem, score) in enumerate(best): - scores[i] = score - decoded[i, : sent_lengths[i]] = hypo - if sent_lengths[i] < sent_max_len: - decoded[i, sent_lengths[i]] = eos_token_id - mems.append(mem) - mems = [torch.cat([mem[i] for mem in mems], dim=0) for i in range(len(mems[0]))] if mems and mems[0] else None - return decoded, mems, scores - - -class BeamHypotheses: - def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool): - """ - Initialize n-best list of hypotheses. - """ - self.max_length = max_length - 1 # ignoring bos_token - self.length_penalty = length_penalty - self.early_stopping = early_stopping - self.num_beams = num_beams - self.beams = [] - self.worst_score = 1e9 - - def __len__(self): - """ - Number of hypotheses in the list. - """ - return len(self.beams) - - def add(self, hyp: torch.LongTensor, sum_logprobs: float, mems=None): - """ - Add a new hypothesis to the list. - """ - score = sum_logprobs / (max(hyp.shape[-1], 1) ** self.length_penalty) - if len(self) < self.num_beams or score > self.worst_score: - self.beams.append((score, hyp, mems)) - if len(self) > self.num_beams: - sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)]) - del self.beams[sorted_next_scores[0][1]] - self.worst_score = sorted_next_scores[1][0] - else: - self.worst_score = min(score, self.worst_score) - - def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: - """ - If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst - one in the heap, then we are done with this sentence. - """ - - if len(self) < self.num_beams: - return False - elif self.early_stopping: - return True - else: - cur_score = best_sum_logprobs / cur_len ** self.length_penalty - ret = self.worst_score >= cur_score - return ret - - -class LogitsProcessor(ABC): - """Abstract base class for all logit processors that can be applied during generation.""" - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - """Torch method for processing logits.""" - raise NotImplementedError( - f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." - ) - - -class LogitsProcessorList(list): - """ - This class can be used to create a list of :class:`~transformers.LogitsProcessor` or - :class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from - list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or - :class:`~transformers.LogitsProcessor` to the inputs. - """ - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - for processor in self: - scores = processor(input_ids, scores) - return scores - - -class MinLengthLogitsProcessor(LogitsProcessor): - r""" - :class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0. - - Args: - min_length (:obj:`int`): - The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`. - eos_token_id (:obj:`int`): - The id of the `end-of-sequence` token. - """ - - def __init__(self, min_length: int, eos_token_ids: Union[List[int], int]): - if not isinstance(min_length, int) or min_length < 0: - raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") - if isinstance(eos_token_ids, int): - eos_token_ids = [eos_token_ids] - for eos_token_id in eos_token_ids: - if not isinstance(eos_token_id, int) or eos_token_id < 0: - raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") - - self.min_length = min_length - self.eos_token_ids = eos_token_ids - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - cur_len = input_ids.shape[-1] - if cur_len < self.min_length: - for eos_token_id in self.eos_token_ids: - scores[:, eos_token_id] = -float("inf") - return scores - - -class NoRepeatNGramLogitsProcessor(LogitsProcessor): - r""" - :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq - <https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__. - - Args: - ngram_size (:obj:`int`): - All ngrams of size :obj:`ngram_size` can only occur once. - """ - - def __init__(self, ngram_size: int): - if not isinstance(ngram_size, int) or ngram_size <= 0: - raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") - self.ngram_size = ngram_size - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - num_batch_hypotheses = scores.shape[0] - cur_len = input_ids.shape[-1] - banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len) - - for i, banned_tokens in enumerate(banned_batch_tokens): - scores[i, banned_tokens] = -float("inf") - - return scores - - def _calc_banned_ngram_tokens( - self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int - ) -> List[Iterable[int]]: - """Copied from fairseq for no_repeat_ngram in beam_search""" - if cur_len + 1 < self.ngram_size: - # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet - return [[] for _ in range(num_hypos)] - generated_ngrams = [{} for _ in range(num_hypos)] - for idx in range(num_hypos): - gen_tokens = prev_input_ids[idx].tolist() - generated_ngram = generated_ngrams[idx] - for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]): - prev_ngram_tuple = tuple(ngram[:-1]) - generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] - - def _get_generated_ngrams(hypo_idx): - # Before decoding the next token, prevent decoding of ngrams that have already appeared - start_idx = cur_len + 1 - self.ngram_size - ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) - return generated_ngrams[hypo_idx].get(ngram_idx, []) - - banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] - return banned_tokens - class BeamSearchStrategy: - def __init__(self, num_beams, max_length, length_penalty, end_tokens, device='cuda', no_repeat_ngram_size=0, - min_tgt_length=0): + def __init__(self, num_beams, length_penalty=1., return_only_end=False, + end_tokens=[], invalid_slices=[], no_repeat_ngram_size=0, min_tgt_length=0): self.num_beams = num_beams - self.max_length = max_length self.length_penalty = length_penalty self.end_tokens = end_tokens - self.no_repeat_ngram_size = no_repeat_ngram_size + self.ngram = no_repeat_ngram_size self.min_tgt_length = min_tgt_length - self.processors = LogitsProcessorList() - if min_tgt_length > 0: - processor = MinLengthLogitsProcessor(min_tgt_length, self.end_tokens) - self.processors.append(processor) - if no_repeat_ngram_size > 0: - processor = NoRepeatNGramLogitsProcessor(no_repeat_ngram_size) - self.processors.append(processor) - self.beam_scorer = BeamSearchScorer( - batch_size=1, - max_length=max_length, - num_beams=num_beams, - device=device, - length_penalty=length_penalty, - do_early_stopping=False, - ) - self.beam_scores = torch.zeros(1, dtype=torch.float, device=device) - - @property - def is_done(self) -> bool: - return self.beam_scorer.is_done + self.invalid_slices = invalid_slices + self.return_only_end = return_only_end + self._init_cache() + + def _init_cache(self): + self.end_beams = [] # list of LongTensors + self.end_beams_penalized_scores = [] # list of LongTensors + self.cached_beam_scores = 0 # [batch_size] + self.cached_beam_ngram_bans = [{} for i in range(self.num_beams)] + self.is_done = False + + def _add_end_beams(self, score, beam): + score = score / ((5. + len(beam)) / 6) ** self.length_penalty # Magic number for OpenNMT + for i in range(len(self.end_beams), -1, -1): + if i == 0 or score < self.end_beams_penalized_scores[i-1]: + break + self.num_beams.insert(i, beam) + self.end_beams_penalized_scores.insert(i, score) + + self.num_beams = self.num_beams[:self.num_beams] + self.end_beams_penalized_scores = self.end_beams_penalized_scores[:self.num_beams] def forward(self, logits, tokens, mems): - last_beam_num = tokens.size(0) - logits = self.processors(tokens, logits) - next_token_scores = F.log_softmax(logits, dim=-1) - next_token_scores = next_token_scores + self.beam_scores[:, None].expand_as(next_token_scores) - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view(1, last_beam_num * vocab_size) + batch_size, vocab_size = logits.shape + seq_len = tokens.shape[-1] + logits = logits.float() + for invalid_slice in self.invalid_slices: + logits[..., invalid_slice] = -65504 + if self.min_tgt_length > seq_len: + for end_token in self.end_tokens: + logits[..., end_token] = -65504 + if self.ngram > 0 and seq_len > self.ngram: + for i in range(batch_size): + ngram_prefix = tokens[i, -(self.ngram-1):].tolist() # TODO ngram=1 + for banned_index in self.cached_beam_ngram_bans.get(tuple(ngram_prefix), default=[]): + logits[i, banned_index] = -65504 + + next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size] + next_token_scores = next_token_scores + self.cached_beam_scores[:, None].expand_as(next_token_scores) + + next_token_scores = next_token_scores.view(batch_size * vocab_size) + + probs = F.softmax(next_token_scores) + next_tokens = torch.multinomial(probs, num_samples=2 * self.num_beams) # [2*nb] + next_token_scores = next_token_scores[next_tokens] + next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=0) + next_tokens = next_tokens[_indices] + + next_indices = next_tokens // vocab_size # batch idx + next_tokens = next_tokens % vocab_size - probs = F.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=2 * self.num_beams) - next_token_scores = torch.gather(next_token_scores, -1, next_tokens) - next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) - next_tokens = torch.gather(next_tokens, -1, _indices) + # select out end beams or continue beams + beam_continue = [] + scores_continue = [] + bans_continue = [] + mems_contiue = [] + for i in range(len(next_tokens)): + beam = torch.cat(tokens[next_indices[i]], next_tokens[i:i+1]) + if int(next_tokens[i]) in self.end_tokens: + self._add_end_beams(next_token_scores[i], beam) + elif len(beam_continue) < batch_size: + beam_continue.append(beam) + mems_contiue.append(mems[:, next_indices[i]]) + # update caches + scores_continue.append(next_token_scores[i]) + if self.ngram > 0: + bans = self.cached_beam_ngram_bans[next_indices[i]].copy() + ngram_prefix = tuple(tokens[next_indices[i], -(self.ngram-1):].tolist()) # TODO ngram=1 + bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[i],) + bans_continue.append(bans) + else: + break + tokens = torch.stack(beam_continue) + mems = torch.stack(mems_contiue, dim=1) + self.cached_beam_scores = torch.tensor(scores_continue, device=logits.device) + self.cached_beam_ngram_bans = bans_continue - next_indices = next_tokens // vocab_size - next_tokens = next_tokens % vocab_size - # stateless - tokens = tokens.expand((self.num_beams, -1)) - beam_outputs = self.beam_scorer.process( - tokens, - next_token_scores, - next_tokens, - next_indices, - eos_token_id=self.end_tokens, - mems=mems - ) - self.beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - beam_next_tokens = beam_next_tokens.unsqueeze(-1) - tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], dim=-1) - mems = [mem[beam_idx] for mem in mems] if mems else None + # TODO is_done return tokens, mems - def finalize(self, tokens, mems): - tokens, mems, scores = self.beam_scorer.finalize(tokens, self.beam_scores, - eos_token_id=self.end_tokens[0], - mems=mems) - return tokens, mems + def finalize(self, tokens): + if not self.return_only_end: + for i in range(tokens.shape[0]): + self._add_end_beams(self.cached_beam_scores[i], tokens[i]) + ret = self.end_beams + self._init_cache() + return ret diff --git a/generation/sampling_strategies/beam_search_strategy_old.py b/generation/sampling_strategies/beam_search_strategy_old.py new file mode 100644 index 0000000000000000000000000000000000000000..aeab7989930250aa93b3231399152cd514b4162b --- /dev/null +++ b/generation/sampling_strategies/beam_search_strategy_old.py @@ -0,0 +1,467 @@ +# -*- encoding: utf-8 -*- +''' +@File : base_strategy.py +@Time : 2021/10/08 22:22:42 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import torch +import torch.nn.functional as F +from abc import ABC, abstractmethod +from collections import UserDict +from typing import Optional, Tuple, List, Iterable, Union + + +class BeamScorer(ABC): + """ + Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and + :meth:`~transformers.PretrainedModel.beam_sample`. + """ + + @abstractmethod + def process( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + **kwargs + ) -> Tuple[torch.Tensor]: + raise NotImplementedError("This is an abstract method.") + + @abstractmethod + def finalize( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + **kwargs + ) -> torch.LongTensor: + raise NotImplementedError("This is an abstract method.") + + +class BeamSearchScorer(BeamScorer): + r""" + :class:`transformers.BeamScorer` implementing standard beam search decoding. + + Adapted in part from `Facebook's XLM beam search code + <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__. + + Args: + batch_size (:obj:`int`): + Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel. + max_length (:obj:`int`): + The maximum length of the sequence to be generated. + num_beams (:obj:`int`): + Number of beams for beam search. + device (:obj:`torch.device`): + Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of + :obj:`BeamSearchScorer` will be allocated. + length_penalty (:obj:`float`, `optional`, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the + model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer + sequences. + do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. + num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): + The number of beam hypotheses that shall be returned upon calling + :meth:`~transformer.BeamSearchScorer.finalize`. + """ + + def __init__( + self, + batch_size: int, + max_length: int, + num_beams: int, + device: Union[torch.device, str], + length_penalty: Optional[float] = 1.0, + do_early_stopping: Optional[bool] = False, + num_beam_hyps_to_keep: Optional[int] = 1, + ): + self.max_length = max_length + self.num_beams = num_beams + self.device = device + self.length_penalty = length_penalty + self.do_early_stopping = do_early_stopping + self.num_beam_hyps_to_keep = num_beam_hyps_to_keep + + self._is_init = False + self._beam_hyps = [ + BeamHypotheses( + num_beams=self.num_beams, + max_length=self.max_length, + length_penalty=self.length_penalty, + early_stopping=self.do_early_stopping, + ) + for _ in range(batch_size) + ] + self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) + + # if not isinstance(num_beams, int) or num_beams <= 1: + # raise ValueError( + # f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." + # ) + + @property + def is_done(self) -> bool: + return self._done.all() + + def process( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + mems=None + ) -> Tuple[torch.Tensor]: + cur_len = input_ids.shape[-1] + batch_size = len(self._beam_hyps) + assert batch_size == (input_ids.shape[0] // self.num_beams) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + device = next_scores.device + next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device) + next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device) + next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device) + + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + assert ( + len(beam_hyp) >= self.num_beams + ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams) + assert ( + eos_token_id is not None and pad_token_id is not None + ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" + # pad the batch + next_beam_scores[batch_idx, :] = 0 + next_beam_tokens[batch_idx, :] = pad_token_id + next_beam_indices[batch_idx, :] = 0 + continue + + # next tokens for this sentence + beam_idx = 0 + for beam_token_rank, (next_token, next_score, next_index) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) + ): + batch_beam_idx = batch_idx * self.num_beams + next_index + # add to generated hypotheses if end of sentence + if (eos_token_id is not None) and (next_token.item() in eos_token_id): + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams + if is_beam_token_worse_than_top_num_beams: + continue + beam_hyp.add( + input_ids[batch_beam_idx].clone(), + next_score.item(), + mems=[mem[[next_index.item()]] for mem in mems] if mems else None + ) + else: + # add next predicted token since it is not eos_token + next_beam_scores[batch_idx, beam_idx] = next_score + next_beam_tokens[batch_idx, beam_idx] = next_token + next_beam_indices[batch_idx, beam_idx] = batch_beam_idx + beam_idx += 1 + + # once the beam for next step is full, don't add more tokens to it. + if beam_idx == self.num_beams: + break + + if beam_idx < self.num_beams: + raise ValueError( + f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." + ) + + # Check if we are done so that we can save a pad step if all(done) + self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( + next_scores[batch_idx].max().item(), cur_len + ) + + return UserDict( + { + "next_beam_scores": next_beam_scores.view(-1), + "next_beam_tokens": next_beam_tokens.view(-1), + "next_beam_indices": next_beam_indices.view(-1), + } + ) + + def finalize( + self, + input_ids: torch.LongTensor, + final_beam_scores: torch.FloatTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + mems=None + ) -> Tuple[torch.LongTensor, List[torch.Tensor], torch.FloatTensor]: + batch_size = len(self._beam_hyps) + # finalize all open beam hypotheses and add to generated hypotheses + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + continue + + # need to add best num_beams hypotheses to generated hyps + for beam_id in range(self.num_beams): + batch_beam_idx = batch_idx * self.num_beams + beam_id + final_score = final_beam_scores[batch_beam_idx].item() + final_tokens = input_ids[batch_beam_idx] + beam_hyp.add(final_tokens, final_score, mems=[mem[[batch_beam_idx]] for mem in mems] if mems else None) + + # select the best hypotheses + sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) + best = [] + + # retrieve best hypotheses + for i, beam_hyp in enumerate(self._beam_hyps): + sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) + for j in range(self.num_beam_hyps_to_keep): + score, best_hyp, mems = sorted_hyps.pop() + sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) + best.append((best_hyp, mems, score)) + + # prepare for adding eos + sent_max_len = min(sent_lengths.max().item(), self.max_length) + decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) + scores = final_beam_scores.new(batch_size * self.num_beam_hyps_to_keep) + # shorter batches are padded if needed + if sent_lengths.min().item() != sent_lengths.max().item(): + assert pad_token_id is not None, "`pad_token_id` has to be defined" + decoded.fill_(pad_token_id) + + # fill with hypotheses and eos_token_id if the latter fits in + mems = [] + for i, (hypo, mem, score) in enumerate(best): + scores[i] = score + decoded[i, : sent_lengths[i]] = hypo + if sent_lengths[i] < sent_max_len: + decoded[i, sent_lengths[i]] = eos_token_id + mems.append(mem) + mems = [torch.cat([mem[i] for mem in mems], dim=0) for i in range(len(mems[0]))] if mems and mems[0] else None + return decoded, mems, scores + + +class BeamHypotheses: + def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool): + """ + Initialize n-best list of hypotheses. + """ + self.max_length = max_length - 1 # ignoring bos_token + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.num_beams = num_beams + self.beams = [] + self.worst_score = 1e9 + + def __len__(self): + """ + Number of hypotheses in the list. + """ + return len(self.beams) + + def add(self, hyp: torch.LongTensor, sum_logprobs: float, mems=None): + """ + Add a new hypothesis to the list. + """ + score = sum_logprobs / (max(hyp.shape[-1], 1) ** self.length_penalty) + if len(self) < self.num_beams or score > self.worst_score: + self.beams.append((score, hyp, mems)) + if len(self) > self.num_beams: + sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)]) + del self.beams[sorted_next_scores[0][1]] + self.worst_score = sorted_next_scores[1][0] + else: + self.worst_score = min(score, self.worst_score) + + def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: + """ + If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst + one in the heap, then we are done with this sentence. + """ + + if len(self) < self.num_beams: + return False + elif self.early_stopping: + return True + else: + cur_score = best_sum_logprobs / cur_len ** self.length_penalty + ret = self.worst_score >= cur_score + return ret + + +class LogitsProcessor(ABC): + """Abstract base class for all logit processors that can be applied during generation.""" + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + """Torch method for processing logits.""" + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + +class LogitsProcessorList(list): + """ + This class can be used to create a list of :class:`~transformers.LogitsProcessor` or + :class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from + list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or + :class:`~transformers.LogitsProcessor` to the inputs. + """ + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + for processor in self: + scores = processor(input_ids, scores) + return scores + + +class MinLengthLogitsProcessor(LogitsProcessor): + r""" + :class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0. + + Args: + min_length (:obj:`int`): + The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`. + eos_token_id (:obj:`int`): + The id of the `end-of-sequence` token. + """ + + def __init__(self, min_length: int, eos_token_ids: Union[List[int], int]): + if not isinstance(min_length, int) or min_length < 0: + raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") + if isinstance(eos_token_ids, int): + eos_token_ids = [eos_token_ids] + for eos_token_id in eos_token_ids: + if not isinstance(eos_token_id, int) or eos_token_id < 0: + raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") + + self.min_length = min_length + self.eos_token_ids = eos_token_ids + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + cur_len = input_ids.shape[-1] + if cur_len < self.min_length: + for eos_token_id in self.eos_token_ids: + scores[:, eos_token_id] = -float("inf") + return scores + + +class NoRepeatNGramLogitsProcessor(LogitsProcessor): + r""" + :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq + <https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__. + + Args: + ngram_size (:obj:`int`): + All ngrams of size :obj:`ngram_size` can only occur once. + """ + + def __init__(self, ngram_size: int): + if not isinstance(ngram_size, int) or ngram_size <= 0: + raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") + self.ngram_size = ngram_size + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + num_batch_hypotheses = scores.shape[0] + cur_len = input_ids.shape[-1] + banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len) + + for i, banned_tokens in enumerate(banned_batch_tokens): + scores[i, banned_tokens] = -float("inf") + + return scores + + def _calc_banned_ngram_tokens( + self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int + ) -> List[Iterable[int]]: + """Copied from fairseq for no_repeat_ngram in beam_search""" + if cur_len + 1 < self.ngram_size: + # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + return [[] for _ in range(num_hypos)] + generated_ngrams = [{} for _ in range(num_hypos)] + for idx in range(num_hypos): + gen_tokens = prev_input_ids[idx].tolist() + generated_ngram = generated_ngrams[idx] + for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]): + prev_ngram_tuple = tuple(ngram[:-1]) + generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] + + def _get_generated_ngrams(hypo_idx): + # Before decoding the next token, prevent decoding of ngrams that have already appeared + start_idx = cur_len + 1 - self.ngram_size + ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) + return generated_ngrams[hypo_idx].get(ngram_idx, []) + + banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] + return banned_tokens + + +class BeamSearchStrategy: + def __init__(self, num_beams, max_length, length_penalty, end_tokens, device='cuda', no_repeat_ngram_size=0, + min_tgt_length=0): + self.num_beams = num_beams + self.max_length = max_length + self.length_penalty = length_penalty + self.end_tokens = end_tokens + self.no_repeat_ngram_size = no_repeat_ngram_size + self.min_tgt_length = min_tgt_length + self.processors = LogitsProcessorList() + if min_tgt_length > 0: + processor = MinLengthLogitsProcessor(min_tgt_length, self.end_tokens) + self.processors.append(processor) + if no_repeat_ngram_size > 0: + processor = NoRepeatNGramLogitsProcessor(no_repeat_ngram_size) + self.processors.append(processor) + self.beam_scorer = BeamSearchScorer( + batch_size=1, + max_length=max_length, + num_beams=num_beams, + device=device, + length_penalty=length_penalty, + do_early_stopping=False, + ) + self.beam_scores = torch.zeros(1, dtype=torch.float, device=device) + + @property + def is_done(self) -> bool: + return self.beam_scorer.is_done + + def forward(self, logits, tokens, mems): + last_beam_num = tokens.size(0) + logits = self.processors(tokens, logits.float()) + next_token_scores = F.log_softmax(logits, dim=-1) + next_token_scores = next_token_scores + self.beam_scores[:, None].expand_as(next_token_scores) + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(1, last_beam_num * vocab_size) + + probs = F.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=2 * self.num_beams) + next_token_scores = torch.gather(next_token_scores, -1, next_tokens) + next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) + next_tokens = torch.gather(next_tokens, -1, _indices) + + next_indices = next_tokens // vocab_size + next_tokens = next_tokens % vocab_size + # stateless + tokens = tokens.expand((self.num_beams, -1)) + beam_outputs = self.beam_scorer.process( + tokens, + next_token_scores, + next_tokens, + next_indices, + eos_token_id=self.end_tokens, + mems=mems + ) + self.beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + beam_next_tokens = beam_next_tokens.unsqueeze(-1) + tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], dim=-1) + mems = [mem[beam_idx] for mem in mems] if mems else None + return tokens, mems + + def finalize(self, tokens, mems): + tokens, mems, scores = self.beam_scorer.finalize(tokens, self.beam_scores, + eos_token_id=self.end_tokens[0], + mems=mems) + return tokens, mems diff --git a/inference_glm.py b/inference_glm.py index 792efd5e030fa1564a7f52310b261844eb113c03..0e5e44d4a58a54a373140028994d1e5856d94bb9 100644 --- a/inference_glm.py +++ b/inference_glm.py @@ -1,12 +1,13 @@ # -*- encoding: utf-8 -*- ''' -@File : inference_cogview.py -@Time : 2021/10/09 19:41:58 +@File : inference_glm.py +@Time : 2021/10/22 19:41:58 @Author : Ming Ding @Contact : dm18@mail.tsinghua.edu.cn ''' # here put the import lib +from functools import partial import os import sys import random @@ -14,163 +15,138 @@ import time from datetime import datetime import torch import torch.nn.functional as F - +import argparse +import stat import mpu +from functools import partial from arguments import get_args from model.glm_model import GLMModel +from model.cached_autoregressive_model import CachedAutoregressiveMixin from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer -from generation.glm_sampling import filling_sequence_glm +from generation.autoregressive_sampling import filling_sequence from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy +from generation.utils import timed_name, generate_continually +def get_masks_and_position_ids_glm(seq, mask_position, context_length): + tokens = seq.unsqueeze(0) -def read_context(tokenizer, args, output=None): - terminate_runs, skip_run = 0, 0 - if mpu.get_model_parallel_rank() == 0: - while True: - raw_text = input("\nContext prompt (stop to exit) >>> ") - if not raw_text: - print('Prompt should not be empty!') - continue - if raw_text == "stop": - terminate_runs = 1 - break - generation_mask = '[gMASK]' if args.task_mask else '[MASK]' - if args.block_lm and 'MASK]' not in raw_text: - raw_text += ' ' + generation_mask - if output is not None: - output.write(raw_text) - context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization - if args.block_lm: - context_tokens = [tokenizer.get_command('ENC').Id] + context_tokens - if not raw_text.endswith('MASK]'): - context_tokens = context_tokens + [tokenizer.get_command('eos').Id] - context_length = len(context_tokens) - - if context_length >= args.max_sequence_length: - print("\nContext length", context_length, - "\nPlease give smaller context than the window length!") - continue - break - else: - context_length = 0 - - terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) - torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), - group=mpu.get_model_parallel_group()) - terminate_runs = terminate_runs_tensor[0].item() - - if terminate_runs == 1: - return terminate_runs, None, None, None - - context_length_tensor = torch.cuda.LongTensor([context_length]) - - torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), - group=mpu.get_model_parallel_group()) - context_length = context_length_tensor[0].item() - if mpu.get_model_parallel_rank() == 0: - context_tokens_tensor = torch.cuda.LongTensor(context_tokens) - else: - context_tokens_tensor = torch.cuda.LongTensor([0] * context_length) - torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), - group=mpu.get_model_parallel_group()) - if mpu.get_model_parallel_rank() != 0: - raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist()) - return terminate_runs, raw_text, context_tokens_tensor, context_length + attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device) + attention_mask.tril_() + attention_mask.unsqueeze_(1) + position_ids = torch.zeros(2, len(seq), device=tokens.device, dtype=torch.long) + torch.arange(0, context_length, out=position_ids[0, :context_length]) + position_ids[0, context_length:] = mask_position + torch.arange(1, len(seq) - context_length + 1, out=position_ids[1, context_length:]) -def get_batch(context_tokens, args): - tokens = context_tokens - tokens = tokens.view(1, -1).contiguous() - tokens = tokens.to('cuda') - - # Get the masks and postition ids. - if args.block_lm: - attention_mask = torch.ones(tokens.size(1), tokens.size(1), device='cuda', dtype=torch.long) - if args.fp16: - attention_mask = attention_mask.half() - position_ids = torch.arange(tokens.size(1), device='cuda', dtype=torch.long) - if not args.no_block_position: - block_position_ids = torch.zeros(tokens.size(1), device='cuda', dtype=torch.long) - position_ids = torch.stack((position_ids, block_position_ids), dim=0) - position_ids = position_ids.unsqueeze(0) - else: - raise NotImplementedError - + position_ids = position_ids.unsqueeze(0) return tokens, attention_mask, position_ids -def generate_samples(model, tokenizer, args): - model.eval() - output_path = "./samples" - if not os.path.exists(output_path): - os.makedirs(output_path) - output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt") - with torch.no_grad(), open(output_path, "w") as output: - while True: - torch.distributed.barrier(group=mpu.get_model_parallel_group()) - terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output) - if terminate_runs == 1: - return - start_time = time.time() - if args.block_lm: - mems = [] - tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args) - mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK'] - mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens] - end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id] - mask_positions = [] - for token in mask_tokens: - mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist() - mask_positions.sort() - if args.no_block_position: - for mask_position in mask_positions: - position_ids[0, mask_position + 1:] += args.out_seq_length - _, *mems = model(tokens, position_ids, attention_mask, *mems) - for mask_position in mask_positions: - if args.no_block_position: - position = position_ids[0, mask_position].item() - else: - position = mask_position - if args.num_beams > 1: - strategy = BeamSearchStrategy(num_beams=args.num_beams, max_length=args.out_seq_length, - length_penalty=args.length_penalty, end_tokens=end_tokens, - no_repeat_ngram_size=args.no_repeat_ngram_size, - min_tgt_length=args.min_tgt_length) - else: - strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, - end_tokens=end_tokens) - new_tokens, mems = filling_sequence_glm(model, tokenizer, position, strategy, args, mems=mems, - end_tokens=end_tokens) - tokens = torch.cat((tokens, new_tokens), dim=1) - output_tokens_list = tokens.view(-1).contiguous() - if mpu.get_model_parallel_rank() == 0: - os.system('clear') - print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) - print("\nContext:", raw_text, flush=True) - decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist()) - trim_decode_tokens = decode_tokens - print("\nGLM:", trim_decode_tokens, flush=True) - output.write(trim_decode_tokens + "\n") - - torch.distributed.barrier(group=mpu.get_model_parallel_group()) - - def main(args): initialize_distributed(args) tokenizer = prepare_tokenizer(args) - # build model + # build model model = GLMModel(args) + model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) if args.fp16: model = model.half() model = model.to(args.device) load_checkpoint(model, args) set_random_seed(args.seed) model.eval() - generate_samples(model, tokenizer, args) + end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id] + # define function for each query + strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k,end_tokens=end_tokens) + + def process(raw_text): + if args.with_id: + query_id, raw_text = raw_text.split('\t') + # add MASK + generation_mask = '[gMASK]' if args.task_mask else '[MASK]' + if 'MASK]' not in raw_text: + raw_text += ' ' + generation_mask + seq = tokenizer.EncodeAsIds(raw_text).tokenization + seq = [tokenizer.get_command('ENC').Id] + seq + if not raw_text.endswith('MASK]'): + seq = seq + [tokenizer.get_command('eos').Id] + print('raw text: ', raw_text) + if len(seq) > args.max_sequence_length: + raise ValueError('text too long.') + + # find mask tokens positions + mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK'] + mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens] + mask_positions = [] + context_tokens_tensor = torch.tensor(seq, dtype=torch.long, device=args.device) + for token in mask_tokens: + mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist() + mask_positions.sort() + + # generation + mbz = args.max_inference_batch_size + assert args.batch_size < mbz or args.batch_size % mbz == 0 + output_list = [] + # call for each position + for mp_idx, mask_position in enumerate(mask_positions): + get_func = partial(get_masks_and_position_ids_glm, mask_position=mask_position, context_length=len(seq)) + for tim in range(max(args.batch_size // mbz, 1)): + input_seq = torch.cuda.LongTensor(seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length-len(seq)-1), device=args.device) + output, _mems = filling_sequence(model, input_seq, + batch_size=min(args.batch_size, mbz), + strategy=strategy, + log_attention_weights=None, + get_masks_and_position_ids=get_func + ) # we don't use mems, fill back + if isinstance(output, torch.Tensor): # different strategies + output = list(output) + + output_list.extend(output) + + # clip -1s and fill back generated things into seq + for i in range(len(output_list)): + output = output_list[i].tolist() + try: + unfinished = output.index(-1) + except ValueError: + unfinished = len(output) + bog = output.index(tokenizer.get_command('sop').Id) + output_list[i] = output[:mask_position] + output[bog+1:unfinished] + output[mask_position+1:bog] + + # prepare the next auto-regressive generation + if mp_idx < len(mask_positions) - 1: + # TODO, here to select the best for this time, inverse prompting? + seq = output_list[0] + output_list = [] + + # decoding + txts = [] + for seq in output_list: + decode_tokens = tokenizer.DecodeIds(seq) + txts.append(decode_tokens) + + # save + if args.with_id: + full_path = os.path.join(args.output_path, query_id + '.txt') + else: + prefix = raw_text.replace('/', '')[:20] + full_path = timed_name(prefix, '.txt', args.output_path) + print(txts[0]) # print the first. + with open(full_path, 'w') as fout: + for txt in txts: + fout.write(txt + '\n') + os.chmod(full_path, stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU) + + os.makedirs(args.output_path, exist_ok=True) + generate_continually(process, args.input_source) if __name__ == "__main__": - args = get_args() + py_parser = argparse.ArgumentParser(add_help=False) + known, args_list = py_parser.parse_known_args() + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + with torch.no_grad(): - main(args) + main(args) \ No newline at end of file diff --git a/inference_glm_old.py b/inference_glm_old.py new file mode 100644 index 0000000000000000000000000000000000000000..792efd5e030fa1564a7f52310b261844eb113c03 --- /dev/null +++ b/inference_glm_old.py @@ -0,0 +1,176 @@ +# -*- encoding: utf-8 -*- +''' +@File : inference_cogview.py +@Time : 2021/10/09 19:41:58 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import random +import time +from datetime import datetime +import torch +import torch.nn.functional as F + +import mpu +from arguments import get_args +from model.glm_model import GLMModel +from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer +from generation.glm_sampling import filling_sequence_glm +from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy + + +def read_context(tokenizer, args, output=None): + terminate_runs, skip_run = 0, 0 + if mpu.get_model_parallel_rank() == 0: + while True: + raw_text = input("\nContext prompt (stop to exit) >>> ") + if not raw_text: + print('Prompt should not be empty!') + continue + if raw_text == "stop": + terminate_runs = 1 + break + generation_mask = '[gMASK]' if args.task_mask else '[MASK]' + if args.block_lm and 'MASK]' not in raw_text: + raw_text += ' ' + generation_mask + if output is not None: + output.write(raw_text) + context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization + if args.block_lm: + context_tokens = [tokenizer.get_command('ENC').Id] + context_tokens + if not raw_text.endswith('MASK]'): + context_tokens = context_tokens + [tokenizer.get_command('eos').Id] + context_length = len(context_tokens) + + if context_length >= args.max_sequence_length: + print("\nContext length", context_length, + "\nPlease give smaller context than the window length!") + continue + break + else: + context_length = 0 + + terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) + torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group()) + terminate_runs = terminate_runs_tensor[0].item() + + if terminate_runs == 1: + return terminate_runs, None, None, None + + context_length_tensor = torch.cuda.LongTensor([context_length]) + + torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group()) + context_length = context_length_tensor[0].item() + if mpu.get_model_parallel_rank() == 0: + context_tokens_tensor = torch.cuda.LongTensor(context_tokens) + else: + context_tokens_tensor = torch.cuda.LongTensor([0] * context_length) + torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group()) + if mpu.get_model_parallel_rank() != 0: + raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist()) + return terminate_runs, raw_text, context_tokens_tensor, context_length + + +def get_batch(context_tokens, args): + tokens = context_tokens + tokens = tokens.view(1, -1).contiguous() + tokens = tokens.to('cuda') + + # Get the masks and postition ids. + if args.block_lm: + attention_mask = torch.ones(tokens.size(1), tokens.size(1), device='cuda', dtype=torch.long) + if args.fp16: + attention_mask = attention_mask.half() + position_ids = torch.arange(tokens.size(1), device='cuda', dtype=torch.long) + if not args.no_block_position: + block_position_ids = torch.zeros(tokens.size(1), device='cuda', dtype=torch.long) + position_ids = torch.stack((position_ids, block_position_ids), dim=0) + position_ids = position_ids.unsqueeze(0) + else: + raise NotImplementedError + + return tokens, attention_mask, position_ids + + +def generate_samples(model, tokenizer, args): + model.eval() + output_path = "./samples" + if not os.path.exists(output_path): + os.makedirs(output_path) + output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt") + with torch.no_grad(), open(output_path, "w") as output: + while True: + torch.distributed.barrier(group=mpu.get_model_parallel_group()) + terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output) + if terminate_runs == 1: + return + start_time = time.time() + if args.block_lm: + mems = [] + tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args) + mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK'] + mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens] + end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id] + mask_positions = [] + for token in mask_tokens: + mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist() + mask_positions.sort() + if args.no_block_position: + for mask_position in mask_positions: + position_ids[0, mask_position + 1:] += args.out_seq_length + _, *mems = model(tokens, position_ids, attention_mask, *mems) + for mask_position in mask_positions: + if args.no_block_position: + position = position_ids[0, mask_position].item() + else: + position = mask_position + if args.num_beams > 1: + strategy = BeamSearchStrategy(num_beams=args.num_beams, max_length=args.out_seq_length, + length_penalty=args.length_penalty, end_tokens=end_tokens, + no_repeat_ngram_size=args.no_repeat_ngram_size, + min_tgt_length=args.min_tgt_length) + else: + strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, + end_tokens=end_tokens) + new_tokens, mems = filling_sequence_glm(model, tokenizer, position, strategy, args, mems=mems, + end_tokens=end_tokens) + tokens = torch.cat((tokens, new_tokens), dim=1) + output_tokens_list = tokens.view(-1).contiguous() + if mpu.get_model_parallel_rank() == 0: + os.system('clear') + print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) + print("\nContext:", raw_text, flush=True) + decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist()) + trim_decode_tokens = decode_tokens + print("\nGLM:", trim_decode_tokens, flush=True) + output.write(trim_decode_tokens + "\n") + + torch.distributed.barrier(group=mpu.get_model_parallel_group()) + + +def main(args): + initialize_distributed(args) + tokenizer = prepare_tokenizer(args) + # build model + model = GLMModel(args) + if args.fp16: + model = model.half() + model = model.to(args.device) + load_checkpoint(model, args) + set_random_seed(args.seed) + model.eval() + generate_samples(model, tokenizer, args) + + +if __name__ == "__main__": + args = get_args() + + with torch.no_grad(): + main(args) diff --git a/model/base_model.py b/model/base_model.py index 0e4b5b3dcb2d0fa8f69823d79fe70f016d0c54fd..b3495496e581783616ff151958c06035310c5546 100644 --- a/model/base_model.py +++ b/model/base_model.py @@ -14,11 +14,13 @@ import random import torch from mpu import BaseTransformer +from .mixins import BaseMixin class BaseModel(torch.nn.Module): def __init__(self, args, transformer=None): super(BaseModel, self).__init__() - self.hooks = self.collect_hooks() + self.mixins = torch.nn.ModuleDict() + self.collect_hooks_() if transformer is not None: self.transformer = transformer else: @@ -37,12 +39,25 @@ class BaseModel(torch.nn.Module): parallel_output=True, hooks=self.hooks ) - self.mixins = torch.nn.ModuleList() - def reinit(self): + def reinit(self): # will be called when loading model # if some mixins are loaded, overrides this function - for m in self.mixins: + for m in self.mixins.values(): m.reinit(self.transformer) + + def add_mixin(self, name, new_mixin, reinit=False): + assert name not in self.mixins + assert isinstance(new_mixin, BaseMixin) + + self.mixins[name] = new_mixin # will auto-register parameters + object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr + + if reinit: + new_mixin.reinit(self.transformer, **self.mixins) # also pass current mixins + self.collect_hooks_() + + def get_mixin(self, name): + return self.mixins[name] def forward(self, *args, **kwargs): # update hooks as the current model (overrided forwards) @@ -51,16 +66,28 @@ class BaseModel(torch.nn.Module): self.transformer.hooks.update(self.hooks) return self.transformer(*args, **kwargs) - def collect_hooks(self): + def collect_hooks_(self): names = ['word_embedding_forward', 'position_embedding_forward', 'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward', 'branch_embedding_forward', 'branch_final_forward' ] hooks = {} + hook_origins = {} for name in names: + for mixin_name, m in self.mixins.items(): + if hasattr(m, name): + if name in hooks: # conflict + raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.') + hooks[name] = getattr(m, name) + hook_origins[name] = mixin_name if hasattr(self, name): + # if name in hooks: # defined in mixins, can override + # print(f'Override {name} in {hook_origins[name]}...') hooks[name] = getattr(self, name) + hook_origins[name] = 'model' + self.hooks = hooks + self.hook_origins = hook_origins return hooks - + def disable_untrainable_params(self): pass \ No newline at end of file diff --git a/model/cached_autoregressive_model.py b/model/cached_autoregressive_model.py index 34702b0cb6f1135406bae1e780c4973dfa814731..97b58b2c464fe02a90752752e8440af2e5852a1c 100755 --- a/model/cached_autoregressive_model.py +++ b/model/cached_autoregressive_model.py @@ -13,15 +13,15 @@ import math import random import torch +from .mixins import BaseMixin from .base_model import BaseModel from mpu.transformer import standard_attention, split_tensor_along_last_dim -class CachedAutoregressiveModel(BaseModel): - def __init__(self, args, transformer=None): - super().__init__(args, transformer=transformer) - self.log_attention_weights = None +class CachedAutoregressiveMixin(BaseMixin): + def __init__(self): + super().__init__() - def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, **kwargs): + def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, log_attention_weights=None, **kwargs): attn_module = self.transformer.layers[layer_id].attention mem = mems[layer_id] if mems is not None else None @@ -40,7 +40,7 @@ class CachedAutoregressiveModel(BaseModel): query_layer = attn_module._transpose_for_scores(mixed_query_layer) key_layer = attn_module._transpose_for_scores(mixed_key_layer) value_layer = attn_module._transpose_for_scores(mixed_value_layer) - context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=self.log_attention_weights) + context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=log_attention_weights) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,) @@ -51,3 +51,8 @@ class CachedAutoregressiveModel(BaseModel): new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous() return output, new_mem + +class CachedAutoregressiveModel(BaseModel): + def __init__(self, args, transformer=None): + super().__init__(args, transformer=transformer) + self.add_mixin('auto-regressive', CachedAutoregressiveMixin()) diff --git a/model/cuda2d_model.py b/model/cuda2d_model.py index c8699e7a66089e6d0dc210a04a2d0c84b27cbf12..cf9867a7a41122f7f37f465631fdad8e55cc8830 100644 --- a/model/cuda2d_model.py +++ b/model/cuda2d_model.py @@ -28,10 +28,10 @@ class Cuda2dModel(BaseModel): def __init__(self, args, transformer=None): super().__init__(args, transformer=transformer) additional_seqlen = args.new_sequence_length - args.max_sequence_length - self.mixins.append(PositionEmbeddingMixin( + self.add_mixin('extra_position_embedding', PositionEmbeddingMixin( additional_seqlen, args.hidden_size )) - self.mixins.append(AttentionMixin( + self.add_mixin('attention_plus', AttentionMixin( num_layers=args.num_layers, hidden_size=args.hidden_size )) @@ -41,23 +41,24 @@ class Cuda2dModel(BaseModel): self.kernel_size2 = args.kernel_size2 self.log_attention_weights = None - def position_embedding_forward(self, position_ids, **kw_tensors): + def position_embedding_forward(self, position_ids, **kw_args): position = position_ids[..., :self.layout[1]] position_plus = position_ids[..., self.layout[1]:] position_embeddings = torch.cat( ( self.transformer.position_embeddings(position), - self.mixins[0].position_embeddings(position_plus) + self.get_mixin('extra_position_embedding').position_embeddings(position_plus) ), dim=-2 ) return position_embeddings - def attention_forward(self, hidden_states, mask, layer_id=None, **kw_tensors): + def attention_forward(self, hidden_states, mask, + layer_id=None, log_attention_weights=None, **kw_args): attn_module = self.transformer.layers[layer_id].attention # attention_plus on all layers - query_key_value_plus = self.mixins[1].query_key_value[layer_id] - dense_plus = self.mixins[1].dense[layer_id] + query_key_value_plus = self.get_mixin('attention_plus').query_key_value[layer_id] + dense_plus = self.get_mixin('attention_plus').dense[layer_id] # split two parts hidden_states_plus = hidden_states[:, self.layout[1]:] @@ -81,7 +82,7 @@ class Cuda2dModel(BaseModel): kernel_size=self.kernel_size, kernel_size2=self.kernel_size2, attention_dropout=dropout_fn, - log_attention_weights=self.log_attention_weights + log_attention_weights=log_attention_weights ) output_0 = attn_module.dense(context_layer0) diff --git a/model/glm_model.py b/model/glm_model.py index 96502d1f7751228b752806f0b80376ece9410170..30593c4866195ef4601972c72871cfabe927924e 100644 --- a/model/glm_model.py +++ b/model/glm_model.py @@ -3,16 +3,26 @@ import torch.nn as nn from .base_model import BaseModel from .cached_autoregressive_model import CachedAutoregressiveModel +from .mixins import BaseMixin - -class GLMModel(CachedAutoregressiveModel): - def __init__(self, args, transformer=None): - super().__init__(args, transformer=transformer) - self.transformer.block_position_embeddings = torch.nn.Embedding(args.max_sequence_length, args.hidden_size) - torch.nn.init.normal_(self.transformer.block_position_embeddings.weight, mean=0.0, std=0.02) - - def position_embedding_forward(self, position_ids, *other_tensors): +class BlockPositionEmbeddingMixin(BaseMixin): + def __init__(self, max_sequence_length, hidden_size, init_method_std=0.02): + super(BlockPositionEmbeddingMixin, self).__init__() + self.max_sequence_length = max_sequence_length + self.hidden_size = hidden_size + self.block_position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size) + torch.nn.init.normal_(self.block_position_embeddings.weight, mean=0.0, std=init_method_std) + + def position_embedding_forward(self, position_ids, **kwargs): position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1] position_embeddings = self.transformer.position_embeddings(position_ids) - block_position_embeddings = self.transformer.block_position_embeddings(block_position_ids) + block_position_embeddings = self.block_position_embeddings(block_position_ids) return position_embeddings + block_position_embeddings + +class GLMModel(BaseModel): + def __init__(self, args, transformer=None): + super().__init__(args, transformer=transformer) + self.add_mixin('block_position_embedding', + BlockPositionEmbeddingMixin(args.max_sequence_length, args.hidden_size) + ) + diff --git a/model/mixins.py b/model/mixins.py index 125f8b0e2c7ccda41c4125892637046a041f7186..f42b0f270d88d01a33e18d8df11b44d5fd3fbba5 100644 --- a/model/mixins.py +++ b/model/mixins.py @@ -20,9 +20,11 @@ class BaseMixin(torch.nn.Module): def __init__(self): super(BaseMixin, self).__init__() # define new params - def reinit(self, transformer, *pre_mixins): + def reinit(self, *pre_mixins): # reload the initial params from previous trained modules pass + # can also define hook-functions here + # ... class PositionEmbeddingMixin(BaseMixin): def __init__(self, additional_sequence_length, hidden_size, @@ -32,8 +34,8 @@ class PositionEmbeddingMixin(BaseMixin): self.reinit_slice = reinit_slice self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size) torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) - def reinit(self, transformer, *pre_mixins): - old_weights = transformer.position_embeddings.weight.data[self.reinit_slice] + def reinit(self, *pre_mixins): + old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice] old_len, hidden_size = old_weights.shape assert hidden_size == self.position_embeddings.weight.shape[-1] self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights) @@ -58,11 +60,11 @@ class AttentionMixin(BaseMixin): init_method=output_layer_init_method) for layer_id in range(num_layers) ]) - def reinit(self, transformer, *pre_mixins): - start_layer = len(transformer.layers) - self.num_layers + def reinit(self, *pre_mixins): + start_layer = len(self.transformer.layers) - self.num_layers assert start_layer >= 0 for layer_id in range(self.num_layers): - old_attention = transformer.layers[start_layer + layer_id].attention + old_attention = self.transformer.layers[start_layer + layer_id].attention self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data) self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data) self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data) diff --git a/mpu/transformer.py b/mpu/transformer.py index cdc34a84fc0681178f7795cf9d3a43b1d92b257b..2cf38521f49b7f8ee207863fb90926c06e622ba2 100755 --- a/mpu/transformer.py +++ b/mpu/transformer.py @@ -111,9 +111,9 @@ class SelfAttention(torch.nn.Module): tensor = tensor.view(*new_tensor_shape) return tensor.permute(0, 2, 1, 3) - def forward(self, hidden_states, mask, **kw_tensors): + def forward(self, hidden_states, mask, **kw_args): if 'attention_forward' in self.hooks: - return self.hooks['attention_forward'](hidden_states, mask, **kw_tensors, layer_id=self.layer_id) + return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id) else: mixed_raw_layer = self.query_key_value(hidden_states) (mixed_query_layer, @@ -162,9 +162,9 @@ class MLP(torch.nn.Module): ) self.dropout = torch.nn.Dropout(output_dropout_prob) - def forward(self, hidden_states, **kw_tensors): + def forward(self, hidden_states, **kw_args): if 'mlp_forward' in self.hooks: - output = self.hooks['mlp_forward'](hidden_states, **kw_tensors, layer_id=self.layer_id) + output = self.hooks['mlp_forward'](hidden_states, **kw_args, layer_id=self.layer_id) else: intermediate_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel = gelu(intermediate_parallel) @@ -227,7 +227,7 @@ class BaseTransformerLayer(torch.nn.Module): hooks=hooks ) - def forward(self, hidden_states, mask, **kw_tensors): + def forward(self, hidden_states, mask, **kw_args): ''' hidden_states: [batch, seq_len, hidden_size] mask: [(1, 1), seq_len, seq_len] @@ -236,7 +236,7 @@ class BaseTransformerLayer(torch.nn.Module): # Layer norm at the begining of the transformer layer. layernorm_output1 = self.input_layernorm(hidden_states) # Self attention. - attention_output, output_this_layer = self.attention(layernorm_output1, mask, **kw_tensors) + attention_output, output_this_layer = self.attention(layernorm_output1, mask, **kw_args) # Third LayerNorm if self.sandwich_ln: @@ -247,7 +247,7 @@ class BaseTransformerLayer(torch.nn.Module): # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. - mlp_output = self.mlp(layernorm_output, **kw_tensors) + mlp_output = self.mlp(layernorm_output, **kw_args) # Fourth LayerNorm if self.sandwich_ln: @@ -316,27 +316,25 @@ class BaseTransformer(torch.nn.Module): # Final layer norm before output. self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) - def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, **kw_tensors): + def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, **kw_args): # sanity check assert len(input_ids.shape) == 2 batch_size, query_length = input_ids.shape assert len(attention_mask.shape) == 2 or \ len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1 assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor) - for k, v in kw_tensors.items(): - assert isinstance(v, torch.Tensor) # branch_input is a new part of input need layer-by-layer update, # but with different hidden_dim and computational routine. # In most cases, you can just ignore it. # embedding part if 'word_embedding_forward' in self.hooks: - hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_tensors) + hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args) else: # default hidden_states = self.word_embeddings(input_ids) if 'position_embedding_forward' in self.hooks: - position_embeddings = self.hooks['position_embedding_forward'](position_ids, **kw_tensors) + position_embeddings = self.hooks['position_embedding_forward'](position_ids, **kw_args) else: assert len(position_ids.shape) <= 2 assert position_ids.shape[-1] == query_length @@ -346,7 +344,7 @@ class BaseTransformer(torch.nn.Module): # branch related embedding if branch_input is None and 'branch_embedding_forward' in self.hooks: - branch_input = self.hooks['branch_embedding_forward'](branch_input, **kw_tensors) + branch_input = self.hooks['branch_embedding_forward'](branch_input, **kw_args) # define custom_forward for checkpointing output_per_layers = [] @@ -361,14 +359,14 @@ class BaseTransformer(torch.nn.Module): for i, layer in enumerate(layers_): if len(inputs) > 2: x_, branch_, output_this_layer = self.hooks['layer_forward']( - x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_tensors + x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args ) elif 'layer_forward' in self.hooks: x_, output_this_layer = self.hooks['layer_forward']( - x_, mask, layer_id=layer.layer_id, **kw_tensors + x_, mask, layer_id=layer.layer_id, **kw_args ) else: - x_, output_this_layer = layer(x_, mask, **kw_tensors) + x_, output_this_layer = layer(x_, mask, **kw_args) output_per_layers_part.append(output_this_layer) return x_, output_per_layers_part return custom_forward @@ -387,25 +385,25 @@ class BaseTransformer(torch.nn.Module): for i, layer in enumerate(self.layers): args = [hidden_states, attention_mask] if branch_input is not None: # customized layer_forward with branch_input - hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_tensors) + hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_args) elif 'layer_forward' in self.hooks: # customized layer_forward - hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_tensors) + hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args) else: - hidden_states, output_this_layer = layer(*args, **kw_tensors) + hidden_states, output_this_layer = layer(*args, **kw_args) output_per_layers.append(output_this_layer) # Final layer norm. logits = self.final_layernorm(hidden_states) if 'final_forward' in self.hooks: - logits_parallel = self.hooks['final_forward'](logits, **kw_tensors) + logits_parallel = self.hooks['final_forward'](logits, **kw_args) else: logits_parallel = copy_to_model_parallel_region(logits) logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight) # branch related embedding if branch_input is None and 'branch_final_forward' in self.hooks: - branch_input = self.hooks['branch_final_forward'](branch_input, **kw_tensors) + branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args) if self.parallel_output: logits_parallel = gather_from_model_parallel_region(logits_parallel) diff --git a/scripts/finetune_into_cogview2.sh b/scripts/finetune_into_cogview2.sh index 3bdf5ea4b78e71165750b9760d2ce2a58f447ae5..1be6d8ff24309d4258cd66269f21f1e02ac92f7c 100755 --- a/scripts/finetune_into_cogview2.sh +++ b/scripts/finetune_into_cogview2.sh @@ -50,7 +50,7 @@ gpt_options="${gpt_options} --deepspeed \ --deepspeed_config ${config_json} \ " - + run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_cogview2.py $@ ${gpt_options}" echo ${run_cmd} diff --git a/scripts/generate_glm.sh b/scripts/generate_glm.sh old mode 100644 new mode 100755 index 143d7b5baa08029075b6c360f1b252a445b5a686..e007900381dd94deb395f94b401ab46a81aa789e --- a/scripts/generate_glm.sh +++ b/scripts/generate_glm.sh @@ -1,7 +1,28 @@ #!/bin/bash -CHECKPOINT_PATH=/dataset/fd5061f6/english_data/checkpoints +CHECKPOINT_PATH=pretrained/glm -source $1 +# MODEL_ARGS="--block-lm \ +# --cloze-eval \ +# --num-layers 24 \ +# --hidden-size 1024 \ +# --num-attention-heads 16 \ +# --max-sequence-length 513 \ +# --tokenizer-model-type roberta \ +# --tokenizer-type glm_GPT2BPETokenizer \ +# --load ${CHECKPOINT_PATH}/glm-roberta-large-blank" + +MODEL_TYPE="blocklm-10B" +MODEL_ARGS="--block-lm \ + --cloze-eval \ + --task-mask \ + --num-layers 48 \ + --hidden-size 4096 \ + --num-attention-heads 64 \ + --max-sequence-length 1025 \ + --tokenizer-model-type gpt2 \ + --tokenizer-type glm_GPT2BPETokenizer \ + --old-checkpoint \ + --load ${CHECKPOINT_PATH}/glm-en-10b" MPSIZE=1 MAXSEQLEN=512 @@ -29,4 +50,7 @@ python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTE --out-seq-length $MAXSEQLEN \ --temperature $TEMP \ --top_k $TOPK \ - --top_p $TOPP + --output-path glm_text \ + --batch-size 1 \ + --out-seq-length 100 \ + --mode inference diff --git a/scripts/pretrain_multiple_nodes.sh b/scripts/pretrain_multiple_nodes.sh index 5a09f7712528974023124dd510d3762623e7c8d9..c1040a3bc54649742d75be71943a0763438f6f19 100755 --- a/scripts/pretrain_multiple_nodes.sh +++ b/scripts/pretrain_multiple_nodes.sh @@ -19,17 +19,17 @@ small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/ziji config_json="$script_dir/ds_config_zero.json" gpt_options=" \ - --experiment-name pretrain-gpt2-cogview-test \ + --experiment-name pretrain-gpt2-cogview-small \ --tokenizer-type cogview \ --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ --model-parallel-size ${MP_SIZE} \ --mode pretrain \ - --num-layers 12 \ - --hidden-size 1024 \ - --num-attention-heads 16 \ + --num-layers 40 \ + --hidden-size 2048 \ + --num-attention-heads 32 \ --train-iters 200000 \ --resume-dataloader \ - --train-data ${small_data} \ + --train-data ${full_data} \ --split 949,50,1 \ --distributed-backend nccl \ --lr-decay-style cosine \ @@ -38,9 +38,9 @@ gpt_options=" \ --max-sequence-length 1089 \ --sandwich-ln \ --fp16 \ - --save-interval 2000 \ + --save-interval 5000 \ --eval-interval 1000 \ - --save $main_dir/checkpoints \ + --save /root/checkpoints \ " # --load pretrained/cogview/cogview-base diff --git a/scripts/text2image_cogview.sh b/scripts/text2image_cogview.sh index f6bea9668c57024c5f34974dd03d6e8ad76336a1..bcb1ecd9847a2e2e39c6f00c34c89518edee61ee 100755 --- a/scripts/text2image_cogview.sh +++ b/scripts/text2image_cogview.sh @@ -33,7 +33,7 @@ MASTER_PORT=${MASTER_PORT} python inference_cogview.py \ --sandwich-ln \ --input-source ./input.txt \ --output-path samples_text2image \ - --batch-size 8 \ + --batch-size 4 \ --max-inference-batch-size 8 \ $@ diff --git a/tokenization/__init__.py b/tokenization/__init__.py index f465ec4713f52132a32f198731a01825228613fc..427d98eb5e1dc4494b00f51fd9913b960874bd82 100644 --- a/tokenization/__init__.py +++ b/tokenization/__init__.py @@ -15,8 +15,6 @@ import torch def get_tokenizer(args=None): - kwargs = {"add_block_symbols": args.block_lm, "add_task_mask": args.task_mask, - "add_decoder_mask": args.block_mask_prob > 0.0 or args.context_mask_ratio > 0.0} if not hasattr(get_tokenizer, 'tokenizer'): # the first time to load the tokenizer if args.tokenizer_type == 'cogview': @@ -25,15 +23,18 @@ def get_tokenizer(args=None): args.img_tokenizer_path, device=torch.cuda.current_device() ) - elif args.tokenizer_type == "BertWordPieceTokenizer": - from .text import BertWordPieceTokenizer - get_tokenizer.tokenizer = BertWordPieceTokenizer(args.tokenizer_model_type, **kwargs) - elif args.tokenizer_type == "GPT2BPETokenizer": - from .text import GPT2BPETokenizer - get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs) - elif args.tokenizer_type == "ChineseSPTokenizer": - from .text import ChineseSPTokenizer - get_tokenizer.tokenizer = ChineseSPTokenizer(**kwargs) + elif args.tokenizer_type.startswith('glm_'): + kwargs = {"add_block_symbols": args.block_lm, "add_task_mask": args.task_mask, + "add_decoder_mask": False} #args.block_mask_prob > 0.0 or args.context_mask_ratio > 0.0} + if args.tokenizer_type == "glm_BertWordPieceTokenizer": + from .text import BertWordPieceTokenizer + get_tokenizer.tokenizer = BertWordPieceTokenizer(args.tokenizer_model_type, **kwargs) + elif args.tokenizer_type == "glm_GPT2BPETokenizer": + from .text import GPT2BPETokenizer + get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs) + elif args.tokenizer_type == "glm_ChineseSPTokenizer": + from .text import ChineseSPTokenizer + get_tokenizer.tokenizer = ChineseSPTokenizer(**kwargs) else: assert args.vocab_size > 0 get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size) diff --git a/tokenization/text/tokenization_gpt2.py b/tokenization/text/tokenization_gpt2.py index 318d9209990b74b9aaadc13d1f482923c77bc3ab..8782fe9cb74fc95dc27e3e15cce2339a655ecd80 100644 --- a/tokenization/text/tokenization_gpt2.py +++ b/tokenization/text/tokenization_gpt2.py @@ -36,12 +36,12 @@ from ..file_utils import cached_path logger = logging.getLogger(__name__) PRETRAINED_VOCAB_ARCHIVE_MAP = { - 'gpt2': "pretrained/pytorch_pretrained_bert/gpt2-vocab.json", - "roberta": "pretrained/pytorch_pretrained_bert/roberta-vocab.json" + 'gpt2': "pretrained/english_tokenizer/gpt2-vocab.json", + "roberta": "pretrained/english_tokenizer/roberta-vocab.json" } PRETRAINED_MERGES_ARCHIVE_MAP = { - 'gpt2': "pretrained/pytorch_pretrained_bert/gpt2-merges.txt", - "roberta": "pretrained/pytorch_pretrained_bert/roberta-merges.txt" + 'gpt2': "pretrained/english_tokenizer/gpt2-merges.txt", + "roberta": "pretrained/english_tokenizer/roberta-merges.txt" } PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 'gpt2': 1024, diff --git a/training/model_io.py b/training/model_io.py index 1655ed20999230346b0dd20c7c547c52a38fb246..18151863ae66a0d6a5e640f71c875d8a61cea92b 100644 --- a/training/model_io.py +++ b/training/model_io.py @@ -131,11 +131,6 @@ def load_checkpoint(model, args): else: # inference without deepspeed module = model - # Process the checkpoint for GLM - if args.block_lm and args.old_checkpoint: - sd['module']['transformer.word_embeddings.weight'] = sd['module']['word_embeddings.weight'] - del sd['module']['word_embeddings.weight'] - # only load module, other hyperparameters are just for recording. missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False) if len(unexpected_keys) > 0: