Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
mixins.py 2.88 KiB
# -*- encoding: utf-8 -*-
'''
@File    :   mixins.py
@Time    :   2021/10/01 17:52:40
@Author  :   Ming Ding 
@Contact :   dm18@mail.tsinghua.edu.cn
'''

# here put the import lib
import os
import sys
import math
import random

import torch
from mpu import ColumnParallelLinear, RowParallelLinear
from mpu.transformer import unscaled_init_method

class BaseMixin(torch.nn.Module):
    def __init__(self):
        super(BaseMixin, self).__init__()
        # define new params
    def reinit(self, transformer, *pre_mixins):
        # reload the initial params from previous trained modules
        pass

class PositionEmbeddingMixin(BaseMixin):
    def __init__(self, additional_sequence_length, hidden_size, 
                init_method_std=0.02, reinit_slice=slice(-1024, None)
        ):
        super(PositionEmbeddingMixin, self).__init__()
        self.reinit_slice = reinit_slice
        self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
        torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
    def reinit(self, transformer, *pre_mixins):
        old_weights = transformer.position_embeddings.weight.data[self.reinit_slice]
        old_len, hidden_size = old_weights.shape
        assert hidden_size == self.position_embeddings.weight.shape[-1]
        self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)

class AttentionMixin(BaseMixin):
    def __init__(self, num_layers,
                hidden_size, 
                init_method=unscaled_init_method(0.02),
                output_layer_init_method=unscaled_init_method(0.02)
        ):
        super(AttentionMixin, self).__init__()
        self.num_layers = num_layers # replace attention in the LAST n layers
        self.query_key_value = torch.nn.ModuleList(
            [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
                gather_output=False,init_method=init_method)
                for layer_id in range(num_layers)
            ])
        self.dense = torch.nn.ModuleList(
            [RowParallelLinear(hidden_size,
                hidden_size,
                input_is_parallel=True,
                init_method=output_layer_init_method)
                for layer_id in range(num_layers)
            ])
    def reinit(self, transformer, *pre_mixins):
        start_layer = len(transformer.layers) - self.num_layers
        assert start_layer >= 0
        for layer_id in range(self.num_layers):
            old_attention = transformer.layers[start_layer + layer_id].attention
            self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
            self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
            self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data)
            self.dense[layer_id].bias.data.copy_(old_attention.dense.bias.data)