diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py index 61543f5d30ed922921b6cd9804b08550e894ca00..76e8c6a9e78527f6ba431c8c428add2245795917 100644 --- a/SwissArmyTransformer/model/base_model.py +++ b/SwissArmyTransformer/model/base_model.py @@ -106,7 +106,7 @@ class BaseModel(torch.nn.Module): def collect_hooks_(self): names = ['word_embedding_forward', 'position_embedding_forward', 'attention_forward', 'cross_attention_forward', 'mlp_forward', 'final_forward', 'layer_forward', - 'branch_embedding_forward', 'branch_final_forward', + 'cross_layer_embedding_forward', 'attention_fn' ] hooks = {} @@ -115,7 +115,7 @@ class BaseModel(torch.nn.Module): for mixin_name, m in self.mixins.items(): if hasattr(m, name): if name in hooks: # if this hook name is already registered - if hasattr(getattr(m, name), 'non_confict'): + if hasattr(getattr(m, name), 'non_conflict'): hooks[name] = partial(getattr(m, name), old_impl=hooks[name]) hook_origins[name] = mixin_name + ' -> ' + hook_origins[name] else: # conflict