Skip to content
Snippets Groups Projects
Commit 8db7ecd7 authored by Ming Ding's avatar Ming Ding
Browse files

tmp_save

parent 6a2fa408
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
''' '''
# here put the import lib # here put the import lib
from functools import partial
import os import os
import sys import sys
import math import math
...@@ -14,6 +15,11 @@ import random ...@@ -14,6 +15,11 @@ import random
import torch import torch
from SwissArmyTransformer.mpu import BaseTransformer from SwissArmyTransformer.mpu import BaseTransformer
from SwissArmyTransformer.mpu.transformer import standard_attention
def non_conflict(func):
func.non_conflict = True
return func
class BaseMixin(torch.nn.Module): class BaseMixin(torch.nn.Module):
def __init__(self): def __init__(self):
...@@ -23,9 +29,22 @@ class BaseMixin(torch.nn.Module): ...@@ -23,9 +29,22 @@ class BaseMixin(torch.nn.Module):
def reinit(self, *pre_mixins): def reinit(self, *pre_mixins):
# reload the initial params from previous trained modules # reload the initial params from previous trained modules
pass pass
# can also define hook-functions here
# can define hook-functions here
# ... # ...
# If the hook is just a pre- or post- transformation,
# You can use @non_conflict to mark it,
# and run `old_impl` to make it compatible with other mixins.
# Eg.,
#
# @non_conflict
# def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kwargs):
# new_q, new_k, new_v = pre_hack(q, k, v)
# attn_result = old_impl(q, k, v, mask, dropout_fn, **kwargs)
# attn_result = post_hack(attn_result)
# return attn_result
class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
def __init__(self, args, transformer=None, **kwargs): def __init__(self, args, transformer=None, **kwargs):
...@@ -87,17 +106,24 @@ class BaseModel(torch.nn.Module): ...@@ -87,17 +106,24 @@ class BaseModel(torch.nn.Module):
def collect_hooks_(self): def collect_hooks_(self):
names = ['word_embedding_forward', 'position_embedding_forward', names = ['word_embedding_forward', 'position_embedding_forward',
'attention_forward', 'cross_attention_forward', 'mlp_forward', 'final_forward', 'layer_forward', 'attention_forward', 'cross_attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
'branch_embedding_forward', 'branch_final_forward' 'branch_embedding_forward', 'branch_final_forward',
'attention_fn'
] ]
hooks = {} hooks = {}
hook_origins = {} hook_origins = {}
for name in names: for name in names:
for mixin_name, m in self.mixins.items(): for mixin_name, m in self.mixins.items():
if hasattr(m, name): if hasattr(m, name):
if name in hooks: # conflict if name in hooks: # if this hook name is already registered
raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.') if hasattr(getattr(m, name), 'non_confict'):
hooks[name] = getattr(m, name) hooks[name] = partial(getattr(m, name), old_impl=hooks[name])
hook_origins[name] = mixin_name hook_origins[name] = mixin_name + ' -> ' + hook_origins[name]
else: # conflict
raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
else: # new hook
hooks[name] = getattr(m, name)
hook_origins[name] = mixin_name
if hasattr(self, name): if hasattr(self, name):
# if name in hooks: # defined in mixins, can override # if name in hooks: # defined in mixins, can override
# print(f'Override {name} in {hook_origins[name]}...') # print(f'Override {name} in {hook_origins[name]}...')
......
...@@ -13,7 +13,7 @@ import math ...@@ -13,7 +13,7 @@ import math
import random import random
import torch import torch
from .base_model import BaseModel, BaseMixin from .base_model import BaseModel, BaseMixin, non_conflict
from SwissArmyTransformer.mpu.transformer import standard_attention, split_tensor_along_last_dim from SwissArmyTransformer.mpu.transformer import standard_attention, split_tensor_along_last_dim
class CachedAutoregressiveMixin(BaseMixin): class CachedAutoregressiveMixin(BaseMixin):
......
...@@ -44,7 +44,7 @@ class LayerNorm(FusedLayerNorm): ...@@ -44,7 +44,7 @@ class LayerNorm(FusedLayerNorm):
def standard_attention(query_layer, key_layer, value_layer, attention_mask, def standard_attention(query_layer, key_layer, value_layer, attention_mask,
attention_dropout=None, log_attention_weights=None, scaling_attention_score=True): attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs):
# We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training.
# The implementation in the paper can be done very easily, if you really need it to train very deep transformers. # The implementation in the paper can be done very easily, if you really need it to train very deep transformers.
...@@ -124,6 +124,10 @@ class SelfAttention(torch.nn.Module): ...@@ -124,6 +124,10 @@ class SelfAttention(torch.nn.Module):
if 'attention_forward' in self.hooks: if 'attention_forward' in self.hooks:
return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id) return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id)
else: else:
attention_fn = standard_attention
if 'attention_fn' in self.hooks:
attention_fn = self.hooks['attention_fn']
mixed_raw_layer = self.query_key_value(hidden_states) mixed_raw_layer = self.query_key_value(hidden_states)
(mixed_query_layer, (mixed_query_layer,
mixed_key_layer, mixed_key_layer,
...@@ -135,7 +139,8 @@ class SelfAttention(torch.nn.Module): ...@@ -135,7 +139,8 @@ class SelfAttention(torch.nn.Module):
key_layer = self._transpose_for_scores(mixed_key_layer) key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer) value_layer = self._transpose_for_scores(mixed_value_layer)
context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn) context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, layer_id=self.layer_id, **kw_args)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
...@@ -198,7 +203,7 @@ class CrossAttention(torch.nn.Module): ...@@ -198,7 +203,7 @@ class CrossAttention(torch.nn.Module):
tensor = tensor.view(*new_tensor_shape) tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3) return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *, **kw_args): def forward(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
if 'cross_attention_forward' in self.hooks: if 'cross_attention_forward' in self.hooks:
return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment