From bcbc61205fbf60641b731a5b709ecec593b53bc9 Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Mon, 13 Dec 2021 14:57:18 +0000
Subject: [PATCH] fix basemodel typo and old hooks name

---
 SwissArmyTransformer/model/base_model.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py
index 61543f5..76e8c6a 100644
--- a/SwissArmyTransformer/model/base_model.py
+++ b/SwissArmyTransformer/model/base_model.py
@@ -106,7 +106,7 @@ class BaseModel(torch.nn.Module):
     def collect_hooks_(self):
         names = ['word_embedding_forward', 'position_embedding_forward',
                  'attention_forward', 'cross_attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
-                 'branch_embedding_forward', 'branch_final_forward',
+                 'cross_layer_embedding_forward',
                  'attention_fn'
                  ]
         hooks = {}
@@ -115,7 +115,7 @@ class BaseModel(torch.nn.Module):
             for mixin_name, m in self.mixins.items():
                 if hasattr(m, name):
                     if name in hooks: # if this hook name is already registered
-                        if hasattr(getattr(m, name), 'non_confict'):
+                        if hasattr(getattr(m, name), 'non_conflict'):
                             hooks[name] = partial(getattr(m, name), old_impl=hooks[name])
                             hook_origins[name] = mixin_name + ' -> ' + hook_origins[name]
                         else: # conflict
-- 
GitLab