diff --git a/CHANGE_LOG.md b/CHANGE_LOG.md index 86ff7d8bf584558c97251c2b24ae4e65cde4c246..0daa8ae5059ac8cde69f321f95774ca80b331f5a 100644 --- a/CHANGE_LOG.md +++ b/CHANGE_LOG.md @@ -1,4 +1,4 @@ -# 2021.10.29 +# 2021.10.29 v0.1 1. change `mixins` from `ModuleList` to `ModuleDict` 2. return tokens and mems in `fill_sequence`, and mems becomes a tensor. 3. `CachedAutoRegressiveMixin` @@ -28,5 +28,8 @@ for the older framework, you also need: old['module']['transformer.word_embeddings.weight'] = old['module']['word_embeddings.weight'] del old['module']['word_embeddings.weight'] ``` +# 2021.11.5 v0.1.2 +1. Add generation.autoregressive_sampling.evalute_perplexity +2. fix Runtime Error in skipping Nan Loss diff --git a/README.md b/README.md index d0aaa05592a06cce0d0c067bc7379d3aa14e60b4..16472dea64f607025fb7abefcfb8e7aebda98e3f 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,12 @@ ``` ### Run GLM +1. Prepare input.txt. Example: "Welcome! This is the main page of SwissArmyTransformer". +2. Run the following commands: ``` cd examples/glm ./scripts/generate_glm.sh config/model_glm_10B_chinese.sh -``` \ No newline at end of file +``` + +Output: +[CLS]Welcome! This is the main page of SwissArmyTransformer. It is a comprehensive and clear explanation of the technical problems in the transformer. It is also an introduction to the development of the SwissArmy transformers. Welcome to Swiss Army Transforters. This is the main page of Swiss army tranforter. It's a complete and clean explaination of technology problem in the Tranformer, which is an integral part of the army's technological development. It also anintroduction of the developments of the Army technicians. Well, if you have any questions, please feel free to contact the official webs diff --git a/SwissArmyTransformer/generation/autoregressive_sampling.py b/SwissArmyTransformer/generation/autoregressive_sampling.py index e71ba6bce05966bf153be0966edfd26e63277a4a..37a24a88a004bfff63a87ebdcf3c909b29256a20 100644 --- a/SwissArmyTransformer/generation/autoregressive_sampling.py +++ b/SwissArmyTransformer/generation/autoregressive_sampling.py @@ -3,7 +3,7 @@ @File : autoregressive_sampling.py @Time : 2021/10/08 15:43:59 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/SwissArmyTransformer/generation/cuda2d_sampling.py b/SwissArmyTransformer/generation/cuda2d_sampling.py index c19154012ad0c8be6a8369f0e01af0dea09e140a..cc77208bc4b688e3107146671564045c5349c239 100644 --- a/SwissArmyTransformer/generation/cuda2d_sampling.py +++ b/SwissArmyTransformer/generation/cuda2d_sampling.py @@ -3,7 +3,7 @@ @File : cuda2d_sampling.py @Time : 2021/10/09 00:46:04 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/SwissArmyTransformer/generation/sampling_strategies/base_strategy.py b/SwissArmyTransformer/generation/sampling_strategies/base_strategy.py index a5268e5ab4db13e6c8ede944f81291505560e4c6..2d8e4a6afd1acd014690b74d05d257c20f042b8d 100644 --- a/SwissArmyTransformer/generation/sampling_strategies/base_strategy.py +++ b/SwissArmyTransformer/generation/sampling_strategies/base_strategy.py @@ -3,7 +3,7 @@ @File : base_strategy.py @Time : 2021/10/08 22:22:42 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py b/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py index 88b5c41f20eed71a98bf00d9eaf0e9dcc0110d03..ccaafdff34fdb06eb3947d1bf2593890680e9017 100644 --- a/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py +++ b/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py @@ -3,7 +3,7 @@ @File : base_strategy.py @Time : 2021/10/08 22:22:42 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy_old.py b/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy_old.py deleted file mode 100644 index aeab7989930250aa93b3231399152cd514b4162b..0000000000000000000000000000000000000000 --- a/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy_old.py +++ /dev/null @@ -1,467 +0,0 @@ -# -*- encoding: utf-8 -*- -''' -@File : base_strategy.py -@Time : 2021/10/08 22:22:42 -@Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn -''' - -# here put the import lib -import torch -import torch.nn.functional as F -from abc import ABC, abstractmethod -from collections import UserDict -from typing import Optional, Tuple, List, Iterable, Union - - -class BeamScorer(ABC): - """ - Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and - :meth:`~transformers.PretrainedModel.beam_sample`. - """ - - @abstractmethod - def process( - self, - input_ids: torch.LongTensor, - next_scores: torch.FloatTensor, - next_tokens: torch.LongTensor, - next_indices: torch.LongTensor, - **kwargs - ) -> Tuple[torch.Tensor]: - raise NotImplementedError("This is an abstract method.") - - @abstractmethod - def finalize( - self, - input_ids: torch.LongTensor, - next_scores: torch.FloatTensor, - next_tokens: torch.LongTensor, - next_indices: torch.LongTensor, - **kwargs - ) -> torch.LongTensor: - raise NotImplementedError("This is an abstract method.") - - -class BeamSearchScorer(BeamScorer): - r""" - :class:`transformers.BeamScorer` implementing standard beam search decoding. - - Adapted in part from `Facebook's XLM beam search code - <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__. - - Args: - batch_size (:obj:`int`): - Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel. - max_length (:obj:`int`): - The maximum length of the sequence to be generated. - num_beams (:obj:`int`): - Number of beams for beam search. - device (:obj:`torch.device`): - Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of - :obj:`BeamSearchScorer` will be allocated. - length_penalty (:obj:`float`, `optional`, defaults to 1.0): - Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the - model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer - sequences. - do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. - num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): - The number of beam hypotheses that shall be returned upon calling - :meth:`~transformer.BeamSearchScorer.finalize`. - """ - - def __init__( - self, - batch_size: int, - max_length: int, - num_beams: int, - device: Union[torch.device, str], - length_penalty: Optional[float] = 1.0, - do_early_stopping: Optional[bool] = False, - num_beam_hyps_to_keep: Optional[int] = 1, - ): - self.max_length = max_length - self.num_beams = num_beams - self.device = device - self.length_penalty = length_penalty - self.do_early_stopping = do_early_stopping - self.num_beam_hyps_to_keep = num_beam_hyps_to_keep - - self._is_init = False - self._beam_hyps = [ - BeamHypotheses( - num_beams=self.num_beams, - max_length=self.max_length, - length_penalty=self.length_penalty, - early_stopping=self.do_early_stopping, - ) - for _ in range(batch_size) - ] - self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) - - # if not isinstance(num_beams, int) or num_beams <= 1: - # raise ValueError( - # f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." - # ) - - @property - def is_done(self) -> bool: - return self._done.all() - - def process( - self, - input_ids: torch.LongTensor, - next_scores: torch.FloatTensor, - next_tokens: torch.LongTensor, - next_indices: torch.LongTensor, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - mems=None - ) -> Tuple[torch.Tensor]: - cur_len = input_ids.shape[-1] - batch_size = len(self._beam_hyps) - assert batch_size == (input_ids.shape[0] // self.num_beams) - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - device = next_scores.device - next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device) - next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device) - next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device) - - for batch_idx, beam_hyp in enumerate(self._beam_hyps): - if self._done[batch_idx]: - assert ( - len(beam_hyp) >= self.num_beams - ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams) - assert ( - eos_token_id is not None and pad_token_id is not None - ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" - # pad the batch - next_beam_scores[batch_idx, :] = 0 - next_beam_tokens[batch_idx, :] = pad_token_id - next_beam_indices[batch_idx, :] = 0 - continue - - # next tokens for this sentence - beam_idx = 0 - for beam_token_rank, (next_token, next_score, next_index) in enumerate( - zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) - ): - batch_beam_idx = batch_idx * self.num_beams + next_index - # add to generated hypotheses if end of sentence - if (eos_token_id is not None) and (next_token.item() in eos_token_id): - # if beam_token does not belong to top num_beams tokens, it should not be added - is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams - if is_beam_token_worse_than_top_num_beams: - continue - beam_hyp.add( - input_ids[batch_beam_idx].clone(), - next_score.item(), - mems=[mem[[next_index.item()]] for mem in mems] if mems else None - ) - else: - # add next predicted token since it is not eos_token - next_beam_scores[batch_idx, beam_idx] = next_score - next_beam_tokens[batch_idx, beam_idx] = next_token - next_beam_indices[batch_idx, beam_idx] = batch_beam_idx - beam_idx += 1 - - # once the beam for next step is full, don't add more tokens to it. - if beam_idx == self.num_beams: - break - - if beam_idx < self.num_beams: - raise ValueError( - f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." - ) - - # Check if we are done so that we can save a pad step if all(done) - self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( - next_scores[batch_idx].max().item(), cur_len - ) - - return UserDict( - { - "next_beam_scores": next_beam_scores.view(-1), - "next_beam_tokens": next_beam_tokens.view(-1), - "next_beam_indices": next_beam_indices.view(-1), - } - ) - - def finalize( - self, - input_ids: torch.LongTensor, - final_beam_scores: torch.FloatTensor, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - mems=None - ) -> Tuple[torch.LongTensor, List[torch.Tensor], torch.FloatTensor]: - batch_size = len(self._beam_hyps) - # finalize all open beam hypotheses and add to generated hypotheses - for batch_idx, beam_hyp in enumerate(self._beam_hyps): - if self._done[batch_idx]: - continue - - # need to add best num_beams hypotheses to generated hyps - for beam_id in range(self.num_beams): - batch_beam_idx = batch_idx * self.num_beams + beam_id - final_score = final_beam_scores[batch_beam_idx].item() - final_tokens = input_ids[batch_beam_idx] - beam_hyp.add(final_tokens, final_score, mems=[mem[[batch_beam_idx]] for mem in mems] if mems else None) - - # select the best hypotheses - sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) - best = [] - - # retrieve best hypotheses - for i, beam_hyp in enumerate(self._beam_hyps): - sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) - for j in range(self.num_beam_hyps_to_keep): - score, best_hyp, mems = sorted_hyps.pop() - sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) - best.append((best_hyp, mems, score)) - - # prepare for adding eos - sent_max_len = min(sent_lengths.max().item(), self.max_length) - decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) - scores = final_beam_scores.new(batch_size * self.num_beam_hyps_to_keep) - # shorter batches are padded if needed - if sent_lengths.min().item() != sent_lengths.max().item(): - assert pad_token_id is not None, "`pad_token_id` has to be defined" - decoded.fill_(pad_token_id) - - # fill with hypotheses and eos_token_id if the latter fits in - mems = [] - for i, (hypo, mem, score) in enumerate(best): - scores[i] = score - decoded[i, : sent_lengths[i]] = hypo - if sent_lengths[i] < sent_max_len: - decoded[i, sent_lengths[i]] = eos_token_id - mems.append(mem) - mems = [torch.cat([mem[i] for mem in mems], dim=0) for i in range(len(mems[0]))] if mems and mems[0] else None - return decoded, mems, scores - - -class BeamHypotheses: - def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool): - """ - Initialize n-best list of hypotheses. - """ - self.max_length = max_length - 1 # ignoring bos_token - self.length_penalty = length_penalty - self.early_stopping = early_stopping - self.num_beams = num_beams - self.beams = [] - self.worst_score = 1e9 - - def __len__(self): - """ - Number of hypotheses in the list. - """ - return len(self.beams) - - def add(self, hyp: torch.LongTensor, sum_logprobs: float, mems=None): - """ - Add a new hypothesis to the list. - """ - score = sum_logprobs / (max(hyp.shape[-1], 1) ** self.length_penalty) - if len(self) < self.num_beams or score > self.worst_score: - self.beams.append((score, hyp, mems)) - if len(self) > self.num_beams: - sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)]) - del self.beams[sorted_next_scores[0][1]] - self.worst_score = sorted_next_scores[1][0] - else: - self.worst_score = min(score, self.worst_score) - - def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: - """ - If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst - one in the heap, then we are done with this sentence. - """ - - if len(self) < self.num_beams: - return False - elif self.early_stopping: - return True - else: - cur_score = best_sum_logprobs / cur_len ** self.length_penalty - ret = self.worst_score >= cur_score - return ret - - -class LogitsProcessor(ABC): - """Abstract base class for all logit processors that can be applied during generation.""" - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - """Torch method for processing logits.""" - raise NotImplementedError( - f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." - ) - - -class LogitsProcessorList(list): - """ - This class can be used to create a list of :class:`~transformers.LogitsProcessor` or - :class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from - list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or - :class:`~transformers.LogitsProcessor` to the inputs. - """ - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - for processor in self: - scores = processor(input_ids, scores) - return scores - - -class MinLengthLogitsProcessor(LogitsProcessor): - r""" - :class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0. - - Args: - min_length (:obj:`int`): - The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`. - eos_token_id (:obj:`int`): - The id of the `end-of-sequence` token. - """ - - def __init__(self, min_length: int, eos_token_ids: Union[List[int], int]): - if not isinstance(min_length, int) or min_length < 0: - raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") - if isinstance(eos_token_ids, int): - eos_token_ids = [eos_token_ids] - for eos_token_id in eos_token_ids: - if not isinstance(eos_token_id, int) or eos_token_id < 0: - raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") - - self.min_length = min_length - self.eos_token_ids = eos_token_ids - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - cur_len = input_ids.shape[-1] - if cur_len < self.min_length: - for eos_token_id in self.eos_token_ids: - scores[:, eos_token_id] = -float("inf") - return scores - - -class NoRepeatNGramLogitsProcessor(LogitsProcessor): - r""" - :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq - <https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__. - - Args: - ngram_size (:obj:`int`): - All ngrams of size :obj:`ngram_size` can only occur once. - """ - - def __init__(self, ngram_size: int): - if not isinstance(ngram_size, int) or ngram_size <= 0: - raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") - self.ngram_size = ngram_size - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - num_batch_hypotheses = scores.shape[0] - cur_len = input_ids.shape[-1] - banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len) - - for i, banned_tokens in enumerate(banned_batch_tokens): - scores[i, banned_tokens] = -float("inf") - - return scores - - def _calc_banned_ngram_tokens( - self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int - ) -> List[Iterable[int]]: - """Copied from fairseq for no_repeat_ngram in beam_search""" - if cur_len + 1 < self.ngram_size: - # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet - return [[] for _ in range(num_hypos)] - generated_ngrams = [{} for _ in range(num_hypos)] - for idx in range(num_hypos): - gen_tokens = prev_input_ids[idx].tolist() - generated_ngram = generated_ngrams[idx] - for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]): - prev_ngram_tuple = tuple(ngram[:-1]) - generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] - - def _get_generated_ngrams(hypo_idx): - # Before decoding the next token, prevent decoding of ngrams that have already appeared - start_idx = cur_len + 1 - self.ngram_size - ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) - return generated_ngrams[hypo_idx].get(ngram_idx, []) - - banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] - return banned_tokens - - -class BeamSearchStrategy: - def __init__(self, num_beams, max_length, length_penalty, end_tokens, device='cuda', no_repeat_ngram_size=0, - min_tgt_length=0): - self.num_beams = num_beams - self.max_length = max_length - self.length_penalty = length_penalty - self.end_tokens = end_tokens - self.no_repeat_ngram_size = no_repeat_ngram_size - self.min_tgt_length = min_tgt_length - self.processors = LogitsProcessorList() - if min_tgt_length > 0: - processor = MinLengthLogitsProcessor(min_tgt_length, self.end_tokens) - self.processors.append(processor) - if no_repeat_ngram_size > 0: - processor = NoRepeatNGramLogitsProcessor(no_repeat_ngram_size) - self.processors.append(processor) - self.beam_scorer = BeamSearchScorer( - batch_size=1, - max_length=max_length, - num_beams=num_beams, - device=device, - length_penalty=length_penalty, - do_early_stopping=False, - ) - self.beam_scores = torch.zeros(1, dtype=torch.float, device=device) - - @property - def is_done(self) -> bool: - return self.beam_scorer.is_done - - def forward(self, logits, tokens, mems): - last_beam_num = tokens.size(0) - logits = self.processors(tokens, logits.float()) - next_token_scores = F.log_softmax(logits, dim=-1) - next_token_scores = next_token_scores + self.beam_scores[:, None].expand_as(next_token_scores) - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view(1, last_beam_num * vocab_size) - - probs = F.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=2 * self.num_beams) - next_token_scores = torch.gather(next_token_scores, -1, next_tokens) - next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) - next_tokens = torch.gather(next_tokens, -1, _indices) - - next_indices = next_tokens // vocab_size - next_tokens = next_tokens % vocab_size - # stateless - tokens = tokens.expand((self.num_beams, -1)) - beam_outputs = self.beam_scorer.process( - tokens, - next_token_scores, - next_tokens, - next_indices, - eos_token_id=self.end_tokens, - mems=mems - ) - self.beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - beam_next_tokens = beam_next_tokens.unsqueeze(-1) - tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], dim=-1) - mems = [mem[beam_idx] for mem in mems] if mems else None - return tokens, mems - - def finalize(self, tokens, mems): - tokens, mems, scores = self.beam_scorer.finalize(tokens, self.beam_scores, - eos_token_id=self.end_tokens[0], - mems=mems) - return tokens, mems diff --git a/SwissArmyTransformer/generation/sampling_strategies/iterative_entfilter_strategy.py b/SwissArmyTransformer/generation/sampling_strategies/iterative_entfilter_strategy.py index 9017a42a3c4dc5092be7c57f64b81e1ef0f8e0bf..7b16ef1d14fabb9653c4f7df0548fea6d8f537e6 100644 --- a/SwissArmyTransformer/generation/sampling_strategies/iterative_entfilter_strategy.py +++ b/SwissArmyTransformer/generation/sampling_strategies/iterative_entfilter_strategy.py @@ -3,7 +3,7 @@ @File : iterative_entfilter_strategy.py @Time : 2021/10/09 14:32:29 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/SwissArmyTransformer/generation/utils.py b/SwissArmyTransformer/generation/utils.py index 01c05726a682e2245575517ba18d3ca27deea9c7..7d947dbf3e774a14e9343657c38d5ee9fb4d5dc2 100644 --- a/SwissArmyTransformer/generation/utils.py +++ b/SwissArmyTransformer/generation/utils.py @@ -3,7 +3,7 @@ @File : utils.py @Time : 2021/10/09 17:18:26 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/SwissArmyTransformer/model/__init__.py b/SwissArmyTransformer/model/__init__.py index c9546c8176b1c11320e764c2444c4f2e7f4e6965..32f46e4bce6d09b3f4780f9adbef6120960158be 100755 --- a/SwissArmyTransformer/model/__init__.py +++ b/SwissArmyTransformer/model/__init__.py @@ -2,3 +2,4 @@ from .base_model import BaseModel from .cached_autoregressive_model import CachedAutoregressiveModel from .cuda2d_model import Cuda2dModel from .glm_model import GLMModel +from .encoder_decoder_model import EncoderDecoderModel \ No newline at end of file diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py index 4a19cb174c4980d1980c49bdf801fd3a6b63e4ec..c9e1c9017782b073546fa12423690b71888a51b0 100644 --- a/SwissArmyTransformer/model/base_model.py +++ b/SwissArmyTransformer/model/base_model.py @@ -3,7 +3,7 @@ @File : base_model.py @Time : 2021/10/01 22:40:33 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/SwissArmyTransformer/model/cached_autoregressive_model.py b/SwissArmyTransformer/model/cached_autoregressive_model.py index 769ff57d3ae829d9c307747ec8b961634c977d4c..a0e1699e923e9002e7a46ac69692ae7934032a57 100755 --- a/SwissArmyTransformer/model/cached_autoregressive_model.py +++ b/SwissArmyTransformer/model/cached_autoregressive_model.py @@ -3,7 +3,7 @@ @File : cached_autoregressive_model.py @Time : 2021/10/02 01:36:24 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/SwissArmyTransformer/model/common_layers.py b/SwissArmyTransformer/model/common_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..90d24334511192b727ed0a0b4eeff588b8fdb960 --- /dev/null +++ b/SwissArmyTransformer/model/common_layers.py @@ -0,0 +1,91 @@ +# -*- encoding: utf-8 -*- +''' +@File : components.py +@Time : 2021/11/23 18:20:22 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim +from SwissArmyTransformer.mpu.transformer import standard_attention, LayerNorm + +class CrossAttention(torch.nn.Module): + def __init__(self, hidden_size, num_attention_heads, + attention_dropout_prob, output_dropout_prob, + init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None): + super(CrossAttention, self).__init__() + # Set output layer initialization if not provided. + if output_layer_init_method is None: + output_layer_init_method = init_method + if inner_hidden_size is None: + inner_hidden_size = hidden_size + self.inner_hidden_size = inner_hidden_size + if enc_hidden_size is None: + enc_hidden_size = hidden_size + self.enc_hidden_size = enc_hidden_size + + # To make user understand better, temporally not support model parallel + world_size = 1 + self.hidden_size_per_partition = divide(hidden_size, world_size) + self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) + self.num_attention_heads_per_partition = divide(num_attention_heads, world_size) + + # To map encoder outputs + self.kv_linear = torch.nn.Linear( + enc_hidden_size, inner_hidden_size * 2 + ) + init_method(self.kv_linear.weight) + + # To map self + self.q_linear = torch.nn.Linear( + hidden_size, inner_hidden_size + ) + init_method(self.q_linear.weight) + + self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) + + self.dense = torch.nn.Linear( + inner_hidden_size, + hidden_size, + ) + output_layer_init_method(self.dense.weight) + self.output_dropout = torch.nn.Dropout(output_dropout_prob) + + + def _transpose_for_scores(self, tensor): + """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with + size [b, np, s, hn]. + """ + new_tensor_shape = tensor.size()[:-1] + \ + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + tensor = tensor.view(*new_tensor_shape) + return tensor.permute(0, 2, 1, 3) + + def forward(self, hidden_states, mask, encoder_outputs, **kw_args): + + query_layer = self.q_linear(hidden_states) + key_layer, value_layer = split_tensor_along_last_dim(self.kv_linear(encoder_outputs), 2) + + dropout_fn = self.attention_dropout if self.training else None + + query_layer = self._transpose_for_scores(query_layer) + key_layer = self._transpose_for_scores(key_layer) + value_layer = self._transpose_for_scores(value_layer) + + context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + output = self.dense(context_layer) + + if self.training: + output = self.output_dropout(output) + + return output diff --git a/SwissArmyTransformer/model/cuda2d_model.py b/SwissArmyTransformer/model/cuda2d_model.py index ab3411f9d6b7bebe2c1913d1a16da76ab3821197..cda027c5d8bcdb46cb0d343f4c0a0338dfccb5e0 100644 --- a/SwissArmyTransformer/model/cuda2d_model.py +++ b/SwissArmyTransformer/model/cuda2d_model.py @@ -3,7 +3,7 @@ @File : cuda2d_model.py @Time : 2021/10/02 01:36:32 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e70245368e68e6d43ee06d97a959868194b712d6 --- /dev/null +++ b/SwissArmyTransformer/model/encoder_decoder_model.py @@ -0,0 +1,144 @@ +# -*- encoding: utf-8 -*- +''' +@File : encoder_decoder_model.py +@Time : 2021/11/22 23:35:28 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import argparse +from .base_model import BaseModel, BaseMixin +from .common_layers import CrossAttention, LayerNorm + + +class CrossAttentionMixin(BaseMixin): + def __init__(self, num_layers, hidden_size, num_attention_heads, + attention_dropout_prob, output_dropout_prob, + init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None): + super().__init__() + + self.cross_attentions = torch.nn.ModuleList( + [CrossAttention( + hidden_size, num_attention_heads, + attention_dropout_prob, output_dropout_prob, + init_method, enc_hidden_size=enc_hidden_size, inner_hidden_size=inner_hidden_size, + output_layer_init_method=output_layer_init_method + ) for layer_id in range(num_layers)] + ) # Just copy args + self.cross_lns = torch.nn.ModuleList( + [LayerNorm(hidden_size, 1e-5) + for layer_id in range(num_layers)] + ) + + + def layer_forward(self, hidden_states, mask, layer_id, **kw_args): + layer = self.transformer.layers[layer_id] + encoder_outputs = kw_args['encoder_outputs'] + ''' + hidden_states: [batch, seq_len, hidden_size] + mask: [(1, 1), seq_len, seq_len] + encoder_outputs: [batch, enc_seq_len, enc_hidden_size] + ''' + # Layer norm at the begining of the transformer layer. + layernorm_output = layer.input_layernorm(hidden_states) + attention_output, output_this_layer = layer.attention(layernorm_output, mask, **kw_args) + # Third LayerNorm + if layer.sandwich_ln: + attention_output = layer.third_layernorm(attention_output) + # Residual connection. + hidden_states = hidden_states + attention_output + + # Cross attention. + layernorm_output = self.cross_lns[layer_id](hidden_states) + cross_attn_output = self.cross_attentions[layer_id]( + layernorm_output, + torch.ones(1, 1, device=hidden_states.device, dtype=hidden_states.dtype), + encoder_outputs + ) + hidden_states = hidden_states + cross_attn_output + + # Layer norm post the layer attention. + layernorm_output = layer.post_attention_layernorm(hidden_states) + # MLP. + mlp_output = layer.mlp(layernorm_output, **kw_args) + + # Fourth LayerNorm + if layer.sandwich_ln: + mlp_output = layer.fourth_layernorm(mlp_output) + output = hidden_states + mlp_output + + return output, output_this_layer + + +class DecoderModel(BaseModel): + def __init__(self, args, transformer=None): + dec_args = argparse.Namespace(**vars(args)) + dec_args.enc_hidden_size = dec_args.hidden_size # used for cross attn + override_attrs = ['num_layers', 'vocab_size', + 'hidden_size', 'num_attention_heads', + 'max_sequence_length', 'sandwich_ln' # TODO + ] + for name in override_attrs: + dec_attr = getattr(dec_args, 'dec_' + name, None) + if dec_attr is not None: # else use encoder-config + setattr(dec_args, name, dec_attr) + + super().__init__(dec_args, transformer=transformer) + self.add_mixin('cross_attention', + CrossAttentionMixin( + dec_args.num_layers, + dec_args.hidden_size, dec_args.num_attention_heads, + dec_args.attention_dropout, dec_args.hidden_dropout, + self.transformer.init_method, + enc_hidden_size=dec_args.enc_hidden_size, + inner_hidden_size=getattr(dec_args, 'dec_inner_hidden_size', None), + output_layer_init_method=self.transformer.output_layer_init_method + ) + ) + +class EncoderDecoderModel(torch.nn.Module): + def __init__(self, args, encoder=None, decoder=None): + super(EncoderDecoderModel, self).__init__() + if encoder is not None: + assert isinstance(encoder, BaseModel) + self.encoder = encoder + else: + self.encoder = BaseModel(args) + + if decoder is not None: + assert isinstance(decoder, BaseModel) + self.decoder = decoder + else: + self.decoder = DecoderModel(args) + + def reinit(self): + self.encoder.reinit() + self.decoder.reinit() + + def disable_untrainable_params(self): + self.encoder.disable_untrainable_params() + self.decoder.disable_untrainable_params() + + def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids, dec_attention_mask, *, branch_input=None, **kw_args): + mask_one = torch.ones(1, 1, device=enc_input_ids.device, dtype=dec_attention_mask.dtype) + enc_outputs, *_dumps = self.encoder(enc_input_ids, enc_position_ids, mask_one, branch_input=branch_input, **kw_args) + dec_outputs, *dec_mems = self.decoder(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=enc_outputs, branch_input=branch_input, **kw_args) + return enc_outputs, dec_outputs, *dec_mems + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('EncoderDecoderModel', 'T5 or Bart') + group.add_argument("--dec_num_layers", type=int, default=None) + group.add_argument("--dec_vocab_size", type=int, default=None) + group.add_argument("--dec_hidden_size", type=int, default=None) + group.add_argument("--dec_num_attention_heads", type=int, default=None) + group.add_argument("--dec_max_sequence_length", type=int, default=None) + group.add_argument("--dec_sandwich_ln", action='store_true') + group.add_argument("--dec_inner_hidden_size", type=int, default=None) + return parser \ No newline at end of file diff --git a/SwissArmyTransformer/model/mixins.py b/SwissArmyTransformer/model/mixins.py index 493e412d82f53ccfa748b10ed56afa2539bb5311..2a76b8038265ff00b7dee078df70b169ebb34105 100644 --- a/SwissArmyTransformer/model/mixins.py +++ b/SwissArmyTransformer/model/mixins.py @@ -3,7 +3,7 @@ @File : mixins.py @Time : 2021/10/01 17:52:40 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index cdff797d197da83d71bd968e3bb02b71f7fa85b5..3764f0552b1c4163536b061700b8d5d7b6e3ff79 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -51,9 +51,10 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, ) if log_attention_weights is not None: attention_scores += log_attention_weights - - # if attention_mask.shape[-2] > 1: # if auto-regressive, skip - attention_scores = torch.mul(attention_scores, attention_mask) - \ + + if not(attention_mask.shape[-2] == 1 and (attention_mask > 0).all()): + # if auto-regressive, skip + attention_scores = torch.mul(attention_scores, attention_mask) - \ 10000.0 * (1.0 - attention_mask) attention_probs = F.softmax(attention_scores, dim=-1) @@ -140,8 +141,9 @@ class SelfAttention(torch.nn.Module): class MLP(torch.nn.Module): def __init__(self, hidden_size, output_dropout_prob, init_method, - output_layer_init_method=None, hooks={}): + output_layer_init_method=None, layer_id=None, hooks={}): super(MLP, self).__init__() + self.layer_id = layer_id # Set output layer initialization if not provided. if output_layer_init_method is None: output_layer_init_method = init_method @@ -224,6 +226,7 @@ class BaseTransformerLayer(torch.nn.Module): output_dropout_prob, init_method, output_layer_init_method=output_layer_init_method, + layer_id=layer_id, hooks=hooks ) diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py index 5ee929afa5d2e6445edfa3a948f80098f6ec6c1f..fa93f3fd575bb4049c0b3bbbddd7d66e196d1e2b 100644 --- a/SwissArmyTransformer/tokenization/__init__.py +++ b/SwissArmyTransformer/tokenization/__init__.py @@ -3,7 +3,7 @@ @File : __init__.py @Time : 2021/10/06 17:58:04 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib @@ -47,13 +47,14 @@ def get_tokenizer(args=None, outer_tokenizer=None): return outer_tokenizer if not hasattr(get_tokenizer, 'tokenizer'): # the first time to load the tokenizer - if args.tokenizer_type == 'cogview': + if args.tokenizer_type.startswith('cogview'): # or cogview_ICE from .cogview import UnifiedTokenizer get_tokenizer.tokenizer = UnifiedTokenizer( args.img_tokenizer_path, + # txt_tokenizer_type=args.tokenizer_type, device=torch.cuda.current_device() ) - elif args.tokenizer_type.startswith('glm_'): + elif args.tokenizer_type.startswith('glm'): kwargs = {"add_block_symbols": True, "add_task_mask": args.task_mask, "add_decoder_mask": args.block_mask_prob > 0.0} if args.tokenizer_type == "glm_GPT2BPETokenizer": diff --git a/SwissArmyTransformer/tokenization/cogview/__init__.py b/SwissArmyTransformer/tokenization/cogview/__init__.py index 4df817a4fc7d2d474303795a668bf6f7e9c67699..429c4140baaa6de2f1ccd3a707f2640cceb7612d 100644 --- a/SwissArmyTransformer/tokenization/cogview/__init__.py +++ b/SwissArmyTransformer/tokenization/cogview/__init__.py @@ -3,7 +3,7 @@ @File : __init__.py @Time : 2021/10/06 18:21:15 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py index 24cb9c658af091ee4ec0a528d4f26ecb8f74059a..c7860e3a9883811732a968b168894b2a668f29e4 100644 --- a/SwissArmyTransformer/training/deepspeed_training.py +++ b/SwissArmyTransformer/training/deepspeed_training.py @@ -120,7 +120,6 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio with ExitStack() as stack: def save_on_exit(args_, model_, optimizer_, lr_scheduler_): save_checkpoint(args_.iteration, model_, optimizer_, lr_scheduler_, args_) - iteration, skipped = train(model, optimizer, lr_scheduler, train_data_iterator, @@ -131,7 +130,7 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio if args.do_valid: prefix = 'the end of training for val data' val_loss = evaluate_and_print_results(prefix, val_data_iterator, - model, args, timers, False) + model, args, timers, False, hooks=hooks) # final save if args.save and iteration != 0: # TODO save @@ -141,7 +140,7 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio if args.do_test and test_data is not None: prefix = 'the end of training for test data' evaluate_and_print_results(prefix, iter(test_data), - model, args, timers, True) + model, args, timers, True, hooks=hooks) def get_model(args, model_cls): @@ -477,6 +476,8 @@ def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time, summary_writer.add_scalar(f'Train/lr', lr, step) summary_writer.add_scalar(f'Train/train_loss', loss, step) summary_writer.add_scalar(f'Train/elapsed_time', elapsed_time, step) + for key in avg_metrics: + summary_writer.add_scalar('Train/'+key, avg_metrics[key], step) def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step): diff --git a/SwissArmyTransformer/training/model_io.py b/SwissArmyTransformer/training/model_io.py index d66e40e4926281f067bab4a3f841e30e30e89679..b423e613b4ec849939a0b686b01987525ad0f35a 100644 --- a/SwissArmyTransformer/training/model_io.py +++ b/SwissArmyTransformer/training/model_io.py @@ -3,7 +3,7 @@ @File : model_io.py @Time : 2021/10/05 18:39:55 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/env/setup_connection.py b/env/setup_connection.py index 20ee88e07258283d6cf2024c9b81f33c26809821..94f6eb1d3eb531bf565753fc2111523b02a1961a 100644 --- a/env/setup_connection.py +++ b/env/setup_connection.py @@ -3,7 +3,7 @@ @File : setup_connection.py @Time : 2021/01/16 16:50:36 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/examples/cogview/inference_cogview.py b/examples/cogview/inference_cogview.py index f546e37eb561aaf12270ef57c882a2f965f1a4fa..e5c506d8b818fa16ad29a3b4907e1af31d575e80 100644 --- a/examples/cogview/inference_cogview.py +++ b/examples/cogview/inference_cogview.py @@ -3,7 +3,7 @@ @File : inference_cogview.py @Time : 2021/10/09 19:41:58 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/examples/cogview/inference_cogview_caps.py b/examples/cogview/inference_cogview_caps.py index 081be09e13d2e027c58bc304cffc9e54578637d3..fc221ba2c9badfc1dc55e0d740205021b680218b 100644 --- a/examples/cogview/inference_cogview_caps.py +++ b/examples/cogview/inference_cogview_caps.py @@ -3,7 +3,7 @@ @File : inference_cogview.py @Time : 2021/10/09 19:41:58 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/examples/cogview/pretrain_gpt2.py b/examples/cogview/pretrain_gpt2.py index bfe88df492c2f089eb1593f2ea5992fe32600148..c00fe2492764c7172f88d989cba899e270fa8703 100755 --- a/examples/cogview/pretrain_gpt2.py +++ b/examples/cogview/pretrain_gpt2.py @@ -3,7 +3,7 @@ @File : pretrain_gpt2.py @Time : 2021/10/06 00:58:32 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib @@ -99,7 +99,6 @@ def forward_step(data_iterator, model, args, timers): losses = losses.view(-1) * loss_mask loss = torch.sum(losses) / loss_mask.sum() - return loss, {} def create_dataset_function(path, args): diff --git a/examples/cogview2/inference_cogview2.py b/examples/cogview2/inference_cogview2.py index 9a423d3d7daa08c3a65c88a8cb38d10f952c1e03..bf32e8287cc176a4eb58987729a6c8f9cfe9e4e6 100644 --- a/examples/cogview2/inference_cogview2.py +++ b/examples/cogview2/inference_cogview2.py @@ -3,7 +3,7 @@ @File : inference_cogview2.py @Time : 2021/10/10 16:31:34 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/examples/cogview2/pretrain_cogview2.py b/examples/cogview2/pretrain_cogview2.py index c9f789cf8229adf4eba1b05f2bf3c33281102927..da4dc22cded477861c7cfd38d35a65741efffea5 100755 --- a/examples/cogview2/pretrain_cogview2.py +++ b/examples/cogview2/pretrain_cogview2.py @@ -3,7 +3,7 @@ @File : pretrain_cogview2.py @Time : 2021/10/06 00:58:32 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib diff --git a/examples/glm/inference_glm.py b/examples/glm/inference_glm.py index a1f5b832d5e04dcda0c3a94e787c38b07c0a08d1..ec4bd1af3bf224e4a77cefc316a0ef7e29201f7a 100644 --- a/examples/glm/inference_glm.py +++ b/examples/glm/inference_glm.py @@ -3,7 +3,7 @@ @File : inference_glm.py @Time : 2021/10/22 19:41:58 @Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn +@Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib @@ -23,7 +23,7 @@ from SwissArmyTransformer import mpu, get_args, get_tokenizer, load_checkpoint, from SwissArmyTransformer.model import GLMModel from SwissArmyTransformer.model.mixins import CachedAutoregressiveMixin -from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence +from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence, evaluate_perplexity from SwissArmyTransformer.generation.sampling_strategies import BeamSearchStrategy, BaseStrategy from SwissArmyTransformer.generation.utils import timed_name, generate_continually diff --git a/setup.py b/setup.py index 8bee4e1037329491ea901be8d140b42b14a73315..636894d95dcc22a78c3c6e44697609ce9ece39e6 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def _requirements(): setup( name="SwissArmyTransformer", - version='0.1.1', + version='0.1.2', description="A transformer-based framework with finetuning as the first class citizen.", long_description=Path("README.md").read_text(), long_description_content_type="text/markdown",