From 59585e171ac12aab60cd30b1ac537df4a787f6d9 Mon Sep 17 00:00:00 2001 From: minkowski0125 <zwd18@mails.tsinghua.edu.cn> Date: Wed, 15 Dec 2021 14:10:32 +0000 Subject: [PATCH] add new vqvae module --- .../tokenization/cogview/vqvae/README.md | 44 -- .../cogview/vqvae/distributed/__init__.py | 13 + .../cogview/vqvae/distributed/launch.py | 92 +++++ .../tokenization/cogview/vqvae/enc_dec.py | 386 ++++++++++++++++++ .../tokenization/cogview/vqvae/quantize.py | 176 ++++++++ .../tokenization/cogview/vqvae/vqvae.py | 111 +++++ 6 files changed, 778 insertions(+), 44 deletions(-) delete mode 100755 SwissArmyTransformer/tokenization/cogview/vqvae/README.md create mode 100644 SwissArmyTransformer/tokenization/cogview/vqvae/distributed/__init__.py create mode 100644 SwissArmyTransformer/tokenization/cogview/vqvae/distributed/launch.py create mode 100644 SwissArmyTransformer/tokenization/cogview/vqvae/enc_dec.py create mode 100644 SwissArmyTransformer/tokenization/cogview/vqvae/quantize.py create mode 100644 SwissArmyTransformer/tokenization/cogview/vqvae/vqvae.py diff --git a/SwissArmyTransformer/tokenization/cogview/vqvae/README.md b/SwissArmyTransformer/tokenization/cogview/vqvae/README.md deleted file mode 100755 index b73bf2f..0000000 --- a/SwissArmyTransformer/tokenization/cogview/vqvae/README.md +++ /dev/null @@ -1,44 +0,0 @@ -# vq-vae-2-pytorch -Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 in PyTorch - -## Update - -* 2020-06-01 - -train_vqvae.py and vqvae.py now supports distributed training. You can use --n_gpu [NUM_GPUS] arguments for train_vqvae.py to use [NUM_GPUS] during training. - -## Requisite - -* Python >= 3.6 -* PyTorch >= 1.1 -* lmdb (for storing extracted codes) - -[Checkpoint of VQ-VAE pretrained on FFHQ](vqvae_560.pt) - -## Usage - -Currently supports 256px (top/bottom hierarchical prior) - -1. Stage 1 (VQ-VAE) - -> python train_vqvae.py [DATASET PATH] - -If you use FFHQ, I highly recommends to preprocess images. (resize and convert to jpeg) - -2. Extract codes for stage 2 training - -> python extract_code.py --ckpt checkpoint/[VQ-VAE CHECKPOINT] --name [LMDB NAME] [DATASET PATH] - -3. Stage 2 (PixelSNAIL) - -> python train_pixelsnail.py [LMDB NAME] - -Maybe it is better to use larger PixelSNAIL model. Currently model size is reduced due to GPU constraints. - -## Sample - -### Stage 1 - -Note: This is a training sample - - diff --git a/SwissArmyTransformer/tokenization/cogview/vqvae/distributed/__init__.py b/SwissArmyTransformer/tokenization/cogview/vqvae/distributed/__init__.py new file mode 100644 index 0000000..b944d98 --- /dev/null +++ b/SwissArmyTransformer/tokenization/cogview/vqvae/distributed/__init__.py @@ -0,0 +1,13 @@ +from .distributed import ( + get_rank, + get_local_rank, + is_primary, + synchronize, + get_world_size, + all_reduce, + all_gather, + reduce_dict, + data_sampler, + LOCAL_PROCESS_GROUP, +) +from .launch import launch diff --git a/SwissArmyTransformer/tokenization/cogview/vqvae/distributed/launch.py b/SwissArmyTransformer/tokenization/cogview/vqvae/distributed/launch.py new file mode 100644 index 0000000..c03326d --- /dev/null +++ b/SwissArmyTransformer/tokenization/cogview/vqvae/distributed/launch.py @@ -0,0 +1,92 @@ +import os + +import torch +from torch import distributed as dist +from torch import multiprocessing as mp + +import distributed as dist_fn + + +def find_free_port(): + import socket + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + + return port + + +def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()): + world_size = n_machine * n_gpu_per_machine + + if world_size > 1: + if "OMP_NUM_THREADS" not in os.environ: + os.environ["OMP_NUM_THREADS"] = "1" + + if dist_url == "auto": + if n_machine != 1: + raise ValueError('dist_url="auto" not supported in multi-machine jobs') + + port = find_free_port() + dist_url = f"tcp://127.0.0.1:{port}" + + if n_machine > 1 and dist_url.startswith("file://"): + raise ValueError( + "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://" + ) + + mp.spawn( + distributed_worker, + nprocs=n_gpu_per_machine, + args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args), + daemon=False, + ) + + else: + fn(*args) + + +def distributed_worker( + local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args +): + if not torch.cuda.is_available(): + raise OSError("CUDA is not available. Please check your environments") + + global_rank = machine_rank * n_gpu_per_machine + local_rank + + try: + dist.init_process_group( + backend="NCCL", + init_method=dist_url, + world_size=world_size, + rank=global_rank, + ) + + except Exception: + raise OSError("failed to initialize NCCL groups") + + dist_fn.synchronize() + + if n_gpu_per_machine > torch.cuda.device_count(): + raise ValueError( + f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})" + ) + + torch.cuda.set_device(local_rank) + + if dist_fn.LOCAL_PROCESS_GROUP is not None: + raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None") + + n_machine = world_size // n_gpu_per_machine + + for i in range(n_machine): + ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine)) + pg = dist.new_group(ranks_on_i) + + if i == machine_rank: + dist_fn.distributed.LOCAL_PROCESS_GROUP = pg + + fn(*args) diff --git a/SwissArmyTransformer/tokenization/cogview/vqvae/enc_dec.py b/SwissArmyTransformer/tokenization/cogview/vqvae/enc_dec.py new file mode 100644 index 0000000..20f87fc --- /dev/null +++ b/SwissArmyTransformer/tokenization/cogview/vqvae/enc_dec.py @@ -0,0 +1,386 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np + +def nonlinearity(x): + return x * torch.sigmoid(x) + +def Normalize(in_channels): + return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + +class Upsample(nn.Module): + def __init__(self, + in_channels, + with_conv): + super().__init__() + self.with_conv = with_conv + if with_conv: + self.conv = nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2., mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + +class DownSample(nn.Module): + def __init__(self, + in_channels, + with_conv): + super().__init__() + self.with_conv = with_conv + if with_conv: + self.conv = nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode='constant', value=0) + x = self.conv(x) + else: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + +class ResidualDownSample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + self.pooling_down_sampler = DownSample(in_channels, with_conv=False) + self.conv_down_sampler = DownSample(in_channels, with_conv=True) + + def forward(self, x): + return self.pooling_down_sampler(x) + self.conv_down_sampler(x) + +class ResnetBlock(nn.Module): + def __init__(self, + in_channels, + dropout, + out_channels=None, + conv_shortcut=False): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + self.norm2 = Normalize(out_channels) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if in_channels != out_channels: + if conv_shortcut: + self.conv_shortcut = nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + B, C, H, W = q.shape + q = q.reshape(B, C, -1) + q = q.permute(0, 2, 1) # (B, H*W, C) + k = k.reshape(B, C, -1) # (B, C, H*W) + w_ = torch.bmm(q, k) # (B, H*W, H*W) + w_ = w_ * C**(-0.5) + w_ = F.softmax(w_, dim=2) + + v = v.reshape(B, C, -1) # (B, C, H*W) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(B, C, H, W) + + h_ = self.proj_out(h_) + + return x + h_ + +class Encoder(nn.Module): + def __init__(self, + in_channels, + out_channels, + z_channels, + channels, + num_res_blocks, + resolution, + attn_resolutions, + resample_with_conv=True, + channels_mult=(1,2,4,8), + dropout=0. + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.z_channels = z_channels + self.channels = channels + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + + self.conv_in = nn.Conv2d(in_channels, + channels, + kernel_size=3, + stride=1, + padding=1) + + current_resolution = resolution + in_channels_mult = (1,) + tuple(channels_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_channels_mult[i_level] + block_out = channels * channels_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + dropout=dropout)) + block_in = block_out + if current_resolution in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = DownSample(block_in, + resample_with_conv) + current_resolution = current_resolution // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = nn.Conv2d(block_in, + z_channels, + kernel_size=3, + stride=1, + padding=1) + + def test_forward(self, x): + # downsample + import pdb + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + return hs + + def forward(self, x): + # downsample + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + + return h + +class Decoder(nn.Module): + def __init__(self, + in_channels, + out_channels, + z_channels, + channels, + num_res_blocks, + resolution, + attn_resolutions, + channels_mult=(1,2,4,8), + resample_with_conv=True, + dropout=0. + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.z_channels = z_channels + self.channels = channels + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + + in_channels_mult = (1,) + tuple(channels_mult) + block_in = channels * channels_mult[self.num_resolutions - 1] + current_resolution = resolution // 2**(self.num_resolutions - 1) + self.z_shape = (1, z_channels, current_resolution, current_resolution) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + dropout=dropout)) + block_in = block_out + if current_resolution in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, + resample_with_conv) + current_resolution = current_resolution * 2 + self.up.insert(0, up) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + self.last_z_shape = z.shape + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight diff --git a/SwissArmyTransformer/tokenization/cogview/vqvae/quantize.py b/SwissArmyTransformer/tokenization/cogview/vqvae/quantize.py new file mode 100644 index 0000000..be55f23 --- /dev/null +++ b/SwissArmyTransformer/tokenization/cogview/vqvae/quantize.py @@ -0,0 +1,176 @@ +import torch +from torch import nn +from torch import einsum +from torch.nn import functional as F + +import distributed as dist_fn + +class VectorQuantize(nn.Module): + def __init__(self, + hidden_dim, + embedding_dim, + n_embed, + commitment_cost=1): + super().__init__() + + self.hidden_dim = hidden_dim + self.embedding_dim = embedding_dim + self.n_embed = n_embed + self.commitment_cost = commitment_cost + + self.proj = nn.Conv2d(hidden_dim, embedding_dim, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed) + + def forward(self, z): + B, C, H, W = z.shape + + z_e = self.proj(z) + z_e = z_e.permute(0, 2, 3, 1) # (B, H, W, C) + flatten = z_e.reshape(-1, self.embedding_dim) + + dist = ( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ self.embed.weight.t() + + self.embed.weight.pow(2).sum(1, keepdim=True).t() + ) + _, embed_ind = (-dist).max(1) + embed_ind = embed_ind.view(B, H, W) + + z_q = self.embed_code(embed_ind) + diff = self.commitment_cost * (z_q.detach() - z_e).pow(2).mean() \ + + (z_q - z_e.detach()).pow(2).mean() + + z_q = z_e + (z_q - z_e).detach() + return z_q, diff, embed_ind + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.embed.weight) + + +class VectorQuantizeEMA(nn.Module): + def __init__(self, + hidden_dim, + embedding_dim, + n_embed, + commitment_cost=1, + decay=0.99, + eps=1e-5, + pre_proj=True, + training_loc=True): + super().__init__() + + self.hidden_dim = hidden_dim + self.embedding_dim = embedding_dim + self.n_embed = n_embed + self.commitment_cost = commitment_cost + self.training_loc = training_loc + + self.pre_proj = pre_proj + if self.pre_proj: + self.proj = nn.Conv2d(hidden_dim, embedding_dim, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed) + + self.register_buffer("cluster_size", torch.zeros(n_embed)) + self.register_buffer("embed_avg", self.embed.weight.data.clone()) + + self.decay = decay + self.eps = eps + + def forward(self, z): + B, C, H, W = z.shape + + if self.pre_proj: + z_e = self.proj(z) + else: + z_e = z + z_e = z_e.permute(0, 2, 3, 1) # (B, H, W, C) + flatten = z_e.reshape(-1, self.embedding_dim) + + dist = ( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ self.embed.weight.t() + + self.embed.weight.pow(2).sum(1, keepdim=True).t() + ) + _, embed_ind = (-dist).max(1) + embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) + embed_ind = embed_ind.view(B, H, W) + + z_q = self.embed_code(embed_ind) + + if self.training_loc and self.training: + embed_onehot_sum = embed_onehot.sum(0) + embed_sum = (flatten.transpose(0, 1) @ embed_onehot).transpose(0, 1) + + dist_fn.all_reduce(embed_onehot_sum) + dist_fn.all_reduce(embed_sum) + + self.cluster_size.data.mul_(self.decay).add_( + embed_onehot_sum, alpha=1-self.decay + ) + self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1-self.decay) + n = self.cluster_size.sum() + cluster_size = ( + (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.weight.data.copy_(embed_normalized) + + diff = self.commitment_cost * (z_q.detach() - z_e).pow(2).mean() + + z_q = z_e + (z_q - z_e).detach() + return z_q, diff, embed_ind + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.embed.weight) + + +class GumbelQuantize(nn.Module): + def __init__(self, + hidden_dim, + embedding_dim, + n_embed, + commitment_cost=1, + straight_through=True, + kl_weight=5e-4, + temp_init=1., + eps=1e-5): + super().__init__() + + self.hidden_dim = hidden_dim + self.embedding_dim = embedding_dim + self.n_embed = n_embed + self.commitment_cost = commitment_cost + + self.kl_weight = kl_weight + self.temperature = temp_init + self.eps = eps + + self.proj = nn.Conv2d(hidden_dim, n_embed, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed) + + self.straight_through = straight_through + + def forward(self, z, temp=None): + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + + B, C, H, W = z.shape + + z_e = self.proj(z) + + soft_one_hot = F.gumbel_softmax(z_e, tau=temp, dim=1, hard=hard) + z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) + + qy = F.softmax(z_e, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + self.eps), dim=1).mean() + + embed_ind = soft_one_hot.argmax(dim=1) + z_q = z_q.permute(0, 2, 3, 1) + return z_q, diff, embed_ind + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.embed.weight) + \ No newline at end of file diff --git a/SwissArmyTransformer/tokenization/cogview/vqvae/vqvae.py b/SwissArmyTransformer/tokenization/cogview/vqvae/vqvae.py new file mode 100644 index 0000000..6b93a47 --- /dev/null +++ b/SwissArmyTransformer/tokenization/cogview/vqvae/vqvae.py @@ -0,0 +1,111 @@ +import torch +from torch import nn +import json +import os + +class VQVAE(nn.Module): + def __init__(self, + enc_config, + dec_config, + quantize_config): + super().__init__() + + self.enc = new_module(enc_config) + self.dec = new_module(dec_config) + self.quantize = new_module(quantize_config) + + def forward(self, input): + quant_t, diff_t, _ = self.encode(input) + + return self.decode(quant_t), diff_t + + def encode(self, input): + logits = self.enc(input) + quant_t, diff_t, id_t = self.quantize.forward(logits) + quant_t = quant_t.permute(0, 3, 1, 2) + # diff_t = diff_t.unsqueeze(0) + + return quant_t, diff_t, id_t + + def decode(self, code): + return [self.dec(code)] + + def decode_code(self, code_t): + quant_t = self.quantize.embed_code(code_t) + quant_t = quant_t.permute(0, 3, 1, 2) + return self.decode(quant_t) + + def get_last_layer(self): + return self.dec.get_last_layer() + + +class HVQVAE(nn.Module): + def __init__( + self, + levels, + embedding_dim, + enc_config, + quantize_config, + down_sampler_configs, + dec_configs, + codebook_scale=1. + ): + super().__init__() + + self.levels = levels + + self.enc = new_module(enc_config) + + self.decs = nn.ModuleList() + for i in range(levels): + self.decs.append(new_module(dec_configs[i])) + + self.quantize = new_module(quantize_config) + self.down_samplers = nn.ModuleList() + for i in range(levels-1): + self.down_samplers.append(new_module(down_sampler_configs[i])) + self.codebook_scale = codebook_scale + + def forward(self, input): + quants, diffs, ids = self.encode(input) + dec_outputs = self.decode(quants[::-1]) + + total_diff = diffs[0] + scale = 1. + for diff in diffs[1:]: + scale *= self.codebook_scale + total_diff = total_diff + diff * scale + return dec_outputs, total_diff + + def encode(self, input): + enc_output = self.enc(input) + enc_outputs = [enc_output] + for l in range(self.levels-1): + enc_outputs.append(self.down_samplers[l](enc_outputs[-1])) + + quants, diffs, ids = [], [], [] + for enc_output in enc_outputs: + quant, diff, id = self.quantize(enc_output) + quants.append(quant.permute(0, 3, 1, 2)) + diffs.append(diff) + ids.append(id) + + return quants, diffs, ids + + def decode(self, quants): + dec_outputs = [] + for l in range(self.levels-1, -1, -1): + dec_outputs.append(self.decs[l](quants[l])) + + return dec_outputs + + def decode_code(self, codes): + quants = [] + for l in range(self.levels): + quants.append(self.quantize.embed_code(codes[l]).permute(0, 3, 1, 2)) + dec_outputs = self.decode(quants) + + return dec_outputs + + def get_last_layer(self): + return self.decs[-1].get_last_layer() \ No newline at end of file -- GitLab