From 8db7ecd79fcaa578fd02c2b0521b5484041c7b6e Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Sun, 5 Dec 2021 11:42:28 +0000
Subject: [PATCH] tmp_save

---
 SwissArmyTransformer/model/base_model.py      | 38 ++++++++++++++++---
 .../model/cached_autoregressive_model.py      |  2 +-
 SwissArmyTransformer/mpu/transformer.py       | 11 ++++--
 3 files changed, 41 insertions(+), 10 deletions(-)

diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py
index 0b593bb..1b0569c 100644
--- a/SwissArmyTransformer/model/base_model.py
+++ b/SwissArmyTransformer/model/base_model.py
@@ -7,6 +7,7 @@
 '''
 
 # here put the import lib
+from functools import partial
 import os
 import sys
 import math
@@ -14,6 +15,11 @@ import random
 import torch
 
 from SwissArmyTransformer.mpu import BaseTransformer
+from SwissArmyTransformer.mpu.transformer import standard_attention
+
+def non_conflict(func):
+    func.non_conflict = True
+    return func
 
 class BaseMixin(torch.nn.Module):
     def __init__(self):
@@ -23,9 +29,22 @@ class BaseMixin(torch.nn.Module):
     def reinit(self, *pre_mixins):
         # reload the initial params from previous trained modules
         pass
-    # can also define hook-functions here
+
+    # can define hook-functions here
     # ...
 
+    # If the hook is just a pre- or post- transformation,
+    # You can use @non_conflict to mark it,
+    # and run `old_impl` to make it compatible with other mixins.
+    # Eg., 
+    # 
+    # @non_conflict
+    # def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kwargs):
+    #     new_q, new_k, new_v = pre_hack(q, k, v)
+    #     attn_result = old_impl(q, k, v, mask, dropout_fn, **kwargs)
+    #     attn_result = post_hack(attn_result)
+    #     return attn_result
+
 
 class BaseModel(torch.nn.Module):
     def __init__(self, args, transformer=None, **kwargs):
@@ -87,17 +106,24 @@ class BaseModel(torch.nn.Module):
     def collect_hooks_(self):
         names = ['word_embedding_forward', 'position_embedding_forward',
                  'attention_forward', 'cross_attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
-                 'branch_embedding_forward', 'branch_final_forward'
+                 'branch_embedding_forward', 'branch_final_forward',
+                 'attention_fn'
                  ]
         hooks = {}
         hook_origins = {}
         for name in names:
             for mixin_name, m in self.mixins.items():
                 if hasattr(m, name):
-                    if name in hooks:  # conflict
-                        raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
-                    hooks[name] = getattr(m, name)
-                    hook_origins[name] = mixin_name
+                    if name in hooks: # if this hook name is already registered
+                        if hasattr(getattr(m, name), 'non_confict'):
+                            hooks[name] = partial(getattr(m, name), old_impl=hooks[name])
+                            hook_origins[name] = mixin_name + ' -> ' + hook_origins[name]
+                        else: # conflict
+                            raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
+                    else: # new hook
+                        hooks[name] = getattr(m, name)
+                        hook_origins[name] = mixin_name
+
             if hasattr(self, name):
                 # if name in hooks: # defined in mixins, can override
                 #     print(f'Override {name} in {hook_origins[name]}...')
diff --git a/SwissArmyTransformer/model/cached_autoregressive_model.py b/SwissArmyTransformer/model/cached_autoregressive_model.py
index a0e1699..296a811 100755
--- a/SwissArmyTransformer/model/cached_autoregressive_model.py
+++ b/SwissArmyTransformer/model/cached_autoregressive_model.py
@@ -13,7 +13,7 @@ import math
 import random
 import torch
 
-from .base_model import BaseModel, BaseMixin
+from .base_model import BaseModel, BaseMixin, non_conflict
 from SwissArmyTransformer.mpu.transformer import standard_attention, split_tensor_along_last_dim
 
 class CachedAutoregressiveMixin(BaseMixin):
diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index e9f7ab1..229e21a 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -44,7 +44,7 @@ class LayerNorm(FusedLayerNorm):
 
 
 def standard_attention(query_layer, key_layer, value_layer, attention_mask,
-                       attention_dropout=None, log_attention_weights=None, scaling_attention_score=True):
+                       attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs):
     # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. 
     # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. 
 
@@ -124,6 +124,10 @@ class SelfAttention(torch.nn.Module):
         if 'attention_forward' in self.hooks:
             return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id)
         else:
+            attention_fn = standard_attention
+            if 'attention_fn' in self.hooks:
+                attention_fn = self.hooks['attention_fn']
+
             mixed_raw_layer = self.query_key_value(hidden_states)
             (mixed_query_layer,
              mixed_key_layer,
@@ -135,7 +139,8 @@ class SelfAttention(torch.nn.Module):
             key_layer = self._transpose_for_scores(mixed_key_layer)
             value_layer = self._transpose_for_scores(mixed_value_layer)
 
-            context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn)
+            context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, layer_id=self.layer_id, **kw_args)
+
             context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
             new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
             context_layer = context_layer.view(*new_context_layer_shape)
@@ -198,7 +203,7 @@ class CrossAttention(torch.nn.Module):
         tensor = tensor.view(*new_tensor_shape)
         return tensor.permute(0, 2, 1, 3)
 
-    def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *, **kw_args):
+    def forward(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args):
         # hidden_states: [b, s, h]
         if 'cross_attention_forward' in self.hooks:
             return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs,
-- 
GitLab