diff --git a/model/base_model.py b/model/base_model.py index 85c416e848be01fd7aef23f4dc09d2463d66d101..8e36ef71c12c82daf269262db6a29c093b6d8468 100644 --- a/model/base_model.py +++ b/model/base_model.py @@ -44,6 +44,13 @@ class BaseModel(torch.nn.Module): for m in self.mixins: m.reinit(self.transformer) + def forward(self, *args, **kwargs): + # update hooks as the current model (overrided forwards) + # Attention! the transformer might be shared by multiple models + self.transformer.hooks.clear() + self.transformer.hooks.update(self.hooks) + return self.transformer(*args, **kwargs) + def collect_hooks(self): names = ['word_embedding_forward', 'position_embedding_forward', 'attention_forward', 'mlp_forward', 'final_forward'] @@ -51,4 +58,7 @@ class BaseModel(torch.nn.Module): for name in names: if hasattr(self, name): hooks[name] = partial(getattr(self, name), self) - return hooks \ No newline at end of file + return hooks + + def disable_untrainable_params(self): + pass \ No newline at end of file diff --git a/model/cached_autoregressive_model.py b/model/cached_autoregressive_model.py new file mode 100755 index 0000000000000000000000000000000000000000..e7234cb333c46b1560158514847f38c2efd78413 --- /dev/null +++ b/model/cached_autoregressive_model.py @@ -0,0 +1,51 @@ +# -*- encoding: utf-8 -*- +''' +@File : gpt2_modeling.py +@Time : 2021/10/02 00:37:22 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch + +from .base_model import BaseModel +from mpu.transformer import standard_attention, split_tensor_along_last_dim + +class CachedAutoregressiveModel(BaseModel): + def __init__(self, args, transformer=None): + super().__init__(args, transformer=transformer) + self.log_attention_weights = None + + def attention_forward(self, hidden_states, mask, *other_tensors, layer_id=None): + attn_module = self.transformer.layers[layer_id].attention + mem = other_tensors[layer_id] if len(other_tensors) > 0 else None + + mixed_raw_layer = attn_module.query_key_value(hidden_states) + (mixed_query_layer, + mixed_key_layer, + mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) + + if mem is not None: # the first time, mem is None + memk, memv = split_tensor_along_last_dim(mem, 2) + mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1) + mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1) + + # same as training + query_layer = self._transpose_for_scores(mixed_query_layer) + key_layer = self._transpose_for_scores(mixed_key_layer) + value_layer = self._transpose_for_scores(mixed_value_layer) + context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn=None, log_attention_weights=self.log_attention_weights) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + output = self.dense(context_layer) + + # new mem this layer + new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous() + + return output, new_mem diff --git a/model/gpt2.py b/model/gpt2.py deleted file mode 100755 index 514f8223ae6cfed27d4aa86c5151fa26a16fa228..0000000000000000000000000000000000000000 --- a/model/gpt2.py +++ /dev/null @@ -1,17 +0,0 @@ -# -*- encoding: utf-8 -*- -''' -@File : gpt2_modeling.py -@Time : 2021/10/02 00:37:22 -@Author : Ming Ding -@Contact : dm18@mail.tsinghua.edu.cn -''' - -# here put the import lib -import os -import sys -import math -import random -import torch - -from .base_model import BaseModel - diff --git a/mpu/transformer.py b/mpu/transformer.py index ca2e9b94d9a84f1fd059eb2e1d2c3832ddb430d1..c19ef2a793f316fe5b8174c4f035714ca6559124 100755 --- a/mpu/transformer.py +++ b/mpu/transformer.py @@ -116,7 +116,7 @@ class SelfAttention(torch.nn.Module): def forward(self, hidden_states, mask, *other_tensors): if 'attention_forward' in self.hooks: - return self.hooks['attention_forward'](hidden_states, mask, *other_tensors) + return self.hooks['attention_forward'](hidden_states, mask, *other_tensors,layer_id=self.layer_id) else: mixed_raw_layer = self.query_key_value(hidden_states) (mixed_query_layer,