diff --git a/generation/autoregressive_sampling.py b/generation/autoregressive_sampling.py index b29964f0a2fbd77945347a26a3cbfead95de7a8c..402200b669ab59f6bafef443bb96c75e04982f04 100644 --- a/generation/autoregressive_sampling.py +++ b/generation/autoregressive_sampling.py @@ -85,7 +85,8 @@ def filling_sequence( continue # forward - model.log_attention_weights = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen + if log_attention_weights is not None: + model.log_attention_weights = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen logits, *mem_kv = model( tokens[:, index:], position_ids[..., index: counter+1], diff --git a/generation/cuda2d_sampling.py b/generation/cuda2d_sampling.py index 20ffa38242c66421d28b8cdb84abb53373d51ec9..c7ebf31c5e52143bf4ab60201dd53f0aa72b5c2a 100644 --- a/generation/cuda2d_sampling.py +++ b/generation/cuda2d_sampling.py @@ -14,7 +14,7 @@ import random import torch from .sampling_strategies import IterativeEntfilterStrategy -def filling_sequence( +def filling_sequence_cuda2d( model, seq0, seq1, @@ -38,10 +38,9 @@ def filling_sequence( assert seq1.shape[1] == layout[-1] - layout[-2] assert (seq1 >= 0).all() and (seq0 >= 0).all() device = seq0.device - # concat and pad sequences batch_size = seq0.shape[0] - n_pad = layout[1] + 1 - len(seq0) # +1 for [EOI1] + n_pad = layout[1] + 1 - seq0.shape[1] # +1 for [EOI1] assert n_pad > 0, "You should truncate long input before filling." seq = torch.cat(( torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype) @@ -53,6 +52,7 @@ def filling_sequence( tokens = seq.clone() attention_mask = torch.ones(layout[1], layout[1]).tril().to(device) attention_mask[n_pad:, :n_pad] = 0 + attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16 position_ids = torch.cat(( torch.zeros(n_pad, dtype=torch.long), torch.arange(0, layout[1] - n_pad), @@ -60,7 +60,7 @@ def filling_sequence( # prepare for interation unfixed = (tokens < 0) - unfixed[:, -4096] = True + unfixed[:, -layout[-1] + layout[-2]:] = True ll, rr = block_hw edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4) num_steps = warmup_steps + ll + rr - 2 @@ -69,7 +69,8 @@ def filling_sequence( logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask) if step_cnt <= warmup_steps: real_temp = 0.1 - tokens = strategy.forward(logits, tokens, real_temp) + new_tokens = strategy.forward(logits, tokens, real_temp) + tokens[unfixed] = new_tokens[unfixed] else: real_temp = 1.05 new_tokens = strategy.forward( @@ -86,4 +87,5 @@ def filling_sequence( print(x,y) unfixed[..., -(layout[-1] - layout[-2]):].view( batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = False - return tokens \ No newline at end of file + + return tokens[:, n_pad:] \ No newline at end of file diff --git a/generation/sampling_strategies/iterative_entfilter_strategy.py b/generation/sampling_strategies/iterative_entfilter_strategy.py index 84196d0c8fae8db48ed515ff265dea9e22136409..9017a42a3c4dc5092be7c57f64b81e1ef0f8e0bf 100644 --- a/generation/sampling_strategies/iterative_entfilter_strategy.py +++ b/generation/sampling_strategies/iterative_entfilter_strategy.py @@ -36,7 +36,7 @@ class IterativeEntfilterStrategy: topraw = (torch.topk(logits, filter_topk, dim=-1)[0]).softmax(dim=-1) ent = -(topraw * topraw.log()).sum(dim=-1) # [batch_size, seq_length] temperature = torch.tensor([[[temperature - temperature2]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > entfilter).unsqueeze(-1) + temperature2 - logits = logits / temperature + logits = logits.float() / temperature for invalid_slice in self.invalid_slices: logits[..., invalid_slice] = -float('Inf') diff --git a/generation/utils.py b/generation/utils.py index bad073413b6c3efa6f71b3e46eeacb883bd504b0..01c05726a682e2245575517ba18d3ca27deea9c7 100644 --- a/generation/utils.py +++ b/generation/utils.py @@ -68,11 +68,12 @@ def generate_continually(func, input_source='interactive'): raw_text = raw_text.strip() if len(raw_text) == 0: continue - try: - start_time = time.time() - func(raw_text) - print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) - except (ValueError, FileNotFoundError) as e: - err_linenos.append(line_no) - continue + # try: + start_time = time.time() + func(raw_text) + print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) + # except (ValueError, FileNotFoundError) as e: + # err_linenos.append(line_no) + # print(e) + # continue print(err_linenos) diff --git a/inference_cogview.py b/inference_cogview.py index 56bdd96ead78f155fce08552e6c582facbeffbdc..88311ff70e737a2f93c003c7bc51ec231d614646 100644 --- a/inference_cogview.py +++ b/inference_cogview.py @@ -32,6 +32,7 @@ def main(args): model = model.to(args.device) load_checkpoint(model, args) set_random_seed(args.seed) + model.eval() # define function for each query query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024' if not args.full_query else '{}' diff --git a/inference_cogview2.py b/inference_cogview2.py new file mode 100644 index 0000000000000000000000000000000000000000..6895d363e242e283d65e6c93949f6bf01a2f6f0a --- /dev/null +++ b/inference_cogview2.py @@ -0,0 +1,125 @@ +# -*- encoding: utf-8 -*- +''' +@File : inference_cogview2.py +@Time : 2021/10/10 16:31:34 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random + +import torch +import argparse +from torchvision import transforms + +from arguments import get_args +from model.cached_autoregressive_model import CachedAutoregressiveModel +from model.cuda2d_model import Cuda2dModel +from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer +from tokenization import get_tokenizer +from generation.sampling_strategies import BaseStrategy, IterativeEntfilterStrategy +from generation.autoregressive_sampling import filling_sequence +from generation.cuda2d_sampling import filling_sequence_cuda2d +from generation.utils import timed_name, save_multiple_images, generate_continually + +def main(args): + initialize_distributed(args) + tokenizer = prepare_tokenizer(args) + # build model + model = Cuda2dModel(args) + if args.fp16: + model = model.half() + model = model.to(args.device) + load_checkpoint(model, args) + model0 = CachedAutoregressiveModel(args, transformer=model.transformer) + set_random_seed(args.seed) + model.eval() + model0.eval() + # define function for each query + query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024 [EOI1]' if not args.full_query else '{}' + invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] + strategy0 = BaseStrategy(invalid_slices, + temperature=args.temperature, topk=args.top_k) + strategy1 = IterativeEntfilterStrategy(invalid_slices, + temperature=args.temperature, topk=10) # temperature not used + tr = transforms.Compose([ + transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), + ]) + + def process(raw_text): + if args.with_id: + query_id, raw_text = raw_text.split() + print('raw text: ', raw_text) + text = query_template.format(raw_text) + seq = tokenizer.parse_query(text, img_size=args.img_size) + if len(seq) > 1088: + raise ValueError('text too long.') + # calibrate text length + txt_len = seq.index(tokenizer['[BASE]']) + log_attention_weights = torch.zeros(len(seq), len(seq), + device=args.device, dtype=torch.half if args.fp16 else torch.float32) + log_attention_weights[txt_len+2:, 1:txt_len] = 1.8 if txt_len <= 10 else 1.4 # TODO args + + # generation + seq = torch.cuda.LongTensor(seq, device=args.device) + mbz = args.max_inference_batch_size + assert args.batch_size < mbz or args.batch_size % mbz == 0 + output_list = [] + for tim in range(max(args.batch_size // mbz, 1)): + output0 = filling_sequence(model0, seq.clone(), + batch_size=min(args.batch_size, mbz), + strategy=strategy0, + log_attention_weights=log_attention_weights + ) + imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1025:-1].tolist())) for x in output0] + blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(args.device), add_normalization=True) # [batch_size, 4096] + output1 = filling_sequence_cuda2d(model, output0, blur64, + warmup_steps=3, block_hw=(4, 4), + strategy=strategy1 + ) + output_list.append(output1) + output_tokens = torch.cat(output_list, dim=0) + # decoding + imgs, txts = [], [] + for seq in output_tokens: + decoded_txts, decoded_imgs = tokenizer.DecodeIds(seq.tolist()) + for i in range(len(decoded_imgs)): + if decoded_imgs[i].shape[-1] < 512: + decoded_imgs[i] = torch.nn.functional.interpolate(decoded_imgs[i], size=(512, 512)) + if args.with_id: + imgs.append(decoded_imgs[-1]) # only the last image (target) + else: + imgs.extend(decoded_imgs) + # save + if args.with_id: + full_path = os.path.join(args.output_path, query_id) + os.makedirs(full_path, exist_ok=True) + save_multiple_images(imgs, full_path, False) + else: + prefix = raw_text.replace('/', '')[:20] + full_path = timed_name(prefix, '.jpg', args.output_path) + save_multiple_images(imgs, full_path, True) + + os.makedirs(args.output_path, exist_ok=True) + generate_continually(process, args.input_source) + +if __name__ == "__main__": + py_parser = argparse.ArgumentParser(add_help=False) + py_parser.add_argument('--full-query', action='store_true') + py_parser.add_argument('--img-size', type=int, default=256) + + Cuda2dModel.add_model_specific_args(py_parser) + + known, args_list = py_parser.parse_known_args() + + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + args.layout = [int(x) for x in args.layout.split(',')] + + with torch.no_grad(): + main(args) \ No newline at end of file diff --git a/move_weights.py b/move_weights.py index 7d41f2b2990effb15c7022dd248c6ce8ada8ed79..e55d6b34b35c2db3afc0eaeffcc2fb029142f784 100644 --- a/move_weights.py +++ b/move_weights.py @@ -103,3 +103,13 @@ missing_keys, unexpected_keys = model.load_state_dict(old['module'], strict=Fals # %% torch.save(old, 'pretrained/cogview/cogview2-base/6000/mp_rank_00_model_states.pt') # %% +import torch +old = torch.load("/dataset/fd5061f6/cogview/zwd/vqgan/l1+ms-ssim+revd_percep/checkpoints/last.ckpt", map_location='cpu') + +# %% +from collections import OrderedDict +new_ckpt = OrderedDict() +for k,v in old['state_dict'].items(): + new_ckpt[k] = v.detach() +torch.save(new_ckpt, 'pretrained/vqvae/l1+ms-ssim+revd_percep.pt') +# %% diff --git a/mpu/transformer.py b/mpu/transformer.py index 61f658ee4bc1e5de73f7f2e01fb77fa0adc3e21f..e04485a697504d5e7fc3d9967b081981a7503bd9 100755 --- a/mpu/transformer.py +++ b/mpu/transformer.py @@ -308,7 +308,7 @@ class BaseTransformer(torch.nn.Module): layer_id, output_layer_init_method=self.output_layer_init_method, sandwich_ln=sandwich_ln, - hooks=hooks + hooks=self.hooks ) self.layers = torch.nn.ModuleList( [get_layer(layer_id) for layer_id in range(num_layers)]) diff --git a/scripts/text2image_cogview.sh b/scripts/text2image_cogview.sh index 4fc46365914d359466343eb0185c1c147d4c7155..f205776ac5e7572f60e6582812cf3935c5a2d2c1 100755 --- a/scripts/text2image_cogview.sh +++ b/scripts/text2image_cogview.sh @@ -17,7 +17,7 @@ script_dir=$(dirname $script_path) MASTER_PORT=${MASTER_PORT} python inference_cogview.py \ --tokenizer-type cogview \ - --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ + --img-tokenizer-path pretrained/vqvae/l1+ms-ssim+revd_percep.pt \ --mode inference \ --distributed-backend nccl \ --max-sequence-length 1089 \ diff --git a/scripts/text2image_cogview2.sh b/scripts/text2image_cogview2.sh new file mode 100755 index 0000000000000000000000000000000000000000..96e8a045bbb2f3d4f046b9949dc734b12ae0288e --- /dev/null +++ b/scripts/text2image_cogview2.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +CHECKPOINT_PATH=pretrained/cogview/cogview2-base +NLAYERS=48 +NHIDDEN=2560 +NATT=40 +MAXSEQLEN=1089 +MASTER_PORT=$(shuf -n 1 -i 10000-65535) +MPSIZE=1 + +#SAMPLING ARGS +TEMP=1.03 +TOPK=200 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +MASTER_PORT=${MASTER_PORT} python inference_cogview2.py \ + --tokenizer-type cogview \ + --img-tokenizer-path pretrained/vqvae/l1+ms-ssim+revd_percep.pt \ + --mode inference \ + --distributed-backend nccl \ + --max-sequence-length 1089 \ + --sandwich-ln \ + --fp16 \ + --model-parallel-size $MPSIZE \ + --num-layers $NLAYERS \ + --hidden-size $NHIDDEN \ + --load $CHECKPOINT_PATH \ + --num-attention-heads $NATT \ + --temperature $TEMP \ + --top_k $TOPK \ + --sandwich-ln \ + --input-source ./input.txt \ + --output-path samples_text2image \ + --batch-size 4 \ + --max-inference-batch-size 8 \ + --device 0 \ + $@ + + diff --git a/tokenization/cogview/vqvae_tokenizer.py b/tokenization/cogview/vqvae_tokenizer.py index 56ee251126fdcb901d4b88cea114342b1dfccdb7..23df2efab1ca3cb925908a1f02ee688712f7ed15 100755 --- a/tokenization/cogview/vqvae_tokenizer.py +++ b/tokenization/cogview/vqvae_tokenizer.py @@ -18,7 +18,7 @@ import torch import torch.nn.functional as F -from vqvae import new_model, img2code, code2img +from vqvae import new_model, img2code, code2img, load_model_default from torchvision import transforms from PIL import Image @@ -35,16 +35,17 @@ class VQVAETokenizer(object): model_path, device='cuda' ): - ckpt = torch.load(model_path, map_location=torch.device(device)) + # ckpt = torch.load(model_path, map_location=torch.device(device)) - model = new_model() + # model = new_model() - if list(ckpt.keys())[0].startswith('module.'): - ckpt = {k[7:]: v for k, v in ckpt.items()} + # if list(ckpt.keys())[0].startswith('module.'): + # ckpt = {k[7:]: v for k, v in ckpt.items()} - model.load_state_dict(ckpt) - model = model.to(device) - model.eval() + # model.load_state_dict(ckpt) + # model = model.to(device) + # model.eval() + model = load_model_default(device=device, path=model_path) self.model = model self.device = device diff --git a/vqvae/__init__.py b/vqvae/__init__.py index 9d2256330e77c8c1ab47ff8ddd21854d15a32f37..2652e923e0f0a7dd90da239d021f3dd4fd922f08 100755 --- a/vqvae/__init__.py +++ b/vqvae/__init__.py @@ -1 +1,3 @@ -from .api import new_model, img2code, code2img \ No newline at end of file +from .api import new_model, img2code, code2img +from .api import new_module, load_decoder_default, load_model_default +from .api import test_decode, test_decode_default \ No newline at end of file diff --git a/vqvae/api.py b/vqvae/api.py index 5a5f64c873baaae70b79d41b6b8f1fc9062e93f0..7d2fd0312c1276ac4cbe9c46b04e8aeef2c905e1 100755 --- a/vqvae/api.py +++ b/vqvae/api.py @@ -2,13 +2,206 @@ # Can rewrite the APIs for VQGAN. # Don't forget to freeze the relavant .py files. +import importlib import torch import math +import os + +from torchvision.utils import save_image, make_grid +from datetime import datetime # production APIs from .vqvae_zc import VQVAE +def new_module(config): + ''' + in config: + "target": module type, vqvae_zc.Decoder/vqvae_diffusion.Decoder/vqvae_diffusion.Decoder2 + "ckpt": path of checkpoint + "ckpt_prefix": prefix to remove in ckpt state dict + "device": device + "params": dict of params + ''' + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + module, cls = config.get("target").rsplit(".", 1) + model = getattr(importlib.import_module(module, package=None), cls)(**config.get("params", dict())) + + device = config.get("device", "cpu") + model = model.to(device) + model.eval() + + if "ckpt" in config: + ckpt = torch.load(config.get("ckpt"), map_location='cpu') + prefix = config.get("ckpt_prefix", None) + if "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + if prefix is not None: + ckpt = {k[len(prefix) + 1:]: v for k, v in ckpt.items() if k.startswith(prefix)} + model.load_state_dict(ckpt, strict=False) + del ckpt + return model + +def load_decoder_default(device=0, path="pretrained/vqvae/l1+ms-ssim+revd_percep.pt"): + # exp: load currently best decoder + target = "vqvae.vqvae_diffusion.Decoder" + params = { + "double_z": False, + "z_channels": 256, + "resolution": 256, + "in_channels": 3, + "out_ch": 3, + "ch": 128, + "ch_mult": [ 1,1,2,4], # num_down = len(ch_mult)-1 + "num_res_blocks": 2, + "attn_resolutions": [16], + "dropout": 0.0 + } + ckpt_prefix = "dec" + + config = { + "target": target, + "params": params, + "ckpt": path, + "ckpt_prefix": ckpt_prefix, + "device": device + } + return new_module(config) + +# path="/dataset/fd5061f6/cogview/zwd/vqgan/l1+ms-ssim+revd_percep/checkpoints/last.ckpt" +def load_model_default(device=0, + path="pretrained/vqvae/l1+ms-ssim+revd_percep.pt"): + # exp: load currently best vqvae model + ddconfig = { + "double_z": False, + "z_channels": 256, + "resolution": 256, + "in_channels": 3, + "out_ch": 3, + "ch": 128, + "ch_mult": [1,1,2,4], + "num_res_blocks": 2, + "attn_resolutions": [16], + "dropout": 0.0 + } + params = { + "in_channel": 3, + "channel": 512, + "n_res_block": 0, + "n_res_channel": 32, + "embed_dim": 256, + "n_embed": 8192, + "stride": 6, + "simple": True, + "decay": 0.99, + "dif": True, + "ddconfig": ddconfig + } + + config = { + 'target': "vqvae.vqvae_zc.VQVAE", + 'params': params, + 'ckpt': path, + 'device': device + } + return new_module(config) + +def test_decode(configs, testcase, device=0, output_path=None): + ''' + configs: list of config for new module + testcases: pt file path or tensor of [B, D, H, W] + ''' + if output_path is None: + output_path = os.path.join("sample", f"{datetime.now().strftime('%m-%d-%H-%M-%S')}.jpg") + + quantize_config = { + "target": "vqvae.vqvae_zc.Quantize", + "params": { + "dim": 256, + "n_embed": 8192, + }, + "ckpt": "/dataset/fd5061f6/cogview/zwd/pretrained/vqvae/vqvae_hard_biggerset_011.pt", + "ckpt_prefix": "module.quantize_t", + "device": device + } + quantize = new_module(quantize_config) + + if type(testcase) is str: + testcase = torch.load(testcase, map_location=torch.device(device))[:, -1024:].contiguous() + testcase = testcase.view(testcase.shape[0], 32, 32).contiguous() + else: + testcase = testcase.to(device) + + quantized_testcase = quantize.embed_code(testcase) + quantized_testcase = quantized_testcase.permute(0, 3, 1, 2) + + outs = [] + for config in configs: + decoder = new_module(config) + out = decoder(quantized_testcase) + outs.append(out.unsqueeze(0)) + outs = torch.cat(outs).permute(1, 0, 2, 3, 4) + outs = outs.reshape(-1, *outs.shape[2:]).contiguous() + save_image(make_grid(outs, nrow=len(configs)), output_path, normalize=True) + +def test_decode_default(device=0): + # testing 3 decoders: original/l1+ms-ssim/l1+ms-ssim+perceptual + configs = [ + { + "target": "vqvae.vqvae_zc.Decoder", + "params": { + "in_channel": 256, + "out_channel": 3, + "channel": 512, + "n_res_block": 0, + "n_res_channel": 32, + "stride": 4, + "simple": True + }, + "ckpt": "/dataset/fd5061f6/cogview/zwd/pretrained/vqvae/vqvae_hard_biggerset_011.pt", + "ckpt_prefix": "module.dec", + "device": device }, + { + "target": "vqvae.vqvae_diffusion.Decoder", + "params": { + "double_z": False, + "z_channels": 256, + "resolution": 256, + "in_channels": 3, + "out_ch": 3, + "ch": 128, + "ch_mult": [ 1,1,2,4], # num_down = len(ch_mult)-1 + "num_res_blocks": 2, + "attn_resolutions": [16], + "dropout": 0.0 + }, + "ckpt": "/dataset/fd5061f6/cogview/zwd/vqgan/l1+ms-ssim/checkpoints/last.ckpt", + "ckpt_prefix": "dec", + "device": device }, + { + "target": "vqvae.vqvae_diffusion.Decoder", + "params": { + "double_z": False, + "z_channels": 256, + "resolution": 256, + "in_channels": 3, + "out_ch": 3, + "ch": 128, + "ch_mult": [ 1,1,2,4], # num_down = len(ch_mult)-1 + "num_res_blocks": 2, + "attn_resolutions": [16], + "dropout": 0.0 + }, + "ckpt": "/dataset/fd5061f6/cogview/zwd/vqgan/l1+ms-ssim+revd_percep/checkpoints/last.ckpt", + "ckpt_prefix": "dec", + "device": device }, + ] + testcase_dir = "/dataset/fd5061f6/cogview/zwd/vqgan/testcase/" + for testcase in os.listdir(testcase_dir): + testcase = os.path.join(testcase_dir, testcase) + test_decode(configs, testcase, device) + def new_model(): '''Return a New Instance of VQVAE, the same parameters with the pretrained model. This is for torch.load(). diff --git a/vqvae/vqvae_diffusion.py b/vqvae/vqvae_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c566deb015c07fc2db3ff26f3fe4e7ab574f88 --- /dev/null +++ b/vqvae/vqvae_diffusion.py @@ -0,0 +1,782 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, t=None): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x): + #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +def Decoder2(**ddconfig): + return nn.Sequential( + torch.nn.Conv2d(ddconfig.get("z_channels"), ddconfig.get("z_channels"), 1), + Decoder(**ddconfig) + ) + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VUNet(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + in_channels, c_channels, + resolution, z_channels, use_timestep=False, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(c_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + self.z_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=1, + stride=1, + padding=0) + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=2*block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, z): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + z = self.z_in(z) + h = torch.cat((h,z),dim=1) + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + diff --git a/vqvae/vqvae_zc.py b/vqvae/vqvae_zc.py index a018ad5e0c8449548a27bac468f598f27f57b2a3..ca2b360ad54e33e9b4da22b7f7d772917caf5f16 100755 --- a/vqvae/vqvae_zc.py +++ b/vqvae/vqvae_zc.py @@ -2,6 +2,8 @@ import torch from torch import nn from torch.nn import functional as F +from .vqvae_diffusion import Decoder as DifDecoder + # import distributed as dist_fn # Copyright 2018 The Sonnet Authors. All Rights Reserved. @@ -225,22 +227,27 @@ class VQVAE(nn.Module): n_embed=1024, stride=4, simple=True, - decay=0.99 + decay=0.99, + dif=False, + ddconfig=None ): super().__init__() if channel == 2048: n_res_block = 0 self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride, embed_dim, n_embed, simple) self.quantize_t = Quantize(embed_dim, n_embed) - self.dec = Decoder( - in_channel=embed_dim, - out_channel=in_channel, - channel=channel, - n_res_block=n_res_block, - n_res_channel=n_res_channel, - stride=stride-2, - simple=simple - ) + if dif: + self.dec = DifDecoder(**ddconfig) + else: + self.dec = Decoder( + in_channel=embed_dim, + out_channel=in_channel, + channel=channel, + n_res_block=n_res_block, + n_res_channel=n_res_channel, + stride=stride-2, + simple=simple + ) def forward(self, input, continuous_relax=False, temperature=1., hard=False, KL=False): quant_t, diff, _, = self.encode(input, continuous_relax, temperature, hard, KL)