Skip to content
Snippets Groups Projects
Commit bcbc6120 authored by Ming Ding's avatar Ming Ding
Browse files

fix basemodel typo and old hooks name

parent be5e6c90
No related branches found
No related tags found
No related merge requests found
...@@ -106,7 +106,7 @@ class BaseModel(torch.nn.Module): ...@@ -106,7 +106,7 @@ class BaseModel(torch.nn.Module):
def collect_hooks_(self): def collect_hooks_(self):
names = ['word_embedding_forward', 'position_embedding_forward', names = ['word_embedding_forward', 'position_embedding_forward',
'attention_forward', 'cross_attention_forward', 'mlp_forward', 'final_forward', 'layer_forward', 'attention_forward', 'cross_attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
'branch_embedding_forward', 'branch_final_forward', 'cross_layer_embedding_forward',
'attention_fn' 'attention_fn'
] ]
hooks = {} hooks = {}
...@@ -115,7 +115,7 @@ class BaseModel(torch.nn.Module): ...@@ -115,7 +115,7 @@ class BaseModel(torch.nn.Module):
for mixin_name, m in self.mixins.items(): for mixin_name, m in self.mixins.items():
if hasattr(m, name): if hasattr(m, name):
if name in hooks: # if this hook name is already registered if name in hooks: # if this hook name is already registered
if hasattr(getattr(m, name), 'non_confict'): if hasattr(getattr(m, name), 'non_conflict'):
hooks[name] = partial(getattr(m, name), old_impl=hooks[name]) hooks[name] = partial(getattr(m, name), old_impl=hooks[name])
hook_origins[name] = mixin_name + ' -> ' + hook_origins[name] hook_origins[name] = mixin_name + ' -> ' + hook_origins[name]
else: # conflict else: # conflict
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment