diff --git a/generation/autoregressive_sampling.py b/generation/autoregressive_sampling.py
index b29964f0a2fbd77945347a26a3cbfead95de7a8c..402200b669ab59f6bafef443bb96c75e04982f04 100644
--- a/generation/autoregressive_sampling.py
+++ b/generation/autoregressive_sampling.py
@@ -85,7 +85,8 @@ def filling_sequence(
             continue
 
         # forward
-        model.log_attention_weights = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen
+        if log_attention_weights is not None:
+            model.log_attention_weights = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen
         logits, *mem_kv = model(
             tokens[:, index:], 
             position_ids[..., index: counter+1],
diff --git a/generation/cuda2d_sampling.py b/generation/cuda2d_sampling.py
index 20ffa38242c66421d28b8cdb84abb53373d51ec9..c7ebf31c5e52143bf4ab60201dd53f0aa72b5c2a 100644
--- a/generation/cuda2d_sampling.py
+++ b/generation/cuda2d_sampling.py
@@ -14,7 +14,7 @@ import random
 import torch
 from .sampling_strategies import IterativeEntfilterStrategy
 
-def filling_sequence(
+def filling_sequence_cuda2d(
         model, 
         seq0,
         seq1, 
@@ -38,10 +38,9 @@ def filling_sequence(
     assert seq1.shape[1] == layout[-1] - layout[-2]
     assert (seq1 >= 0).all() and (seq0 >= 0).all()
     device = seq0.device
-
     # concat and pad sequences
     batch_size = seq0.shape[0]
-    n_pad = layout[1] + 1 - len(seq0) # +1 for [EOI1]
+    n_pad = layout[1] + 1 - seq0.shape[1] # +1 for [EOI1]
     assert n_pad > 0, "You should truncate long input before filling."
     seq = torch.cat((
         torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype)
@@ -53,6 +52,7 @@ def filling_sequence(
     tokens = seq.clone()
     attention_mask = torch.ones(layout[1], layout[1]).tril().to(device)
     attention_mask[n_pad:, :n_pad] = 0
+    attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
     position_ids = torch.cat((
         torch.zeros(n_pad, dtype=torch.long),
         torch.arange(0, layout[1] - n_pad), 
@@ -60,7 +60,7 @@ def filling_sequence(
 
     # prepare for interation
     unfixed = (tokens < 0)
-    unfixed[:, -4096] = True
+    unfixed[:, -layout[-1] + layout[-2]:] = True
     ll, rr = block_hw
     edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
     num_steps = warmup_steps + ll + rr - 2
@@ -69,7 +69,8 @@ def filling_sequence(
         logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask)
         if step_cnt <= warmup_steps:
             real_temp = 0.1
-            tokens = strategy.forward(logits, tokens, real_temp)
+            new_tokens = strategy.forward(logits, tokens, real_temp)
+            tokens[unfixed] = new_tokens[unfixed]
         else:
             real_temp = 1.05
             new_tokens = strategy.forward(
@@ -86,4 +87,5 @@ def filling_sequence(
                     print(x,y)
                     unfixed[..., -(layout[-1] - layout[-2]):].view(
                         batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = False
-    return tokens
\ No newline at end of file
+
+    return tokens[:, n_pad:]
\ No newline at end of file
diff --git a/generation/sampling_strategies/iterative_entfilter_strategy.py b/generation/sampling_strategies/iterative_entfilter_strategy.py
index 84196d0c8fae8db48ed515ff265dea9e22136409..9017a42a3c4dc5092be7c57f64b81e1ef0f8e0bf 100644
--- a/generation/sampling_strategies/iterative_entfilter_strategy.py
+++ b/generation/sampling_strategies/iterative_entfilter_strategy.py
@@ -36,7 +36,7 @@ class IterativeEntfilterStrategy:
             topraw = (torch.topk(logits, filter_topk, dim=-1)[0]).softmax(dim=-1)
             ent = -(topraw * topraw.log()).sum(dim=-1) # [batch_size, seq_length]
             temperature = torch.tensor([[[temperature - temperature2]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > entfilter).unsqueeze(-1) + temperature2
-        logits = logits / temperature
+        logits = logits.float() / temperature
         for invalid_slice in self.invalid_slices:
             logits[..., invalid_slice] = -float('Inf')
         
diff --git a/generation/utils.py b/generation/utils.py
index bad073413b6c3efa6f71b3e46eeacb883bd504b0..01c05726a682e2245575517ba18d3ca27deea9c7 100644
--- a/generation/utils.py
+++ b/generation/utils.py
@@ -68,11 +68,12 @@ def generate_continually(func, input_source='interactive'):
             raw_text = raw_text.strip()
             if len(raw_text) == 0:
                 continue
-            try:
-                start_time = time.time()
-                func(raw_text)
-                print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
-            except (ValueError, FileNotFoundError) as e:
-                err_linenos.append(line_no)
-                continue
+            # try:
+            start_time = time.time()
+            func(raw_text)
+            print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
+            # except (ValueError, FileNotFoundError) as e:
+            #     err_linenos.append(line_no)
+            #     print(e)
+            #     continue
         print(err_linenos)
diff --git a/inference_cogview.py b/inference_cogview.py
index 56bdd96ead78f155fce08552e6c582facbeffbdc..88311ff70e737a2f93c003c7bc51ec231d614646 100644
--- a/inference_cogview.py
+++ b/inference_cogview.py
@@ -32,6 +32,7 @@ def main(args):
     model = model.to(args.device)
     load_checkpoint(model, args)
     set_random_seed(args.seed)
+    model.eval()
     
     # define function for each query
     query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024' if not args.full_query else '{}'
diff --git a/inference_cogview2.py b/inference_cogview2.py
new file mode 100644
index 0000000000000000000000000000000000000000..6895d363e242e283d65e6c93949f6bf01a2f6f0a
--- /dev/null
+++ b/inference_cogview2.py
@@ -0,0 +1,125 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   inference_cogview2.py
+@Time    :   2021/10/10 16:31:34
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+
+import torch
+import argparse
+from torchvision import transforms
+
+from arguments import get_args
+from model.cached_autoregressive_model import CachedAutoregressiveModel
+from model.cuda2d_model import Cuda2dModel
+from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
+from tokenization import get_tokenizer
+from generation.sampling_strategies import BaseStrategy, IterativeEntfilterStrategy
+from generation.autoregressive_sampling import filling_sequence
+from generation.cuda2d_sampling import filling_sequence_cuda2d
+from generation.utils import timed_name, save_multiple_images, generate_continually
+
+def main(args):
+    initialize_distributed(args)
+    tokenizer = prepare_tokenizer(args)
+    # build model 
+    model = Cuda2dModel(args)
+    if args.fp16:
+        model = model.half()
+    model = model.to(args.device)
+    load_checkpoint(model, args)
+    model0 = CachedAutoregressiveModel(args, transformer=model.transformer)
+    set_random_seed(args.seed)
+    model.eval()
+    model0.eval()
+    # define function for each query
+    query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024 [EOI1]' if not args.full_query else '{}'
+    invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
+    strategy0 = BaseStrategy(invalid_slices, 
+        temperature=args.temperature, topk=args.top_k)
+    strategy1 = IterativeEntfilterStrategy(invalid_slices,
+        temperature=args.temperature, topk=10) # temperature not used
+    tr = transforms.Compose([
+                transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), 
+            ])
+    
+    def process(raw_text):
+        if args.with_id:
+            query_id, raw_text = raw_text.split()
+        print('raw text: ', raw_text)
+        text = query_template.format(raw_text)
+        seq = tokenizer.parse_query(text, img_size=args.img_size)
+        if len(seq) > 1088:
+            raise ValueError('text too long.')
+        # calibrate text length
+        txt_len = seq.index(tokenizer['[BASE]'])
+        log_attention_weights = torch.zeros(len(seq), len(seq), 
+            device=args.device, dtype=torch.half if args.fp16 else torch.float32)
+        log_attention_weights[txt_len+2:, 1:txt_len] = 1.8 if txt_len <= 10 else 1.4 # TODO args
+
+        # generation
+        seq = torch.cuda.LongTensor(seq, device=args.device)
+        mbz = args.max_inference_batch_size
+        assert args.batch_size < mbz or args.batch_size % mbz == 0
+        output_list = []
+        for tim in range(max(args.batch_size // mbz, 1)):
+            output0 = filling_sequence(model0, seq.clone(),
+                    batch_size=min(args.batch_size, mbz),
+                    strategy=strategy0,
+                    log_attention_weights=log_attention_weights
+                    )
+            imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1025:-1].tolist())) for x in output0]
+            blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(args.device), add_normalization=True) # [batch_size, 4096]
+            output1 = filling_sequence_cuda2d(model, output0, blur64, 
+                    warmup_steps=3, block_hw=(4, 4),
+                    strategy=strategy1
+                    )
+            output_list.append(output1)
+        output_tokens = torch.cat(output_list, dim=0)
+        # decoding
+        imgs, txts = [], []
+        for seq in output_tokens:
+            decoded_txts, decoded_imgs = tokenizer.DecodeIds(seq.tolist())
+            for i in range(len(decoded_imgs)):
+                if decoded_imgs[i].shape[-1] < 512:
+                    decoded_imgs[i] = torch.nn.functional.interpolate(decoded_imgs[i], size=(512, 512))
+            if args.with_id:
+                imgs.append(decoded_imgs[-1]) # only the last image (target)
+            else:
+                imgs.extend(decoded_imgs)
+        # save
+        if args.with_id:
+            full_path = os.path.join(args.output_path, query_id)
+            os.makedirs(full_path, exist_ok=True)
+            save_multiple_images(imgs, full_path, False)
+        else:
+            prefix = raw_text.replace('/', '')[:20]
+            full_path = timed_name(prefix, '.jpg', args.output_path)
+            save_multiple_images(imgs, full_path, True)
+    
+    os.makedirs(args.output_path, exist_ok=True)
+    generate_continually(process, args.input_source)
+
+if __name__ == "__main__":
+    py_parser = argparse.ArgumentParser(add_help=False)
+    py_parser.add_argument('--full-query', action='store_true')
+    py_parser.add_argument('--img-size', type=int, default=256)
+    
+    Cuda2dModel.add_model_specific_args(py_parser)
+    
+    known, args_list = py_parser.parse_known_args()
+    
+    args = get_args(args_list)
+    args = argparse.Namespace(**vars(args), **vars(known))
+    
+    args.layout = [int(x) for x in args.layout.split(',')]
+    
+    with torch.no_grad():
+        main(args)
\ No newline at end of file
diff --git a/move_weights.py b/move_weights.py
index 7d41f2b2990effb15c7022dd248c6ce8ada8ed79..e55d6b34b35c2db3afc0eaeffcc2fb029142f784 100644
--- a/move_weights.py
+++ b/move_weights.py
@@ -103,3 +103,13 @@ missing_keys, unexpected_keys = model.load_state_dict(old['module'], strict=Fals
 # %%
 torch.save(old, 'pretrained/cogview/cogview2-base/6000/mp_rank_00_model_states.pt')
 # %%
+import torch
+old = torch.load("/dataset/fd5061f6/cogview/zwd/vqgan/l1+ms-ssim+revd_percep/checkpoints/last.ckpt", map_location='cpu')
+
+# %%
+from collections import OrderedDict
+new_ckpt = OrderedDict()
+for k,v in old['state_dict'].items():
+    new_ckpt[k] = v.detach()
+torch.save(new_ckpt, 'pretrained/vqvae/l1+ms-ssim+revd_percep.pt')
+# %%
diff --git a/mpu/transformer.py b/mpu/transformer.py
index 61f658ee4bc1e5de73f7f2e01fb77fa0adc3e21f..e04485a697504d5e7fc3d9967b081981a7503bd9 100755
--- a/mpu/transformer.py
+++ b/mpu/transformer.py
@@ -308,7 +308,7 @@ class BaseTransformer(torch.nn.Module):
                 layer_id,
                 output_layer_init_method=self.output_layer_init_method,
                 sandwich_ln=sandwich_ln,
-                hooks=hooks
+                hooks=self.hooks
                 )
         self.layers = torch.nn.ModuleList(
             [get_layer(layer_id) for layer_id in range(num_layers)])
diff --git a/scripts/text2image_cogview.sh b/scripts/text2image_cogview.sh
index 4fc46365914d359466343eb0185c1c147d4c7155..f205776ac5e7572f60e6582812cf3935c5a2d2c1 100755
--- a/scripts/text2image_cogview.sh
+++ b/scripts/text2image_cogview.sh
@@ -17,7 +17,7 @@ script_dir=$(dirname $script_path)
 
 MASTER_PORT=${MASTER_PORT} python inference_cogview.py \
        --tokenizer-type cogview \
-       --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \
+       --img-tokenizer-path pretrained/vqvae/l1+ms-ssim+revd_percep.pt \
        --mode inference \
        --distributed-backend nccl \
        --max-sequence-length 1089 \
diff --git a/scripts/text2image_cogview2.sh b/scripts/text2image_cogview2.sh
new file mode 100755
index 0000000000000000000000000000000000000000..96e8a045bbb2f3d4f046b9949dc734b12ae0288e
--- /dev/null
+++ b/scripts/text2image_cogview2.sh
@@ -0,0 +1,41 @@
+#!/bin/bash
+
+CHECKPOINT_PATH=pretrained/cogview/cogview2-base
+NLAYERS=48
+NHIDDEN=2560
+NATT=40
+MAXSEQLEN=1089
+MASTER_PORT=$(shuf -n 1 -i 10000-65535)
+MPSIZE=1
+
+#SAMPLING ARGS
+TEMP=1.03
+TOPK=200
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+
+MASTER_PORT=${MASTER_PORT} python inference_cogview2.py \
+       --tokenizer-type cogview \
+       --img-tokenizer-path pretrained/vqvae/l1+ms-ssim+revd_percep.pt \
+       --mode inference \
+       --distributed-backend nccl \
+       --max-sequence-length 1089 \
+       --sandwich-ln \
+       --fp16 \
+       --model-parallel-size $MPSIZE \
+       --num-layers $NLAYERS \
+       --hidden-size $NHIDDEN \
+       --load $CHECKPOINT_PATH \
+       --num-attention-heads $NATT \
+       --temperature $TEMP \
+       --top_k $TOPK \
+       --sandwich-ln \
+       --input-source ./input.txt \
+       --output-path samples_text2image \
+       --batch-size 4 \
+       --max-inference-batch-size 8 \
+       --device 0 \
+       $@
+
+
diff --git a/tokenization/cogview/vqvae_tokenizer.py b/tokenization/cogview/vqvae_tokenizer.py
index 56ee251126fdcb901d4b88cea114342b1dfccdb7..23df2efab1ca3cb925908a1f02ee688712f7ed15 100755
--- a/tokenization/cogview/vqvae_tokenizer.py
+++ b/tokenization/cogview/vqvae_tokenizer.py
@@ -18,7 +18,7 @@ import torch
 import torch.nn.functional as F
 
 
-from vqvae import new_model, img2code, code2img
+from vqvae import new_model, img2code, code2img, load_model_default
 from torchvision import transforms
 from PIL import Image
 
@@ -35,16 +35,17 @@ class VQVAETokenizer(object):
             model_path, 
             device='cuda'
         ):
-        ckpt = torch.load(model_path, map_location=torch.device(device))
+        # ckpt = torch.load(model_path, map_location=torch.device(device))
 
-        model = new_model()
+        # model = new_model()
 
-        if list(ckpt.keys())[0].startswith('module.'):
-            ckpt = {k[7:]: v for k, v in ckpt.items()}
+        # if list(ckpt.keys())[0].startswith('module.'):
+        #     ckpt = {k[7:]: v for k, v in ckpt.items()}
 
-        model.load_state_dict(ckpt)
-        model = model.to(device)
-        model.eval()
+        # model.load_state_dict(ckpt)
+        # model = model.to(device)
+        # model.eval()
+        model = load_model_default(device=device, path=model_path)
 
         self.model = model
         self.device = device
diff --git a/vqvae/__init__.py b/vqvae/__init__.py
index 9d2256330e77c8c1ab47ff8ddd21854d15a32f37..2652e923e0f0a7dd90da239d021f3dd4fd922f08 100755
--- a/vqvae/__init__.py
+++ b/vqvae/__init__.py
@@ -1 +1,3 @@
-from .api import new_model, img2code, code2img
\ No newline at end of file
+from .api import new_model, img2code, code2img
+from .api import new_module, load_decoder_default, load_model_default 
+from .api import test_decode, test_decode_default
\ No newline at end of file
diff --git a/vqvae/api.py b/vqvae/api.py
index 5a5f64c873baaae70b79d41b6b8f1fc9062e93f0..7d2fd0312c1276ac4cbe9c46b04e8aeef2c905e1 100755
--- a/vqvae/api.py
+++ b/vqvae/api.py
@@ -2,13 +2,206 @@
 # Can rewrite the APIs for VQGAN.
 # Don't forget to freeze the relavant .py files.
 
+import importlib
 import torch
 import math
+import os
+
+from torchvision.utils import save_image, make_grid
+from datetime import datetime
 
 # production APIs
 
 from .vqvae_zc import VQVAE
 
+def new_module(config):
+    '''
+        in config:
+            "target": module type, vqvae_zc.Decoder/vqvae_diffusion.Decoder/vqvae_diffusion.Decoder2
+            "ckpt": path of checkpoint
+            "ckpt_prefix": prefix to remove in ckpt state dict
+            "device": device
+            "params": dict of params
+    '''
+    if not "target" in config:
+        raise KeyError("Expected key `target` to instantiate.")
+    module, cls = config.get("target").rsplit(".", 1)
+    model = getattr(importlib.import_module(module, package=None), cls)(**config.get("params", dict()))
+
+    device = config.get("device", "cpu")
+    model = model.to(device)
+    model.eval()
+
+    if "ckpt" in config:
+        ckpt = torch.load(config.get("ckpt"), map_location='cpu')
+        prefix = config.get("ckpt_prefix", None)
+        if "state_dict" in ckpt:
+            ckpt = ckpt["state_dict"]   
+        if prefix is not None:
+            ckpt = {k[len(prefix) + 1:]: v for k, v in ckpt.items() if k.startswith(prefix)}
+        model.load_state_dict(ckpt, strict=False)
+        del ckpt
+    return model
+
+def load_decoder_default(device=0, path="pretrained/vqvae/l1+ms-ssim+revd_percep.pt"):
+    # exp: load currently best decoder
+    target = "vqvae.vqvae_diffusion.Decoder"
+    params = {
+        "double_z": False,
+        "z_channels": 256,
+        "resolution": 256,
+        "in_channels": 3,
+        "out_ch": 3,
+        "ch": 128,
+        "ch_mult": [ 1,1,2,4],  # num_down = len(ch_mult)-1
+        "num_res_blocks": 2,
+        "attn_resolutions": [16],
+        "dropout": 0.0
+    }
+    ckpt_prefix = "dec"
+
+    config = {
+        "target": target,
+        "params": params,
+        "ckpt": path,
+        "ckpt_prefix": ckpt_prefix,
+        "device": device
+    }
+    return new_module(config)
+
+# path="/dataset/fd5061f6/cogview/zwd/vqgan/l1+ms-ssim+revd_percep/checkpoints/last.ckpt"
+def load_model_default(device=0, 
+                    path="pretrained/vqvae/l1+ms-ssim+revd_percep.pt"):
+    # exp: load currently best vqvae model
+    ddconfig = {
+        "double_z": False,
+        "z_channels": 256,
+        "resolution": 256,
+        "in_channels": 3,
+        "out_ch": 3,
+        "ch": 128,
+        "ch_mult": [1,1,2,4],
+        "num_res_blocks": 2,
+        "attn_resolutions": [16],
+        "dropout": 0.0
+    }
+    params = {
+        "in_channel": 3,
+        "channel": 512,
+        "n_res_block": 0,
+        "n_res_channel": 32,
+        "embed_dim": 256,
+        "n_embed": 8192,
+        "stride": 6,
+        "simple": True,
+        "decay": 0.99,
+        "dif": True,
+        "ddconfig": ddconfig
+    }
+
+    config = {
+        'target': "vqvae.vqvae_zc.VQVAE",
+        'params': params,
+        'ckpt': path,
+        'device': device
+    }
+    return new_module(config)
+
+def test_decode(configs, testcase, device=0, output_path=None):
+    '''
+        configs: list of config for new module
+        testcases: pt file path or tensor of [B, D, H, W]
+    '''
+    if output_path is None:
+        output_path = os.path.join("sample", f"{datetime.now().strftime('%m-%d-%H-%M-%S')}.jpg")
+
+    quantize_config = {
+        "target": "vqvae.vqvae_zc.Quantize",
+        "params": {
+            "dim": 256,
+            "n_embed": 8192,
+        },
+        "ckpt": "/dataset/fd5061f6/cogview/zwd/pretrained/vqvae/vqvae_hard_biggerset_011.pt",
+        "ckpt_prefix": "module.quantize_t",
+        "device": device
+    }
+    quantize = new_module(quantize_config)
+
+    if type(testcase) is str:
+        testcase = torch.load(testcase, map_location=torch.device(device))[:, -1024:].contiguous()
+        testcase = testcase.view(testcase.shape[0], 32, 32).contiguous()
+    else:
+        testcase = testcase.to(device)
+
+    quantized_testcase = quantize.embed_code(testcase)
+    quantized_testcase = quantized_testcase.permute(0, 3, 1, 2)
+
+    outs = []
+    for config in configs:
+        decoder = new_module(config)
+        out = decoder(quantized_testcase)
+        outs.append(out.unsqueeze(0))
+    outs = torch.cat(outs).permute(1, 0, 2, 3, 4)
+    outs = outs.reshape(-1, *outs.shape[2:]).contiguous()
+    save_image(make_grid(outs, nrow=len(configs)), output_path, normalize=True)
+
+def test_decode_default(device=0):
+    # testing 3 decoders: original/l1+ms-ssim/l1+ms-ssim+perceptual
+    configs = [
+        {
+            "target": "vqvae.vqvae_zc.Decoder",
+            "params": {
+                "in_channel": 256, 
+                "out_channel": 3,
+                "channel": 512,
+                "n_res_block": 0,
+                "n_res_channel": 32,
+                "stride": 4,
+                "simple": True
+            },
+            "ckpt": "/dataset/fd5061f6/cogview/zwd/pretrained/vqvae/vqvae_hard_biggerset_011.pt",
+            "ckpt_prefix": "module.dec",
+            "device": device },
+        {
+            "target": "vqvae.vqvae_diffusion.Decoder",
+            "params": {
+                "double_z": False,
+                "z_channels": 256,
+                "resolution": 256,
+                "in_channels": 3,
+                "out_ch": 3,
+                "ch": 128,
+                "ch_mult": [ 1,1,2,4],  # num_down = len(ch_mult)-1
+                "num_res_blocks": 2,
+                "attn_resolutions": [16],
+                "dropout": 0.0
+            },
+            "ckpt": "/dataset/fd5061f6/cogview/zwd/vqgan/l1+ms-ssim/checkpoints/last.ckpt",
+            "ckpt_prefix": "dec",
+            "device": device },
+        {
+            "target": "vqvae.vqvae_diffusion.Decoder",
+            "params": {
+                "double_z": False,
+                "z_channels": 256,
+                "resolution": 256,
+                "in_channels": 3,
+                "out_ch": 3,
+                "ch": 128,
+                "ch_mult": [ 1,1,2,4],  # num_down = len(ch_mult)-1
+                "num_res_blocks": 2,
+                "attn_resolutions": [16],
+                "dropout": 0.0
+            },
+            "ckpt": "/dataset/fd5061f6/cogview/zwd/vqgan/l1+ms-ssim+revd_percep/checkpoints/last.ckpt",
+            "ckpt_prefix": "dec",
+            "device": device },
+    ]
+    testcase_dir = "/dataset/fd5061f6/cogview/zwd/vqgan/testcase/"
+    for testcase in os.listdir(testcase_dir):
+        testcase = os.path.join(testcase_dir, testcase)
+        test_decode(configs, testcase, device)
+
 def new_model():
     '''Return a New Instance of VQVAE, the same parameters with the pretrained model.
         This is for torch.load().
diff --git a/vqvae/vqvae_diffusion.py b/vqvae/vqvae_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9c566deb015c07fc2db3ff26f3fe4e7ab574f88
--- /dev/null
+++ b/vqvae/vqvae_diffusion.py
@@ -0,0 +1,782 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+    """
+    This matches the implementation in Denoising Diffusion Probabilistic Models:
+    From Fairseq.
+    Build sinusoidal embeddings.
+    This matches the implementation in tensor2tensor, but differs slightly
+    from the description in Section 3.5 of "Attention Is All You Need".
+    """
+    assert len(timesteps.shape) == 1
+
+    half_dim = embedding_dim // 2
+    emb = math.log(10000) / (half_dim - 1)
+    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+    emb = emb.to(device=timesteps.device)
+    emb = timesteps.float()[:, None] * emb[None, :]
+    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+    if embedding_dim % 2 == 1:  # zero pad
+        emb = torch.nn.functional.pad(emb, (0,1,0,0))
+    return emb
+
+
+def nonlinearity(x):
+    # swish
+    return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels):
+    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            self.conv = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+        if self.with_conv:
+            x = self.conv(x)
+        return x
+
+
+class Downsample(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            # no asymmetric padding in torch conv, must do it ourselves
+            self.conv = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=3,
+                                        stride=2,
+                                        padding=0)
+
+    def forward(self, x):
+        if self.with_conv:
+            pad = (0,1,0,1)
+            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+            x = self.conv(x)
+        else:
+            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+        return x
+
+
+class ResnetBlock(nn.Module):
+    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+                 dropout, temb_channels=512):
+        super().__init__()
+        self.in_channels = in_channels
+        out_channels = in_channels if out_channels is None else out_channels
+        self.out_channels = out_channels
+        self.use_conv_shortcut = conv_shortcut
+
+        self.norm1 = Normalize(in_channels)
+        self.conv1 = torch.nn.Conv2d(in_channels,
+                                     out_channels,
+                                     kernel_size=3,
+                                     stride=1,
+                                     padding=1)
+        if temb_channels > 0:
+            self.temb_proj = torch.nn.Linear(temb_channels,
+                                             out_channels)
+        self.norm2 = Normalize(out_channels)
+        self.dropout = torch.nn.Dropout(dropout)
+        self.conv2 = torch.nn.Conv2d(out_channels,
+                                     out_channels,
+                                     kernel_size=3,
+                                     stride=1,
+                                     padding=1)
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                self.conv_shortcut = torch.nn.Conv2d(in_channels,
+                                                     out_channels,
+                                                     kernel_size=3,
+                                                     stride=1,
+                                                     padding=1)
+            else:
+                self.nin_shortcut = torch.nn.Conv2d(in_channels,
+                                                    out_channels,
+                                                    kernel_size=1,
+                                                    stride=1,
+                                                    padding=0)
+
+    def forward(self, x, temb):
+        h = x
+        h = self.norm1(h)
+        h = nonlinearity(h)
+        h = self.conv1(h)
+
+        if temb is not None:
+            h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+        h = self.norm2(h)
+        h = nonlinearity(h)
+        h = self.dropout(h)
+        h = self.conv2(h)
+
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                x = self.conv_shortcut(x)
+            else:
+                x = self.nin_shortcut(x)
+
+        return x+h
+
+
+class AttnBlock(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.k = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.v = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.proj_out = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=1,
+                                        stride=1,
+                                        padding=0)
+
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        b,c,h,w = q.shape
+        q = q.reshape(b,c,h*w)
+        q = q.permute(0,2,1)   # b,hw,c
+        k = k.reshape(b,c,h*w) # b,c,hw
+        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+        w_ = w_ * (int(c)**(-0.5))
+        w_ = torch.nn.functional.softmax(w_, dim=2)
+
+        # attend to values
+        v = v.reshape(b,c,h*w)
+        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
+        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+        h_ = h_.reshape(b,c,h,w)
+
+        h_ = self.proj_out(h_)
+
+        return x+h_
+
+
+class Model(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, use_timestep=True):
+        super().__init__()
+        self.ch = ch
+        self.temb_ch = self.ch*4
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+
+        self.use_timestep = use_timestep
+        if self.use_timestep:
+            # timestep embedding
+            self.temb = nn.Module()
+            self.temb.dense = nn.ModuleList([
+                torch.nn.Linear(self.ch,
+                                self.temb_ch),
+                torch.nn.Linear(self.temb_ch,
+                                self.temb_ch),
+            ])
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(in_channels,
+                                       self.ch,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        curr_res = resolution
+        in_ch_mult = (1,)+tuple(ch_mult)
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch*in_ch_mult[i_level]
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(AttnBlock(block_in))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions-1:
+                down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = AttnBlock(block_in)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch*ch_mult[i_level]
+            skip_in = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks+1):
+                if i_block == self.num_res_blocks:
+                    skip_in = ch*in_ch_mult[i_level]
+                block.append(ResnetBlock(in_channels=block_in+skip_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(AttnBlock(block_in))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up) # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_ch,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+
+    def forward(self, x, t=None):
+        #assert x.shape[2] == x.shape[3] == self.resolution
+
+        if self.use_timestep:
+            # timestep embedding
+            assert t is not None
+            temb = get_timestep_embedding(t, self.ch)
+            temb = self.temb.dense[0](temb)
+            temb = nonlinearity(temb)
+            temb = self.temb.dense[1](temb)
+        else:
+            temb = None
+
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions-1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks+1):
+                h = self.up[i_level].block[i_block](
+                    torch.cat([h, hs.pop()], dim=1), temb)
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+class Encoder(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, z_channels, double_z=True, **ignore_kwargs):
+        super().__init__()
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(in_channels,
+                                       self.ch,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        curr_res = resolution
+        in_ch_mult = (1,)+tuple(ch_mult)
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch*in_ch_mult[i_level]
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(AttnBlock(block_in))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions-1:
+                down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = AttnBlock(block_in)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        2*z_channels if double_z else z_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+
+    def forward(self, x):
+        #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
+
+        # timestep embedding
+        temb = None
+
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions-1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+def Decoder2(**ddconfig):
+    return nn.Sequential(
+        torch.nn.Conv2d(ddconfig.get("z_channels"), ddconfig.get("z_channels"), 1),
+        Decoder(**ddconfig)
+    )
+
+class Decoder(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, z_channels, give_pre_end=False, **ignorekwargs):
+        super().__init__()
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+        self.give_pre_end = give_pre_end
+
+        # compute in_ch_mult, block_in and curr_res at lowest res
+        in_ch_mult = (1,)+tuple(ch_mult)
+        block_in = ch*ch_mult[self.num_resolutions-1]
+        curr_res = resolution // 2**(self.num_resolutions-1)
+        self.z_shape = (1,z_channels,curr_res,curr_res)
+        print("Working with z of shape {} = {} dimensions.".format(
+            self.z_shape, np.prod(self.z_shape)))
+
+        # z to block_in
+        self.conv_in = torch.nn.Conv2d(z_channels,
+                                       block_in,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = AttnBlock(block_in)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks+1):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(AttnBlock(block_in))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up) # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_ch,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, z):
+        #assert z.shape[1:] == self.z_shape[1:]
+        self.last_z_shape = z.shape
+
+        # timestep embedding
+        temb = None
+
+        # z to block_in
+        h = self.conv_in(z)
+
+        # middle
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks+1):
+                h = self.up[i_level].block[i_block](h, temb)
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        if self.give_pre_end:
+            return h
+
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+class VUNet(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True,
+                 in_channels, c_channels,
+                 resolution, z_channels, use_timestep=False, **ignore_kwargs):
+        super().__init__()
+        self.ch = ch
+        self.temb_ch = self.ch*4
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+
+        self.use_timestep = use_timestep
+        if self.use_timestep:
+            # timestep embedding
+            self.temb = nn.Module()
+            self.temb.dense = nn.ModuleList([
+                torch.nn.Linear(self.ch,
+                                self.temb_ch),
+                torch.nn.Linear(self.temb_ch,
+                                self.temb_ch),
+            ])
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(c_channels,
+                                       self.ch,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        curr_res = resolution
+        in_ch_mult = (1,)+tuple(ch_mult)
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch*in_ch_mult[i_level]
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(AttnBlock(block_in))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions-1:
+                down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        self.z_in = torch.nn.Conv2d(z_channels,
+                                    block_in,
+                                    kernel_size=1,
+                                    stride=1,
+                                    padding=0)
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = AttnBlock(block_in)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch*ch_mult[i_level]
+            skip_in = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks+1):
+                if i_block == self.num_res_blocks:
+                    skip_in = ch*in_ch_mult[i_level]
+                block.append(ResnetBlock(in_channels=block_in+skip_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(AttnBlock(block_in))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up) # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_ch,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+
+    def forward(self, x, z):
+        #assert x.shape[2] == x.shape[3] == self.resolution
+
+        if self.use_timestep:
+            # timestep embedding
+            assert t is not None
+            temb = get_timestep_embedding(t, self.ch)
+            temb = self.temb.dense[0](temb)
+            temb = nonlinearity(temb)
+            temb = self.temb.dense[1](temb)
+        else:
+            temb = None
+
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions-1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        z = self.z_in(z)
+        h = torch.cat((h,z),dim=1)
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks+1):
+                h = self.up[i_level].block[i_block](
+                    torch.cat([h, hs.pop()], dim=1), temb)
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+class SimpleDecoder(nn.Module):
+    def __init__(self, in_channels, out_channels, *args, **kwargs):
+        super().__init__()
+        self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+                                     ResnetBlock(in_channels=in_channels,
+                                                 out_channels=2 * in_channels,
+                                                 temb_channels=0, dropout=0.0),
+                                     ResnetBlock(in_channels=2 * in_channels,
+                                                out_channels=4 * in_channels,
+                                                temb_channels=0, dropout=0.0),
+                                     ResnetBlock(in_channels=4 * in_channels,
+                                                out_channels=2 * in_channels,
+                                                temb_channels=0, dropout=0.0),
+                                     nn.Conv2d(2*in_channels, in_channels, 1),
+                                     Upsample(in_channels, with_conv=True)])
+        # end
+        self.norm_out = Normalize(in_channels)
+        self.conv_out = torch.nn.Conv2d(in_channels,
+                                        out_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        for i, layer in enumerate(self.model):
+            if i in [1,2,3]:
+                x = layer(x, None)
+            else:
+                x = layer(x)
+
+        h = self.norm_out(x)
+        h = nonlinearity(h)
+        x = self.conv_out(h)
+        return x
+
+
+class UpsampleDecoder(nn.Module):
+    def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+                 ch_mult=(2,2), dropout=0.0):
+        super().__init__()
+        # upsampling
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        block_in = in_channels
+        curr_res = resolution // 2 ** (self.num_resolutions - 1)
+        self.res_blocks = nn.ModuleList()
+        self.upsample_blocks = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            res_block = []
+            block_out = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks + 1):
+                res_block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+            self.res_blocks.append(nn.ModuleList(res_block))
+            if i_level != self.num_resolutions - 1:
+                self.upsample_blocks.append(Upsample(block_in, True))
+                curr_res = curr_res * 2
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        # upsampling
+        h = x
+        for k, i_level in enumerate(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks + 1):
+                h = self.res_blocks[i_level][i_block](h, None)
+            if i_level != self.num_resolutions - 1:
+                h = self.upsample_blocks[k](h)
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
diff --git a/vqvae/vqvae_zc.py b/vqvae/vqvae_zc.py
index a018ad5e0c8449548a27bac468f598f27f57b2a3..ca2b360ad54e33e9b4da22b7f7d772917caf5f16 100755
--- a/vqvae/vqvae_zc.py
+++ b/vqvae/vqvae_zc.py
@@ -2,6 +2,8 @@ import torch
 from torch import nn
 from torch.nn import functional as F
 
+from .vqvae_diffusion import Decoder as DifDecoder
+
 # import distributed as dist_fn
 
 # Copyright 2018 The Sonnet Authors. All Rights Reserved.
@@ -225,22 +227,27 @@ class VQVAE(nn.Module):
         n_embed=1024,
         stride=4,
         simple=True,
-        decay=0.99
+        decay=0.99,
+        dif=False,
+        ddconfig=None
     ):
         super().__init__()
         if channel == 2048:
             n_res_block = 0
         self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride, embed_dim, n_embed, simple)
         self.quantize_t = Quantize(embed_dim, n_embed)
-        self.dec = Decoder(
-            in_channel=embed_dim, 
-            out_channel=in_channel,
-            channel=channel,
-            n_res_block=n_res_block,
-            n_res_channel=n_res_channel,
-            stride=stride-2,
-            simple=simple
-        )
+        if dif:
+            self.dec = DifDecoder(**ddconfig)
+        else:
+            self.dec = Decoder(
+                in_channel=embed_dim, 
+                out_channel=in_channel,
+                channel=channel,
+                n_res_block=n_res_block,
+                n_res_channel=n_res_channel,
+                stride=stride-2,
+                simple=simple
+            )
          
     def forward(self, input, continuous_relax=False, temperature=1., hard=False, KL=False):
         quant_t, diff, _, = self.encode(input, continuous_relax, temperature, hard, KL)