Skip to content
Snippets Groups Projects
Commit 59585e17 authored by minkowski0125's avatar minkowski0125
Browse files

add new vqvae module

parent 21ce004a
No related branches found
No related tags found
No related merge requests found
# vq-vae-2-pytorch
Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 in PyTorch
## Update
* 2020-06-01
train_vqvae.py and vqvae.py now supports distributed training. You can use --n_gpu [NUM_GPUS] arguments for train_vqvae.py to use [NUM_GPUS] during training.
## Requisite
* Python >= 3.6
* PyTorch >= 1.1
* lmdb (for storing extracted codes)
[Checkpoint of VQ-VAE pretrained on FFHQ](vqvae_560.pt)
## Usage
Currently supports 256px (top/bottom hierarchical prior)
1. Stage 1 (VQ-VAE)
> python train_vqvae.py [DATASET PATH]
If you use FFHQ, I highly recommends to preprocess images. (resize and convert to jpeg)
2. Extract codes for stage 2 training
> python extract_code.py --ckpt checkpoint/[VQ-VAE CHECKPOINT] --name [LMDB NAME] [DATASET PATH]
3. Stage 2 (PixelSNAIL)
> python train_pixelsnail.py [LMDB NAME]
Maybe it is better to use larger PixelSNAIL model. Currently model size is reduced due to GPU constraints.
## Sample
### Stage 1
Note: This is a training sample
![Sample from Stage 1 (VQ-VAE)](stage1_sample.png)
from .distributed import (
get_rank,
get_local_rank,
is_primary,
synchronize,
get_world_size,
all_reduce,
all_gather,
reduce_dict,
data_sampler,
LOCAL_PROCESS_GROUP,
)
from .launch import launch
import os
import torch
from torch import distributed as dist
from torch import multiprocessing as mp
import distributed as dist_fn
def find_free_port():
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
return port
def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()):
world_size = n_machine * n_gpu_per_machine
if world_size > 1:
if "OMP_NUM_THREADS" not in os.environ:
os.environ["OMP_NUM_THREADS"] = "1"
if dist_url == "auto":
if n_machine != 1:
raise ValueError('dist_url="auto" not supported in multi-machine jobs')
port = find_free_port()
dist_url = f"tcp://127.0.0.1:{port}"
if n_machine > 1 and dist_url.startswith("file://"):
raise ValueError(
"file:// is not a reliable init method in multi-machine jobs. Prefer tcp://"
)
mp.spawn(
distributed_worker,
nprocs=n_gpu_per_machine,
args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args),
daemon=False,
)
else:
fn(*args)
def distributed_worker(
local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args
):
if not torch.cuda.is_available():
raise OSError("CUDA is not available. Please check your environments")
global_rank = machine_rank * n_gpu_per_machine + local_rank
try:
dist.init_process_group(
backend="NCCL",
init_method=dist_url,
world_size=world_size,
rank=global_rank,
)
except Exception:
raise OSError("failed to initialize NCCL groups")
dist_fn.synchronize()
if n_gpu_per_machine > torch.cuda.device_count():
raise ValueError(
f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})"
)
torch.cuda.set_device(local_rank)
if dist_fn.LOCAL_PROCESS_GROUP is not None:
raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None")
n_machine = world_size // n_gpu_per_machine
for i in range(n_machine):
ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine))
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
dist_fn.distributed.LOCAL_PROCESS_GROUP = pg
fn(*args)
import math
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
def nonlinearity(x):
return x * torch.sigmoid(x)
def Normalize(in_channels):
return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self,
in_channels,
with_conv):
super().__init__()
self.with_conv = with_conv
if with_conv:
self.conv = nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2., mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class DownSample(nn.Module):
def __init__(self,
in_channels,
with_conv):
super().__init__()
self.with_conv = with_conv
if with_conv:
self.conv = nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode='constant', value=0)
x = self.conv(x)
else:
x = F.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResidualDownSample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.pooling_down_sampler = DownSample(in_channels, with_conv=False)
self.conv_down_sampler = DownSample(in_channels, with_conv=True)
def forward(self, x):
return self.pooling_down_sampler(x) + self.conv_down_sampler(x)
class ResnetBlock(nn.Module):
def __init__(self,
in_channels,
dropout,
out_channels=None,
conv_shortcut=False):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
self.norm2 = Normalize(out_channels)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if in_channels != out_channels:
if conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
B, C, H, W = q.shape
q = q.reshape(B, C, -1)
q = q.permute(0, 2, 1) # (B, H*W, C)
k = k.reshape(B, C, -1) # (B, C, H*W)
w_ = torch.bmm(q, k) # (B, H*W, H*W)
w_ = w_ * C**(-0.5)
w_ = F.softmax(w_, dim=2)
v = v.reshape(B, C, -1) # (B, C, H*W)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(B, C, H, W)
h_ = self.proj_out(h_)
return x + h_
class Encoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
z_channels,
channels,
num_res_blocks,
resolution,
attn_resolutions,
resample_with_conv=True,
channels_mult=(1,2,4,8),
dropout=0.
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.z_channels = z_channels
self.channels = channels
self.num_resolutions = len(channels_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.conv_in = nn.Conv2d(in_channels,
channels,
kernel_size=3,
stride=1,
padding=1)
current_resolution = resolution
in_channels_mult = (1,) + tuple(channels_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = channels * in_channels_mult[i_level]
block_out = channels * channels_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
dropout=dropout))
block_in = block_out
if current_resolution in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = DownSample(block_in,
resample_with_conv)
current_resolution = current_resolution // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = nn.Conv2d(block_in,
z_channels,
kernel_size=3,
stride=1,
padding=1)
def test_forward(self, x):
# downsample
import pdb
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
return hs
def forward(self, x):
# downsample
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
z_channels,
channels,
num_res_blocks,
resolution,
attn_resolutions,
channels_mult=(1,2,4,8),
resample_with_conv=True,
dropout=0.
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.z_channels = z_channels
self.channels = channels
self.num_resolutions = len(channels_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
in_channels_mult = (1,) + tuple(channels_mult)
block_in = channels * channels_mult[self.num_resolutions - 1]
current_resolution = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, current_resolution, current_resolution)
# z to block_in
self.conv_in = nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = channels * channels_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
dropout=dropout))
block_in = block_out
if current_resolution in attn_resolutions:
attn.append(AttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in,
resample_with_conv)
current_resolution = current_resolution * 2
self.up.insert(0, up)
# end
self.norm_out = Normalize(block_in)
self.conv_out = nn.Conv2d(block_in,
out_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z):
self.last_z_shape = z.shape
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.weight
import torch
from torch import nn
from torch import einsum
from torch.nn import functional as F
import distributed as dist_fn
class VectorQuantize(nn.Module):
def __init__(self,
hidden_dim,
embedding_dim,
n_embed,
commitment_cost=1):
super().__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.commitment_cost = commitment_cost
self.proj = nn.Conv2d(hidden_dim, embedding_dim, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed)
def forward(self, z):
B, C, H, W = z.shape
z_e = self.proj(z)
z_e = z_e.permute(0, 2, 3, 1) # (B, H, W, C)
flatten = z_e.reshape(-1, self.embedding_dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed.weight.t()
+ self.embed.weight.pow(2).sum(1, keepdim=True).t()
)
_, embed_ind = (-dist).max(1)
embed_ind = embed_ind.view(B, H, W)
z_q = self.embed_code(embed_ind)
diff = self.commitment_cost * (z_q.detach() - z_e).pow(2).mean() \
+ (z_q - z_e.detach()).pow(2).mean()
z_q = z_e + (z_q - z_e).detach()
return z_q, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.weight)
class VectorQuantizeEMA(nn.Module):
def __init__(self,
hidden_dim,
embedding_dim,
n_embed,
commitment_cost=1,
decay=0.99,
eps=1e-5,
pre_proj=True,
training_loc=True):
super().__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.commitment_cost = commitment_cost
self.training_loc = training_loc
self.pre_proj = pre_proj
if self.pre_proj:
self.proj = nn.Conv2d(hidden_dim, embedding_dim, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed)
self.register_buffer("cluster_size", torch.zeros(n_embed))
self.register_buffer("embed_avg", self.embed.weight.data.clone())
self.decay = decay
self.eps = eps
def forward(self, z):
B, C, H, W = z.shape
if self.pre_proj:
z_e = self.proj(z)
else:
z_e = z
z_e = z_e.permute(0, 2, 3, 1) # (B, H, W, C)
flatten = z_e.reshape(-1, self.embedding_dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed.weight.t()
+ self.embed.weight.pow(2).sum(1, keepdim=True).t()
)
_, embed_ind = (-dist).max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
embed_ind = embed_ind.view(B, H, W)
z_q = self.embed_code(embed_ind)
if self.training_loc and self.training:
embed_onehot_sum = embed_onehot.sum(0)
embed_sum = (flatten.transpose(0, 1) @ embed_onehot).transpose(0, 1)
dist_fn.all_reduce(embed_onehot_sum)
dist_fn.all_reduce(embed_sum)
self.cluster_size.data.mul_(self.decay).add_(
embed_onehot_sum, alpha=1-self.decay
)
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1-self.decay)
n = self.cluster_size.sum()
cluster_size = (
(self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.weight.data.copy_(embed_normalized)
diff = self.commitment_cost * (z_q.detach() - z_e).pow(2).mean()
z_q = z_e + (z_q - z_e).detach()
return z_q, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.weight)
class GumbelQuantize(nn.Module):
def __init__(self,
hidden_dim,
embedding_dim,
n_embed,
commitment_cost=1,
straight_through=True,
kl_weight=5e-4,
temp_init=1.,
eps=1e-5):
super().__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.commitment_cost = commitment_cost
self.kl_weight = kl_weight
self.temperature = temp_init
self.eps = eps
self.proj = nn.Conv2d(hidden_dim, n_embed, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed)
self.straight_through = straight_through
def forward(self, z, temp=None):
hard = self.straight_through if self.training else True
temp = self.temperature if temp is None else temp
B, C, H, W = z.shape
z_e = self.proj(z)
soft_one_hot = F.gumbel_softmax(z_e, tau=temp, dim=1, hard=hard)
z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
qy = F.softmax(z_e, dim=1)
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + self.eps), dim=1).mean()
embed_ind = soft_one_hot.argmax(dim=1)
z_q = z_q.permute(0, 2, 3, 1)
return z_q, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.weight)
\ No newline at end of file
import torch
from torch import nn
import json
import os
class VQVAE(nn.Module):
def __init__(self,
enc_config,
dec_config,
quantize_config):
super().__init__()
self.enc = new_module(enc_config)
self.dec = new_module(dec_config)
self.quantize = new_module(quantize_config)
def forward(self, input):
quant_t, diff_t, _ = self.encode(input)
return self.decode(quant_t), diff_t
def encode(self, input):
logits = self.enc(input)
quant_t, diff_t, id_t = self.quantize.forward(logits)
quant_t = quant_t.permute(0, 3, 1, 2)
# diff_t = diff_t.unsqueeze(0)
return quant_t, diff_t, id_t
def decode(self, code):
return [self.dec(code)]
def decode_code(self, code_t):
quant_t = self.quantize.embed_code(code_t)
quant_t = quant_t.permute(0, 3, 1, 2)
return self.decode(quant_t)
def get_last_layer(self):
return self.dec.get_last_layer()
class HVQVAE(nn.Module):
def __init__(
self,
levels,
embedding_dim,
enc_config,
quantize_config,
down_sampler_configs,
dec_configs,
codebook_scale=1.
):
super().__init__()
self.levels = levels
self.enc = new_module(enc_config)
self.decs = nn.ModuleList()
for i in range(levels):
self.decs.append(new_module(dec_configs[i]))
self.quantize = new_module(quantize_config)
self.down_samplers = nn.ModuleList()
for i in range(levels-1):
self.down_samplers.append(new_module(down_sampler_configs[i]))
self.codebook_scale = codebook_scale
def forward(self, input):
quants, diffs, ids = self.encode(input)
dec_outputs = self.decode(quants[::-1])
total_diff = diffs[0]
scale = 1.
for diff in diffs[1:]:
scale *= self.codebook_scale
total_diff = total_diff + diff * scale
return dec_outputs, total_diff
def encode(self, input):
enc_output = self.enc(input)
enc_outputs = [enc_output]
for l in range(self.levels-1):
enc_outputs.append(self.down_samplers[l](enc_outputs[-1]))
quants, diffs, ids = [], [], []
for enc_output in enc_outputs:
quant, diff, id = self.quantize(enc_output)
quants.append(quant.permute(0, 3, 1, 2))
diffs.append(diff)
ids.append(id)
return quants, diffs, ids
def decode(self, quants):
dec_outputs = []
for l in range(self.levels-1, -1, -1):
dec_outputs.append(self.decs[l](quants[l]))
return dec_outputs
def decode_code(self, codes):
quants = []
for l in range(self.levels):
quants.append(self.quantize.embed_code(codes[l]).permute(0, 3, 1, 2))
dec_outputs = self.decode(quants)
return dec_outputs
def get_last_layer(self):
return self.decs[-1].get_last_layer()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment