diff --git a/model/base_model.py b/model/base_model.py
index 85c416e848be01fd7aef23f4dc09d2463d66d101..8e36ef71c12c82daf269262db6a29c093b6d8468 100644
--- a/model/base_model.py
+++ b/model/base_model.py
@@ -44,6 +44,13 @@ class BaseModel(torch.nn.Module):
         for m in self.mixins:
             m.reinit(self.transformer)
     
+    def forward(self, *args, **kwargs):
+        # update hooks as the current model (overrided forwards)
+        # Attention! the transformer might be shared by multiple models
+        self.transformer.hooks.clear()
+        self.transformer.hooks.update(self.hooks)
+        return self.transformer(*args, **kwargs)
+        
     def collect_hooks(self):
         names = ['word_embedding_forward', 'position_embedding_forward',
                     'attention_forward', 'mlp_forward', 'final_forward']
@@ -51,4 +58,7 @@ class BaseModel(torch.nn.Module):
         for name in names:
             if hasattr(self, name):
                 hooks[name] = partial(getattr(self, name), self)
-        return hooks
\ No newline at end of file
+        return hooks
+
+    def disable_untrainable_params(self):
+        pass
\ No newline at end of file
diff --git a/model/cached_autoregressive_model.py b/model/cached_autoregressive_model.py
new file mode 100755
index 0000000000000000000000000000000000000000..e7234cb333c46b1560158514847f38c2efd78413
--- /dev/null
+++ b/model/cached_autoregressive_model.py
@@ -0,0 +1,51 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   gpt2_modeling.py
+@Time    :   2021/10/02 00:37:22
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+
+from .base_model import BaseModel
+from mpu.transformer import standard_attention, split_tensor_along_last_dim
+
+class CachedAutoregressiveModel(BaseModel):
+    def __init__(self, args, transformer=None):
+        super().__init__(args, transformer=transformer)
+        self.log_attention_weights = None
+        
+    def attention_forward(self, hidden_states, mask, *other_tensors, layer_id=None):
+        attn_module = self.transformer.layers[layer_id].attention
+        mem = other_tensors[layer_id] if len(other_tensors) > 0 else None
+        
+        mixed_raw_layer = attn_module.query_key_value(hidden_states)
+        (mixed_query_layer,
+            mixed_key_layer,
+            mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
+        
+        if mem is not None: # the first time, mem is None
+            memk, memv = split_tensor_along_last_dim(mem, 2)
+            mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1)
+            mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1)
+
+        # same as training
+        query_layer = self._transpose_for_scores(mixed_query_layer)
+        key_layer = self._transpose_for_scores(mixed_key_layer)
+        value_layer = self._transpose_for_scores(mixed_value_layer)
+        context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn=None, log_attention_weights=self.log_attention_weights)
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+        context_layer = context_layer.view(*new_context_layer_shape)
+        output = self.dense(context_layer)
+        
+        # new mem this layer
+        new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous()
+            
+        return output, new_mem
diff --git a/model/gpt2.py b/model/gpt2.py
deleted file mode 100755
index 514f8223ae6cfed27d4aa86c5151fa26a16fa228..0000000000000000000000000000000000000000
--- a/model/gpt2.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# -*- encoding: utf-8 -*-
-'''
-@File    :   gpt2_modeling.py
-@Time    :   2021/10/02 00:37:22
-@Author  :   Ming Ding 
-@Contact :   dm18@mail.tsinghua.edu.cn
-'''
-
-# here put the import lib
-import os
-import sys
-import math
-import random
-import torch
-
-from .base_model import BaseModel
-
diff --git a/mpu/transformer.py b/mpu/transformer.py
index ca2e9b94d9a84f1fd059eb2e1d2c3832ddb430d1..c19ef2a793f316fe5b8174c4f035714ca6559124 100755
--- a/mpu/transformer.py
+++ b/mpu/transformer.py
@@ -116,7 +116,7 @@ class SelfAttention(torch.nn.Module):
 
     def forward(self, hidden_states, mask, *other_tensors):
         if 'attention_forward' in self.hooks:
-            return self.hooks['attention_forward'](hidden_states, mask, *other_tensors)
+            return self.hooks['attention_forward'](hidden_states, mask, *other_tensors,layer_id=self.layer_id)
         else:
             mixed_raw_layer = self.query_key_value(hidden_states)
             (mixed_query_layer,