# -*- encoding: utf-8 -*- ''' @File : base_model.py @Time : 2021/10/01 22:40:33 @Author : Ming Ding @Contact : dm18@mail.tsinghua.edu.cn ''' # here put the import lib import os import sys import math import random import torch from functools import partial from mpu import BaseTransformer class BaseModel(torch.nn.Module): def __init__(self, args, transformer=None): super(BaseModel, self).__init__() self.hooks = self.collect_hooks() if transformer is not None: self.transformer = transformer else: self.transformer = BaseTransformer( num_layers=args.num_layers, vocab_size=args.vocab_size, hidden_size=args.hidden_size, num_attention_heads=args.num_attention_heads, max_sequence_length=args.max_position_embeddings, embedding_dropout_prob=args.hidden_dropout, attention_dropout_prob=args.attention_dropout, output_dropout_prob=args.hidden_dropout, checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, sandwich_ln=args.sandwich_ln, parallel_output=True, hooks=self.hooks ) self.mixins = torch.nn.ModuleList() def reinit(self): for m in self.mixins: m.reinit(self.transformer) def forward(self, *args, **kwargs): # update hooks as the current model (overrided forwards) # Attention! the transformer might be shared by multiple models self.transformer.hooks.clear() self.transformer.hooks.update(self.hooks) return self.transformer(*args, **kwargs) def collect_hooks(self): names = ['word_embedding_forward', 'position_embedding_forward', 'attention_forward', 'mlp_forward', 'final_forward'] hooks = {} for name in names: if hasattr(self, name): hooks[name] = partial(getattr(self, name), self) return hooks def disable_untrainable_params(self): pass