From 8db7ecd79fcaa578fd02c2b0521b5484041c7b6e Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Sun, 5 Dec 2021 11:42:28 +0000 Subject: [PATCH] tmp_save --- SwissArmyTransformer/model/base_model.py | 38 ++++++++++++++++--- .../model/cached_autoregressive_model.py | 2 +- SwissArmyTransformer/mpu/transformer.py | 11 ++++-- 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py index 0b593bb..1b0569c 100644 --- a/SwissArmyTransformer/model/base_model.py +++ b/SwissArmyTransformer/model/base_model.py @@ -7,6 +7,7 @@ ''' # here put the import lib +from functools import partial import os import sys import math @@ -14,6 +15,11 @@ import random import torch from SwissArmyTransformer.mpu import BaseTransformer +from SwissArmyTransformer.mpu.transformer import standard_attention + +def non_conflict(func): + func.non_conflict = True + return func class BaseMixin(torch.nn.Module): def __init__(self): @@ -23,9 +29,22 @@ class BaseMixin(torch.nn.Module): def reinit(self, *pre_mixins): # reload the initial params from previous trained modules pass - # can also define hook-functions here + + # can define hook-functions here # ... + # If the hook is just a pre- or post- transformation, + # You can use @non_conflict to mark it, + # and run `old_impl` to make it compatible with other mixins. + # Eg., + # + # @non_conflict + # def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kwargs): + # new_q, new_k, new_v = pre_hack(q, k, v) + # attn_result = old_impl(q, k, v, mask, dropout_fn, **kwargs) + # attn_result = post_hack(attn_result) + # return attn_result + class BaseModel(torch.nn.Module): def __init__(self, args, transformer=None, **kwargs): @@ -87,17 +106,24 @@ class BaseModel(torch.nn.Module): def collect_hooks_(self): names = ['word_embedding_forward', 'position_embedding_forward', 'attention_forward', 'cross_attention_forward', 'mlp_forward', 'final_forward', 'layer_forward', - 'branch_embedding_forward', 'branch_final_forward' + 'branch_embedding_forward', 'branch_final_forward', + 'attention_fn' ] hooks = {} hook_origins = {} for name in names: for mixin_name, m in self.mixins.items(): if hasattr(m, name): - if name in hooks: # conflict - raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.') - hooks[name] = getattr(m, name) - hook_origins[name] = mixin_name + if name in hooks: # if this hook name is already registered + if hasattr(getattr(m, name), 'non_confict'): + hooks[name] = partial(getattr(m, name), old_impl=hooks[name]) + hook_origins[name] = mixin_name + ' -> ' + hook_origins[name] + else: # conflict + raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.') + else: # new hook + hooks[name] = getattr(m, name) + hook_origins[name] = mixin_name + if hasattr(self, name): # if name in hooks: # defined in mixins, can override # print(f'Override {name} in {hook_origins[name]}...') diff --git a/SwissArmyTransformer/model/cached_autoregressive_model.py b/SwissArmyTransformer/model/cached_autoregressive_model.py index a0e1699..296a811 100755 --- a/SwissArmyTransformer/model/cached_autoregressive_model.py +++ b/SwissArmyTransformer/model/cached_autoregressive_model.py @@ -13,7 +13,7 @@ import math import random import torch -from .base_model import BaseModel, BaseMixin +from .base_model import BaseModel, BaseMixin, non_conflict from SwissArmyTransformer.mpu.transformer import standard_attention, split_tensor_along_last_dim class CachedAutoregressiveMixin(BaseMixin): diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index e9f7ab1..229e21a 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -44,7 +44,7 @@ class LayerNorm(FusedLayerNorm): def standard_attention(query_layer, key_layer, value_layer, attention_mask, - attention_dropout=None, log_attention_weights=None, scaling_attention_score=True): + attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs): # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. @@ -124,6 +124,10 @@ class SelfAttention(torch.nn.Module): if 'attention_forward' in self.hooks: return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id) else: + attention_fn = standard_attention + if 'attention_fn' in self.hooks: + attention_fn = self.hooks['attention_fn'] + mixed_raw_layer = self.query_key_value(hidden_states) (mixed_query_layer, mixed_key_layer, @@ -135,7 +139,8 @@ class SelfAttention(torch.nn.Module): key_layer = self._transpose_for_scores(mixed_key_layer) value_layer = self._transpose_for_scores(mixed_value_layer) - context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn) + context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, layer_id=self.layer_id, **kw_args) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) @@ -198,7 +203,7 @@ class CrossAttention(torch.nn.Module): tensor = tensor.view(*new_tensor_shape) return tensor.permute(0, 2, 1, 3) - def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *, **kw_args): + def forward(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args): # hidden_states: [b, s, h] if 'cross_attention_forward' in self.hooks: return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, -- GitLab