Skip to content
Snippets Groups Projects
Commit bcbc6120 authored by Ming Ding's avatar Ming Ding
Browse files

fix basemodel typo and old hooks name

parent be5e6c90
No related branches found
No related tags found
No related merge requests found
......@@ -106,7 +106,7 @@ class BaseModel(torch.nn.Module):
def collect_hooks_(self):
names = ['word_embedding_forward', 'position_embedding_forward',
'attention_forward', 'cross_attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
'branch_embedding_forward', 'branch_final_forward',
'cross_layer_embedding_forward',
'attention_fn'
]
hooks = {}
......@@ -115,7 +115,7 @@ class BaseModel(torch.nn.Module):
for mixin_name, m in self.mixins.items():
if hasattr(m, name):
if name in hooks: # if this hook name is already registered
if hasattr(getattr(m, name), 'non_confict'):
if hasattr(getattr(m, name), 'non_conflict'):
hooks[name] = partial(getattr(m, name), old_impl=hooks[name])
hook_origins[name] = mixin_name + ' -> ' + hook_origins[name]
else: # conflict
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment