From 59585e171ac12aab60cd30b1ac537df4a787f6d9 Mon Sep 17 00:00:00 2001
From: minkowski0125 <zwd18@mails.tsinghua.edu.cn>
Date: Wed, 15 Dec 2021 14:10:32 +0000
Subject: [PATCH] add new vqvae module

---
 .../tokenization/cogview/vqvae/README.md      |  44 --
 .../cogview/vqvae/distributed/__init__.py     |  13 +
 .../cogview/vqvae/distributed/launch.py       |  92 +++++
 .../tokenization/cogview/vqvae/enc_dec.py     | 386 ++++++++++++++++++
 .../tokenization/cogview/vqvae/quantize.py    | 176 ++++++++
 .../tokenization/cogview/vqvae/vqvae.py       | 111 +++++
 6 files changed, 778 insertions(+), 44 deletions(-)
 delete mode 100755 SwissArmyTransformer/tokenization/cogview/vqvae/README.md
 create mode 100644 SwissArmyTransformer/tokenization/cogview/vqvae/distributed/__init__.py
 create mode 100644 SwissArmyTransformer/tokenization/cogview/vqvae/distributed/launch.py
 create mode 100644 SwissArmyTransformer/tokenization/cogview/vqvae/enc_dec.py
 create mode 100644 SwissArmyTransformer/tokenization/cogview/vqvae/quantize.py
 create mode 100644 SwissArmyTransformer/tokenization/cogview/vqvae/vqvae.py

diff --git a/SwissArmyTransformer/tokenization/cogview/vqvae/README.md b/SwissArmyTransformer/tokenization/cogview/vqvae/README.md
deleted file mode 100755
index b73bf2f..0000000
--- a/SwissArmyTransformer/tokenization/cogview/vqvae/README.md
+++ /dev/null
@@ -1,44 +0,0 @@
-# vq-vae-2-pytorch
-Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 in PyTorch
-
-## Update
-
-* 2020-06-01
-
-train_vqvae.py and vqvae.py now supports distributed training. You can use --n_gpu [NUM_GPUS] arguments for train_vqvae.py to use [NUM_GPUS] during training.
-
-## Requisite
-
-* Python >= 3.6
-* PyTorch >= 1.1
-* lmdb (for storing extracted codes)
-
-[Checkpoint of VQ-VAE pretrained on FFHQ](vqvae_560.pt)
-
-## Usage
-
-Currently supports 256px (top/bottom hierarchical prior)
-
-1. Stage 1 (VQ-VAE)
-
-> python train_vqvae.py [DATASET PATH]
-
-If you use FFHQ, I highly recommends to preprocess images. (resize and convert to jpeg)
-
-2. Extract codes for stage 2 training
-
-> python extract_code.py --ckpt checkpoint/[VQ-VAE CHECKPOINT] --name [LMDB NAME] [DATASET PATH]
-
-3. Stage 2 (PixelSNAIL)
-
-> python train_pixelsnail.py [LMDB NAME]
-
-Maybe it is better to use larger PixelSNAIL model. Currently model size is reduced due to GPU constraints.
-
-## Sample
-
-### Stage 1
-
-Note: This is a training sample
-
-![Sample from Stage 1 (VQ-VAE)](stage1_sample.png)
diff --git a/SwissArmyTransformer/tokenization/cogview/vqvae/distributed/__init__.py b/SwissArmyTransformer/tokenization/cogview/vqvae/distributed/__init__.py
new file mode 100644
index 0000000..b944d98
--- /dev/null
+++ b/SwissArmyTransformer/tokenization/cogview/vqvae/distributed/__init__.py
@@ -0,0 +1,13 @@
+from .distributed import (
+    get_rank,
+    get_local_rank,
+    is_primary,
+    synchronize,
+    get_world_size,
+    all_reduce,
+    all_gather,
+    reduce_dict,
+    data_sampler,
+    LOCAL_PROCESS_GROUP,
+)
+from .launch import launch
diff --git a/SwissArmyTransformer/tokenization/cogview/vqvae/distributed/launch.py b/SwissArmyTransformer/tokenization/cogview/vqvae/distributed/launch.py
new file mode 100644
index 0000000..c03326d
--- /dev/null
+++ b/SwissArmyTransformer/tokenization/cogview/vqvae/distributed/launch.py
@@ -0,0 +1,92 @@
+import os
+
+import torch
+from torch import distributed as dist
+from torch import multiprocessing as mp
+
+import distributed as dist_fn
+
+
+def find_free_port():
+    import socket
+
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+
+    sock.bind(("", 0))
+    port = sock.getsockname()[1]
+    sock.close()
+
+    return port
+
+
+def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()):
+    world_size = n_machine * n_gpu_per_machine
+
+    if world_size > 1:
+        if "OMP_NUM_THREADS" not in os.environ:
+            os.environ["OMP_NUM_THREADS"] = "1"
+
+        if dist_url == "auto":
+            if n_machine != 1:
+                raise ValueError('dist_url="auto" not supported in multi-machine jobs')
+
+            port = find_free_port()
+            dist_url = f"tcp://127.0.0.1:{port}"
+
+        if n_machine > 1 and dist_url.startswith("file://"):
+            raise ValueError(
+                "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://"
+            )
+
+        mp.spawn(
+            distributed_worker,
+            nprocs=n_gpu_per_machine,
+            args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args),
+            daemon=False,
+        )
+
+    else:
+        fn(*args)
+
+
+def distributed_worker(
+    local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args
+):
+    if not torch.cuda.is_available():
+        raise OSError("CUDA is not available. Please check your environments")
+
+    global_rank = machine_rank * n_gpu_per_machine + local_rank
+
+    try:
+        dist.init_process_group(
+            backend="NCCL",
+            init_method=dist_url,
+            world_size=world_size,
+            rank=global_rank,
+        )
+
+    except Exception:
+        raise OSError("failed to initialize NCCL groups")
+
+    dist_fn.synchronize()
+
+    if n_gpu_per_machine > torch.cuda.device_count():
+        raise ValueError(
+            f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})"
+        )
+
+    torch.cuda.set_device(local_rank)
+
+    if dist_fn.LOCAL_PROCESS_GROUP is not None:
+        raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None")
+
+    n_machine = world_size // n_gpu_per_machine
+
+    for i in range(n_machine):
+        ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine))
+        pg = dist.new_group(ranks_on_i)
+
+        if i == machine_rank:
+            dist_fn.distributed.LOCAL_PROCESS_GROUP = pg
+
+    fn(*args)
diff --git a/SwissArmyTransformer/tokenization/cogview/vqvae/enc_dec.py b/SwissArmyTransformer/tokenization/cogview/vqvae/enc_dec.py
new file mode 100644
index 0000000..20f87fc
--- /dev/null
+++ b/SwissArmyTransformer/tokenization/cogview/vqvae/enc_dec.py
@@ -0,0 +1,386 @@
+import math
+import torch
+from torch import nn
+import torch.nn.functional as F
+import numpy as np
+
+def nonlinearity(x):
+    return x * torch.sigmoid(x)
+
+def Normalize(in_channels):
+    return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+class Upsample(nn.Module):
+    def __init__(self, 
+                    in_channels,
+                    with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if with_conv:
+            self.conv = nn.Conv2d(in_channels,
+                                    in_channels,
+                                    kernel_size=3,
+                                    stride=1,
+                                    padding=1)
+    
+    def forward(self, x):
+        x = F.interpolate(x, scale_factor=2., mode="nearest")
+        if self.with_conv:
+            x = self.conv(x)
+        return x
+
+class DownSample(nn.Module):
+    def __init__(self,
+                    in_channels,
+                    with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if with_conv:
+            self.conv = nn.Conv2d(in_channels,
+                                    in_channels,
+                                    kernel_size=3,
+                                    stride=2,
+                                    padding=0)
+    
+    def forward(self, x):
+        if self.with_conv:
+            pad = (0, 1, 0, 1)
+            x = F.pad(x, pad, mode='constant', value=0)
+            x = self.conv(x)
+        else:
+            x = F.avg_pool2d(x, kernel_size=2, stride=2)
+        return x
+
+class ResidualDownSample(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+        self.pooling_down_sampler = DownSample(in_channels, with_conv=False)
+        self.conv_down_sampler = DownSample(in_channels, with_conv=True)
+
+    def forward(self, x):
+        return self.pooling_down_sampler(x) + self.conv_down_sampler(x)
+
+class ResnetBlock(nn.Module):
+    def __init__(self,
+                    in_channels,
+                    dropout,
+                    out_channels=None,
+                    conv_shortcut=False):
+        super().__init__()
+        self.in_channels = in_channels
+        out_channels = in_channels if out_channels is None else out_channels
+        self.out_channels = out_channels
+        self.use_conv_shortcut = conv_shortcut
+
+        self.norm1 = Normalize(in_channels)
+        self.conv1 = nn.Conv2d(in_channels,
+                                out_channels,
+                                kernel_size=3,
+                                stride=1,
+                                padding=1)
+
+        self.norm2 = Normalize(out_channels)
+        self.dropout = nn.Dropout(dropout)
+        self.conv2 = nn.Conv2d(out_channels,
+                                out_channels,
+                                kernel_size=3,
+                                stride=1,
+                                padding=1)
+        if in_channels != out_channels:
+            if conv_shortcut:
+                self.conv_shortcut = nn.Conv2d(in_channels,
+                                                out_channels,
+                                                kernel_size=3,
+                                                stride=1,
+                                                padding=1)
+            else:
+                self.nin_shortcut = nn.Conv2d(in_channels,
+                                                out_channels,
+                                                kernel_size=1,
+                                                stride=1,
+                                                padding=0)
+                                            
+    def forward(self, x):
+        h = x
+        h = self.norm1(h)
+        h = nonlinearity(h)
+        h = self.conv1(h)
+
+        h = self.norm2(h)
+        h = nonlinearity(h)
+        h = self.dropout(h)
+        h = self.conv2(h)
+
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                x = self.conv_shortcut(x)
+            else:
+                x = self.nin_shortcut(x)
+
+        return x + h
+
+class AttnBlock(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = nn.Conv2d(in_channels,
+                            in_channels,
+                            kernel_size=1,
+                            stride=1,
+                            padding=0)
+        self.k = nn.Conv2d(in_channels,
+                            in_channels,
+                            kernel_size=1,
+                            stride=1,
+                            padding=0)
+        self.v = nn.Conv2d(in_channels,
+                            in_channels,
+                            kernel_size=1,
+                            stride=1,
+                            padding=0)
+        self.proj_out = nn.Conv2d(in_channels,
+                                    in_channels,
+                                    kernel_size=1, 
+                                    stride=1,
+                                    padding=0)
+    
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        B, C, H, W = q.shape
+        q = q.reshape(B, C, -1)
+        q = q.permute(0, 2, 1) # (B, H*W, C)
+        k = k.reshape(B, C, -1) # (B, C, H*W)
+        w_ = torch.bmm(q, k) # (B, H*W, H*W)
+        w_ = w_ * C**(-0.5)
+        w_ = F.softmax(w_, dim=2)
+
+        v = v.reshape(B, C, -1) # (B, C, H*W)
+        w_ = w_.permute(0, 2, 1)
+        h_ = torch.bmm(v, w_)
+        h_ = h_.reshape(B, C, H, W)
+
+        h_ = self.proj_out(h_)
+
+        return x + h_
+
+class Encoder(nn.Module):
+    def __init__(self,
+                    in_channels,
+                    out_channels,
+                    z_channels,
+                    channels,
+                    num_res_blocks,
+                    resolution,
+                    attn_resolutions,
+                    resample_with_conv=True,
+                    channels_mult=(1,2,4,8),
+                    dropout=0.
+                    ):
+        super().__init__()
+
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.z_channels = z_channels
+        self.channels = channels
+        self.num_resolutions = len(channels_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+
+        self.conv_in = nn.Conv2d(in_channels,
+                                    channels,
+                                    kernel_size=3,
+                                    stride=1,
+                                    padding=1)
+
+        current_resolution = resolution
+        in_channels_mult = (1,) + tuple(channels_mult)
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = channels * in_channels_mult[i_level]
+            block_out = channels * channels_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(ResnetBlock(in_channels=block_in,
+                                            out_channels=block_out,
+                                            dropout=dropout))
+                block_in = block_out
+                if current_resolution in attn_resolutions:
+                    attn.append(AttnBlock(block_in))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions - 1:
+                down.downsample = DownSample(block_in,
+                                                resample_with_conv)
+                current_resolution = current_resolution // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                        out_channels=block_in,
+                                        dropout=dropout)
+        self.mid.attn_1 = AttnBlock(block_in)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                        out_channels=block_in,
+                                        dropout=dropout)
+        
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = nn.Conv2d(block_in,
+                                    z_channels,
+                                    kernel_size=3,
+                                    stride=1,
+                                    padding=1)
+
+    def test_forward(self, x):
+        # downsample
+        import pdb
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1])
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions - 1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+            
+        return hs
+
+    def forward(self, x):
+        # downsample
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1])
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions - 1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        h = self.mid.block_1(h)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+
+        return h
+
+class Decoder(nn.Module):
+    def __init__(self,
+                    in_channels,
+                    out_channels,
+                    z_channels,
+                    channels,
+                    num_res_blocks,
+                    resolution,
+                    attn_resolutions,
+                    channels_mult=(1,2,4,8),
+                    resample_with_conv=True,
+                    dropout=0.
+                    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.z_channels = z_channels
+        self.channels = channels
+        self.num_resolutions = len(channels_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        
+        in_channels_mult = (1,) + tuple(channels_mult)
+        block_in = channels * channels_mult[self.num_resolutions - 1]
+        current_resolution = resolution // 2**(self.num_resolutions - 1)
+        self.z_shape = (1, z_channels, current_resolution, current_resolution)
+
+        # z to block_in
+        self.conv_in = nn.Conv2d(z_channels,
+                                    block_in,
+                                    kernel_size=3,
+                                    stride=1,
+                                    padding=1)
+        
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                        out_channels=block_in,
+                                        dropout=dropout)
+        self.mid.attn_1 = AttnBlock(block_in)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                        out_channels=block_in,
+                                        dropout=dropout)
+        
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = channels * channels_mult[i_level]
+            for i_block in range(self.num_res_blocks + 1):
+                block.append(ResnetBlock(in_channels=block_in,
+                                            out_channels=block_out,
+                                            dropout=dropout))
+                block_in = block_out
+                if current_resolution in attn_resolutions:
+                    attn.append(AttnBlock(block_in))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in,
+                                        resample_with_conv)
+                current_resolution = current_resolution * 2
+            self.up.insert(0, up)
+        
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = nn.Conv2d(block_in,
+                                    out_channels,
+                                    kernel_size=3,
+                                    stride=1,
+                                    padding=1)
+
+    def forward(self, z):
+        self.last_z_shape = z.shape
+
+        # z to block_in
+        h = self.conv_in(z)
+
+        # middle
+        h = self.mid.block_1(h)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks + 1):
+                h = self.up[i_level].block[i_block](h)
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+        
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+    def get_last_layer(self):
+        return self.conv_out.weight
diff --git a/SwissArmyTransformer/tokenization/cogview/vqvae/quantize.py b/SwissArmyTransformer/tokenization/cogview/vqvae/quantize.py
new file mode 100644
index 0000000..be55f23
--- /dev/null
+++ b/SwissArmyTransformer/tokenization/cogview/vqvae/quantize.py
@@ -0,0 +1,176 @@
+import torch
+from torch import nn
+from torch import einsum
+from torch.nn import functional as F
+
+import distributed as dist_fn
+
+class VectorQuantize(nn.Module):
+    def __init__(self, 
+                    hidden_dim,
+                    embedding_dim,
+                    n_embed,
+                    commitment_cost=1):
+        super().__init__()
+        
+        self.hidden_dim = hidden_dim
+        self.embedding_dim = embedding_dim
+        self.n_embed = n_embed
+        self.commitment_cost = commitment_cost
+
+        self.proj = nn.Conv2d(hidden_dim, embedding_dim, 1)
+        self.embed = nn.Embedding(n_embed, embedding_dim)
+        self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed)
+
+    def forward(self, z):
+        B, C, H, W = z.shape
+
+        z_e = self.proj(z)
+        z_e = z_e.permute(0, 2, 3, 1) # (B, H, W, C)
+        flatten = z_e.reshape(-1, self.embedding_dim)
+
+        dist = (
+            flatten.pow(2).sum(1, keepdim=True)
+            - 2 * flatten @ self.embed.weight.t()
+            + self.embed.weight.pow(2).sum(1, keepdim=True).t()
+        )
+        _, embed_ind = (-dist).max(1)
+        embed_ind = embed_ind.view(B, H, W)
+
+        z_q = self.embed_code(embed_ind)
+        diff = self.commitment_cost * (z_q.detach() - z_e).pow(2).mean() \
+                + (z_q - z_e.detach()).pow(2).mean()
+
+        z_q = z_e + (z_q - z_e).detach()
+        return z_q, diff, embed_ind
+
+    def embed_code(self, embed_id):
+        return F.embedding(embed_id, self.embed.weight)
+    
+
+class VectorQuantizeEMA(nn.Module):
+    def __init__(self,
+                    hidden_dim,
+                    embedding_dim,
+                    n_embed,
+                    commitment_cost=1,
+                    decay=0.99,
+                    eps=1e-5,
+                    pre_proj=True,
+                    training_loc=True):
+        super().__init__()
+        
+        self.hidden_dim = hidden_dim
+        self.embedding_dim = embedding_dim
+        self.n_embed = n_embed
+        self.commitment_cost = commitment_cost
+        self.training_loc = training_loc
+        
+        self.pre_proj = pre_proj
+        if self.pre_proj:
+            self.proj = nn.Conv2d(hidden_dim, embedding_dim, 1)
+        self.embed = nn.Embedding(n_embed, embedding_dim)
+        self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed)
+        
+        self.register_buffer("cluster_size", torch.zeros(n_embed))
+        self.register_buffer("embed_avg", self.embed.weight.data.clone())
+        
+        self.decay = decay
+        self.eps = eps
+        
+    def forward(self, z):
+        B, C, H, W = z.shape
+        
+        if self.pre_proj:
+            z_e = self.proj(z)
+        else:
+            z_e = z
+        z_e = z_e.permute(0, 2, 3, 1) # (B, H, W, C)
+        flatten = z_e.reshape(-1, self.embedding_dim)
+
+        dist = (
+            flatten.pow(2).sum(1, keepdim=True)
+            - 2 * flatten @ self.embed.weight.t()
+            + self.embed.weight.pow(2).sum(1, keepdim=True).t()
+        )
+        _, embed_ind = (-dist).max(1)
+        embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
+        embed_ind = embed_ind.view(B, H, W)
+
+        z_q = self.embed_code(embed_ind)
+        
+        if self.training_loc and self.training:
+            embed_onehot_sum = embed_onehot.sum(0)
+            embed_sum = (flatten.transpose(0, 1) @ embed_onehot).transpose(0, 1)
+            
+            dist_fn.all_reduce(embed_onehot_sum)
+            dist_fn.all_reduce(embed_sum)
+            
+            self.cluster_size.data.mul_(self.decay).add_(
+                embed_onehot_sum, alpha=1-self.decay
+            )
+            self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1-self.decay)
+            n = self.cluster_size.sum()
+            cluster_size = (
+                (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
+            )
+            embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
+            self.embed.weight.data.copy_(embed_normalized)
+        
+        diff = self.commitment_cost * (z_q.detach() - z_e).pow(2).mean()
+
+        z_q = z_e + (z_q - z_e).detach()
+        return z_q, diff, embed_ind
+
+    def embed_code(self, embed_id):
+        return F.embedding(embed_id, self.embed.weight)
+ 
+
+class GumbelQuantize(nn.Module):
+    def __init__(self,
+                    hidden_dim,
+                    embedding_dim,
+                    n_embed,
+                    commitment_cost=1,
+                    straight_through=True,
+                    kl_weight=5e-4,
+                    temp_init=1.,
+                    eps=1e-5):
+        super().__init__()
+        
+        self.hidden_dim = hidden_dim
+        self.embedding_dim = embedding_dim
+        self.n_embed = n_embed
+        self.commitment_cost = commitment_cost
+        
+        self.kl_weight = kl_weight
+        self.temperature = temp_init
+        self.eps = eps
+        
+        self.proj = nn.Conv2d(hidden_dim, n_embed, 1)
+        self.embed = nn.Embedding(n_embed, embedding_dim)
+        self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed)
+        
+        self.straight_through = straight_through
+        
+    def forward(self, z, temp=None):
+        hard = self.straight_through if self.training else True
+        temp = self.temperature if temp is None else temp
+        
+        B, C, H, W = z.shape
+        
+        z_e = self.proj(z)
+        
+        soft_one_hot = F.gumbel_softmax(z_e, tau=temp, dim=1, hard=hard)
+        z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
+        
+        qy = F.softmax(z_e, dim=1)
+        diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + self.eps), dim=1).mean()
+        
+        embed_ind = soft_one_hot.argmax(dim=1)
+        z_q = z_q.permute(0, 2, 3, 1)
+        return z_q, diff, embed_ind
+        
+    def embed_code(self, embed_id):
+        return F.embedding(embed_id, self.embed.weight)
+ 
\ No newline at end of file
diff --git a/SwissArmyTransformer/tokenization/cogview/vqvae/vqvae.py b/SwissArmyTransformer/tokenization/cogview/vqvae/vqvae.py
new file mode 100644
index 0000000..6b93a47
--- /dev/null
+++ b/SwissArmyTransformer/tokenization/cogview/vqvae/vqvae.py
@@ -0,0 +1,111 @@
+import torch
+from torch import nn
+import json
+import os
+
+class VQVAE(nn.Module):
+    def __init__(self,
+                    enc_config,
+                    dec_config,
+                    quantize_config):
+        super().__init__()
+        
+        self.enc = new_module(enc_config)
+        self.dec = new_module(dec_config)
+        self.quantize = new_module(quantize_config)
+
+    def forward(self, input):
+        quant_t, diff_t, _ = self.encode(input)
+
+        return self.decode(quant_t), diff_t
+    
+    def encode(self, input):
+        logits = self.enc(input)
+        quant_t, diff_t, id_t = self.quantize.forward(logits)
+        quant_t = quant_t.permute(0, 3, 1, 2)
+        # diff_t = diff_t.unsqueeze(0)
+
+        return quant_t, diff_t, id_t
+
+    def decode(self, code):
+        return [self.dec(code)]
+    
+    def decode_code(self, code_t):
+        quant_t = self.quantize.embed_code(code_t)
+        quant_t = quant_t.permute(0, 3, 1, 2)
+        return self.decode(quant_t)
+    
+    def get_last_layer(self):
+        return self.dec.get_last_layer()
+    
+
+class HVQVAE(nn.Module):
+    def __init__(
+        self,
+        levels,
+        embedding_dim,
+        enc_config,
+        quantize_config,
+        down_sampler_configs,
+        dec_configs,
+        codebook_scale=1.
+    ):
+        super().__init__()
+        
+        self.levels = levels
+
+        self.enc = new_module(enc_config)
+            
+        self.decs = nn.ModuleList()
+        for i in range(levels):
+            self.decs.append(new_module(dec_configs[i]))
+            
+        self.quantize = new_module(quantize_config)
+        self.down_samplers = nn.ModuleList()
+        for i in range(levels-1):
+            self.down_samplers.append(new_module(down_sampler_configs[i]))
+        self.codebook_scale = codebook_scale
+            
+    def forward(self, input):        
+        quants, diffs, ids = self.encode(input)
+        dec_outputs = self.decode(quants[::-1])
+        
+        total_diff = diffs[0]
+        scale = 1.
+        for diff in diffs[1:]:
+            scale *= self.codebook_scale
+            total_diff = total_diff + diff * scale
+        return dec_outputs, total_diff
+
+    def encode(self, input):
+        enc_output = self.enc(input)
+        enc_outputs = [enc_output]
+        for l in range(self.levels-1):
+            enc_outputs.append(self.down_samplers[l](enc_outputs[-1]))
+       
+        quants, diffs, ids = [], [], []
+        for enc_output in enc_outputs:
+            quant, diff, id = self.quantize(enc_output)
+            quants.append(quant.permute(0, 3, 1, 2))
+            diffs.append(diff)
+            ids.append(id)
+            
+        return quants, diffs, ids
+        
+    def decode(self, quants):
+        dec_outputs = []
+        for l in range(self.levels-1, -1, -1):
+            dec_outputs.append(self.decs[l](quants[l]))
+            
+        return dec_outputs
+
+    def decode_code(self, codes):
+        quants = []
+        for l in range(self.levels):
+            quants.append(self.quantize.embed_code(codes[l]).permute(0, 3, 1, 2))
+        dec_outputs = self.decode(quants)
+
+        return dec_outputs
+    
+    def get_last_layer(self):
+        return self.decs[-1].get_last_layer()
\ No newline at end of file
-- 
GitLab