Skip to content
Snippets Groups Projects
Commit 39e6d723 authored by Ming Ding's avatar Ming Ding
Browse files

del old

parent 5db7dfe4
No related branches found
No related tags found
No related merge requests found
Showing
with 0 additions and 1567 deletions
assets/coco_new.png

8.67 MiB

assets/cogviewcase.png

1.55 MiB

assets/logo.png

21.5 KiB

# %%
p = 'people.jpeg'
from data_utils.vqvae_tokenizer import VQVAETokenizer
model = VQVAETokenizer(
'pretrained/vqvae/vqvae_hard_biggerset_011.pt'
)
img = model.read_img(p, img_size=512)
# %%
test_dir = 'tmp'
import os
import torch
from torchvision.utils import save_image
img = model.EncodeAsIds(img)
imgs = model.DecodeIds(torch.tensor(img))
save_image(imgs, os.path.join(test_dir, 'show512_people.jpg'), normalize=True)
# %%
import numpy as np
import torch
def entropy(x):
a = np.array(x)
a = a / a.sum()
return - np.sum(a * np.log(a))
print(entropy([0.9999,0.001]))
def loadbao(name):
ret1, ret2 = [], []
with open(name, 'r') as fin:
for line in fin:
a = line.split()
aa = [float(x) for x in a[1:5]]
b = entropy(aa)
c = sum(aa)
ret1.append(b)
ret2.append(c)
return np.array(ret1), np.array(ret2)
def loadlion(name):
ret1, ret2 = [], []
with open(name, 'r') as fin:
for line in fin:
a, b = line.split()
ret1.append(abs(float(a)))
ret2.append(abs(float(b)))
return ret1, ret2
import torchvision
import torchvision.transforms as transforms
def sq(img, x, y, lx, ly):
assert len(img.shape) == 3
img[:,x:x+lx,y] = torch.tensor([0,1,0]).unsqueeze(-1)
img[:,x:x+lx,y+ly] = torch.tensor([0,1,0]).unsqueeze(-1)
img[:,x,y:y+ly] = torch.tensor([0,1,0]).unsqueeze(-1)
img[:,x+lx,y:y+ly] = torch.tensor([0,1,0]).unsqueeze(-1)
transform = transforms.Compose([
transforms.Resize(512),
transforms.CenterCrop(512),
])
img = torchvision.io.read_image('cat2.jpeg')
img = transform(img) / 255.
# a,b = np.array(loadlion('bed6.txt'))
b, c = np.array(loadbao('bed1.txt'))
for t in (b>1.35).nonzero()[0]:
x,y = t // 64, t % 64
sq(img, x*8, y*8, 7, 7)
print(b.sum())
torchvision.utils.save_image(img, 'example_cat.jpg')
from torch.utils.data import IterableDataset
import PIL
import csv
import torch
from io import BytesIO
import base64
class TsvDataset(IterableDataset):
def __init__(self, path, transform=None, caption_only=False):
self.f = open(path, "r")
self.tsvreader = csv.reader(self.f, delimiter='\t')
self.transform = transform
self.caption_only = caption_only
def callback_fn(image_base64, id, caption):
try:
img = Image.open(BytesIO(base64.urlsafe_b64decode(image_base64))).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img, id, caption
except (PIL.UnidentifiedImageError, PIL.Image.DecompressionBombError):
print("UnidentifiedImageError")
return torch.zeros((3, 256, 256)), "not_a_image", "not_a_caption"
self.callback_fn = callback_fn
def __iter__(self):
def get_next():
if self.caption_only:
for line in self.tsvreader:
yield self.callback_fn(torch.zeros((3, 256, 256)), line[0], line[1])
else:
for line in self.tsvreader:
yield self.callback_fn(line[3], line[0], line[2])
return iter(get_next())
def __del__(self):
self.f.close()
\ No newline at end of file
#!/usr/bin/env python3
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
The FID metric calculates the distance between two distributions of images.
Typically, we have summary statistics (mean & covariance matrix) of one
of these distributions, while the 2nd distribution is given by a GAN.
When run as a stand-alone program, it compares the distribution of
images that are stored as PNG/JPEG at a specified location with a
distribution given by summary statistics (in pickle format).
The FID is calculated by assuming that X_1 and X_2 are the activations of
the pool_3 layer of the inception net for generated samples and real world
samples respectivly.
See --help to see further details.
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
of Tensorflow
Copyright 2018 Institute of Bioinformatics, JKU Linz
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import pathlib
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from torchvision.models.inception import inception_v3
import torch
import numpy as np
from scipy.misc import imread
from scipy import linalg
from torch.autograd import Variable
from torch.nn.functional import adaptive_avg_pool2d
import torchvision.transforms as transforms
from inception import InceptionV3
import torch.utils.data
from PIL import Image
from torch.utils import data
import img_data as img_data
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
#parser.add_argument('path', type=str, nargs=2,
# help=('Path to the generated images or '
# 'to .npz statistic files'))
parser.add_argument('--batch-size', type=int, default=64,
help='Batch size to use')
parser.add_argument('--dims', type=int, default=2048,
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
help=('Dimensionality of Inception features to use. '
'By default, uses pool3 features'))
parser.add_argument('-c', '--gpu', default='', type=str,
help='GPU to use (leave blank for CPU only)')
parser.add_argument('--path1', type=str, default=64)
parser.add_argument('--path2', type=str, default=64)
def get_activations(images, model, batch_size=64, dims=2048, cuda=False, verbose=True):
"""Calculates the activations of the pool_3 layer for all images.
Params:
-- images : Numpy array of dimension (n_images, 3, hi, wi). The values
must lie between 0 and 1.
-- model : Instance of inception model
-- batch_size : the images numpy array is split into batches with
batch size batch_size. A reasonable batch size depends
on the hardware.
-- dims : Dimensionality of features returned by Inception
-- cuda : If set to True, use GPU
-- verbose : If set to True and parameter out_step is given, the number
of calculated batches is reported.
Returns:
-- A numpy array of dimension (num images, dims) that contains the
activations of the given tensor when feeding inception with the
query tensor.
"""
model.eval()
#d0 = images.shape[0]
d0 = images.__len__() * batch_size
if batch_size > d0:
print(('Warning: batch size is bigger than the data size. '
'Setting batch size to data size'))
batch_size = d0
n_batches = d0 // batch_size
n_used_imgs = n_batches * batch_size
pred_arr = np.empty((n_used_imgs, dims))
#for i in range(n_batches):
for i, batch in enumerate(images):
#batch = batch[0]
#if verbose:
#print('\rPropagating batch %d/%d' % (i + 1, n_batches), end='', flush=True)
#import ipdb
#ipdb.set_trace()
start = i * batch_size
end = start + batch_size
#batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor)
#batch = Variable(batch, volatile=True)
if cuda:
batch = batch.cuda()
pred = model(batch)[0]
# If model output is not scalar, apply global spatial average pooling.
# This happens if you choose a dimensionality not equal 2048.
if pred.shape[2] != 1 or pred.shape[3] != 1:
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)
if verbose:
print(' done')
return pred_arr
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Params:
-- mu1 : Numpy array containing the activations of a layer of the
inception net (like returned by the function 'get_predictions')
for generated samples.
-- mu2 : The sample mean over activations, precalculated on an
representive data set.
-- sigma1: The covariance matrix over activations for generated samples.
-- sigma2: The covariance matrix over activations, precalculated on an
representive data set.
Returns:
-- : The Frechet Distance.
"""
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert mu1.shape == mu2.shape, \
'Training and test mean vectors have different lengths'
assert sigma1.shape == sigma2.shape, \
'Training and test covariances have different dimensions'
diff = mu1 - mu2
# Product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = ('fid calculation produces singular product; '
'adding %s to diagonal of cov estimates') % eps
print(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return (diff.dot(diff) + np.trace(sigma1) +
np.trace(sigma2) - 2 * tr_covmean)
def calculate_activation_statistics(images, model, batch_size=64,
dims=2048, cuda=False, verbose=True):
"""Calculation of the statistics used by the FID.
Params:
-- images : Numpy array of dimension (n_images, 3, hi, wi). The values
must lie between 0 and 1.
-- model : Instance of inception model
-- batch_size : The images numpy array is split into batches with
batch size batch_size. A reasonable batch size
depends on the hardware.
-- dims : Dimensionality of features returned by Inception
-- cuda : If set to True, use GPU
-- verbose : If set to True and parameter out_step is given, the
number of calculated batches is reported.
Returns:
-- mu : The mean over samples of the activations of the pool_3 layer of
the inception model.
-- sigma : The covariance matrix of the activations of the pool_3 layer of
the inception model.
"""
act = get_activations(images, model, batch_size, dims, cuda, verbose)
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
if path.endswith('.npz'):
f = np.load(path)
m, s = f['mu'][:], f['sigma'][:]
f.close()
else:
dataset = img_data.Dataset(path, transforms.Compose([
transforms.Resize((299, 299)),
transforms.ToTensor(),
]))
print(dataset.__len__())
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=8)
m, s = calculate_activation_statistics(dataloader, model, batch_size, dims, cuda)
return m, s
def calculate_fid_given_dataset(dataset1, dataset2, batch_size, cuda=True, dims=2048):
"""Calculates the FID of two dataset"""
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx])
if cuda:
model.cuda()
loader1 = torch.utils.data.DataLoader(dataset=dataset1, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=8)
m1, s1 = calculate_activation_statistics(loader1, model, batch_size, dims, cuda)
loader2 = torch.utils.data.DataLoader(dataset=dataset2, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=8)
m2, s2 = calculate_activation_statistics(loader2, model, batch_size, dims, cuda)
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
return fid_value
def calculate_fid_given_paths(paths, batch_size, cuda, dims):
"""Calculates the FID of two paths"""
for p in paths:
if not os.path.exists(p):
raise RuntimeError('Invalid path: %s' % p)
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx])
if cuda:
model.cuda()
m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims, cuda)
m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims, cuda)
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
return fid_value
if __name__ == '__main__':
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
paths = ["",""]
paths[0] = args.path1
paths[1] = args.path2
print(paths)
fid_value = calculate_fid_given_paths(paths, args.batch_size,args.gpu,args.dims)
print('FID: ', fid_value)
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class InceptionV3(nn.Module):
"""Pretrained InceptionV3 network returning feature maps"""
# Index of default block of inception to return,
# corresponds to output of final average pooling
DEFAULT_BLOCK_INDEX = 3
# Maps feature dimensionality to their output blocks indices
BLOCK_INDEX_BY_DIM = {
64: 0, # First max pooling features
192: 1, # Second max pooling featurs
768: 2, # Pre-aux classifier features
2048: 3 # Final average pooling features
}
def __init__(self,
output_blocks=[DEFAULT_BLOCK_INDEX],
resize_input=True,
normalize_input=True,
requires_grad=False):
"""Build pretrained InceptionV3
Parameters
----------
output_blocks : list of int
Indices of blocks to return features of. Possible values are:
- 0: corresponds to output of first max pooling
- 1: corresponds to output of second max pooling
- 2: corresponds to output which is fed to aux classifier
- 3: corresponds to output of final average pooling
resize_input : bool
If true, bilinearly resizes input to width and height 299 before
feeding input to model. As the network without fully connected
layers is fully convolutional, it should be able to handle inputs
of arbitrary size, so resizing might not be strictly needed
normalize_input : bool
If true, normalizes the input to the statistics the pretrained
Inception network expects
requires_grad : bool
If true, parameters of the model require gradient. Possibly useful
for finetuning the network
"""
super(InceptionV3, self).__init__()
self.resize_input = resize_input
self.normalize_input = normalize_input
self.output_blocks = sorted(output_blocks)
self.last_needed_block = max(output_blocks)
assert self.last_needed_block <= 3, \
'Last possible output block index is 3'
self.blocks = nn.ModuleList()
inception = models.inception_v3(pretrained=True)
# Block 0: input to maxpool1
block0 = [
inception.Conv2d_1a_3x3,
inception.Conv2d_2a_3x3,
inception.Conv2d_2b_3x3,
nn.MaxPool2d(kernel_size=3, stride=2)
]
self.blocks.append(nn.Sequential(*block0))
# Block 1: maxpool1 to maxpool2
if self.last_needed_block >= 1:
block1 = [
inception.Conv2d_3b_1x1,
inception.Conv2d_4a_3x3,
nn.MaxPool2d(kernel_size=3, stride=2)
]
self.blocks.append(nn.Sequential(*block1))
# Block 2: maxpool2 to aux classifier
if self.last_needed_block >= 2:
block2 = [
inception.Mixed_5b,
inception.Mixed_5c,
inception.Mixed_5d,
inception.Mixed_6a,
inception.Mixed_6b,
inception.Mixed_6c,
inception.Mixed_6d,
inception.Mixed_6e,
]
self.blocks.append(nn.Sequential(*block2))
# Block 3: aux classifier to final avgpool
if self.last_needed_block >= 3:
block3 = [
inception.Mixed_7a,
inception.Mixed_7b,
inception.Mixed_7c,
nn.AdaptiveAvgPool2d(output_size=(1, 1))
]
self.blocks.append(nn.Sequential(*block3))
for param in self.parameters():
param.requires_grad = requires_grad
def forward(self, inp):
"""Get Inception feature maps
Parameters
----------
inp : torch.autograd.Variable
Input tensor of shape Bx3xHxW. Values are expected to be in
range (0, 1)
Returns
-------
List of torch.autograd.Variable, corresponding to the selected output
block, sorted ascending by index
"""
outp = []
x = inp
if self.resize_input:
x = F.upsample(x, size=(299, 299), mode='bilinear', align_corners=True)
if self.normalize_input:
x = x.clone()
x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
for idx, block in enumerate(self.blocks):
x = block(x)
if idx in self.output_blocks:
outp.append(x)
if idx == self.last_needed_block:
break
return outp
# 参考:https://github.com/sbarratt/inception-score-pytorch/blob/master/inception_score.py
import torch
from torch import nn
from torchvision.models.inception import inception_v3
from torch.nn import functional as F
from torch.autograd import Variable
import numpy as np
from scipy.stats import entropy
def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1):
"""Computes the inception score of the generated images imgs
imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
cuda -- whether or not to run on GPU
batch_size -- batch size for feeding into Inception v3
splits -- number of splits
"""
N = len(imgs)
assert batch_size > 0
assert N > batch_size
# Set up dtype
if cuda:
dtype = torch.cuda.FloatTensor
else:
if torch.cuda.is_available():
print("WARNING: You have a CUDA device, so you should probably set cuda=True")
dtype = torch.FloatTensor
# Set up dataloader
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
# Load inception model
inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype)
inception_model.eval()
up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype)
def get_pred(x):
if resize:
x = up(x)
x = inception_model(x)
return F.softmax(x).data.cpu().numpy()
# Get predictions
preds = np.zeros((N, 1000))
for i, batch in enumerate(dataloader, 0):
batch = batch.type(dtype)
batchv = Variable(batch)
batch_size_i = batch.size()[0]
preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv)
# Now compute the mean kl-div
split_scores = []
for k in range(splits):
part = preds[k * (N // splits): (k+1) * (N // splits), :]
py = np.mean(part, axis=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i, :]
scores.append(entropy(pyx, py))
split_scores.append(np.exp(np.mean(scores)))
return np.mean(split_scores), np.std(split_scores)
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch DataLoader for TFRecords"""
import torch
from torch.optim.lr_scheduler import _LRScheduler
import math
from utils import print_rank_0
class AnnealingLR(_LRScheduler):
"""Anneals the learning rate from start to zero along a cosine curve."""
DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, decay_ratio=0.5, restart_iter=0):
self.restart_iter = restart_iter
assert warmup_iter <= num_iters
self.optimizer = optimizer
self.start_lr = start_lr
self.warmup_iter = warmup_iter
self.num_iters = last_iter + 1
self.end_iter = num_iters
self.decay_style = decay_style.lower() if isinstance(decay_style, str) else None
self.decay_ratio = 1 / decay_ratio
self.step(self.num_iters)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print(f'learning rate decaying style {self.decay_style}, ratio {self.decay_ratio}')
def get_lr(self):
# https://openreview.net/pdf?id=BJYwwY9ll pg. 4
real_num_iters = self.num_iters - self.restart_iter
real_end_iter = self.end_iter - self.restart_iter
# print_rank_0(f'real_num_iters: {real_num_iters}')
if self.warmup_iter > 0 and real_num_iters <= self.warmup_iter:
return float(self.start_lr) * real_num_iters / self.warmup_iter
else:
if self.decay_style == self.DECAY_STYLES[0]:
return self.start_lr*((real_end_iter-(real_num_iters-self.warmup_iter))/real_end_iter)
elif self.decay_style == self.DECAY_STYLES[1]:
decay_step_ratio = min(1.0, (real_num_iters - self.warmup_iter) / real_end_iter)
return self.start_lr / self.decay_ratio * (
(math.cos(math.pi * decay_step_ratio) + 1) * (self.decay_ratio - 1) / 2 + 1)
elif self.decay_style == self.DECAY_STYLES[2]:
#TODO: implement exponential decay
return self.start_lr
else:
return self.start_lr
def step(self, step_num=None):
if step_num is None:
step_num = self.num_iters + 1
self.num_iters = step_num
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
group['lr'] = new_lr
def state_dict(self):
sd = {
# 'start_lr': self.start_lr,
'warmup_iter': self.warmup_iter,
'num_iters': self.num_iters,
'decay_style': self.decay_style,
'end_iter': self.end_iter,
'decay_ratio': self.decay_ratio
}
return sd
def load_state_dict(self, sd):
import pdb;pdb.set_trace()
# self.start_lr = sd['start_lr']
# self.warmup_iter = sd['warmup_iter']
self.num_iters = sd['num_iters']
# self.end_iter = sd['end_iter']
self.decay_style = sd['decay_style']
if 'decay_ratio' in sd:
self.decay_ratio = sd['decay_ratio']
self.step(self.num_iters)
from .utils import show_recover_results
# -*- encoding: utf-8 -*-
'''
@File : preprocess_text_image_data.py
@Time : 2021/01/24 15:38:44
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
from tqdm import tqdm
import pickle
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import lmdb
from .pretokenized_data import make_text_image_batch, make_tuple_text_image_batch, make_super_resolution_batch
import PIL
import timeit
@torch.no_grad()
def extract_code(model, datasets, text_dict, name, device, txt_type):
index = 0
map_size = 1024 * 1024 * 1024 * 1024
lmdb_env = lmdb.open(f'/root/mnt/lmdb/{name}', map_size=map_size, writemap=True)
print(f'/root/mnt/lmdb/{name}')
with lmdb_env.begin(write=True) as txn:
for dataset in datasets:
loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=1)
print(dataset)
pbar = tqdm(loader)
for raw_imgs, raw_filenames in pbar:
imgs = []
filenames = []
for i, filename in enumerate(raw_filenames):
if filename != "not_a_image" and text_dict.__contains__(filename):
imgs.append(raw_imgs[i])
filenames.append(filename)
else:
print("warning: deleted damaged image")
imgs = torch.stack(imgs)
imgs = imgs.to(device)
try:
if txt_type == "h5":
filenames = filenames.numpy()
txts = [text_dict[filename] for filename in filenames]
if txt_type != "h5":
codes = make_text_image_batch(model, txts, imgs)
else:
codes = make_tuple_text_image_batch(model, txts, imgs)
for code in codes:
txn.put(str(index).encode('utf-8'), pickle.dumps(code))
index += 1
except KeyError:
print("warning: KeyError. The text cannot be find")
pass
txn.put('length'.encode('utf-8'), str(index).encode('utf-8'))
@torch.no_grad()
def extract_code_super_resolution_patches(model, datasets, text_dict, name, device, txt_type):
index = 0
map_size = 1024 * 1024 * 1024 * 1024
lmdb_env = lmdb.open(f'/root/mnt/lmdb/{name}_super_resolution', map_size=map_size, writemap=True)
print(f'/root/mnt/lmdb/{name}_super_resolution')
with lmdb_env.begin(write=True) as txn:
for dataset in datasets:
loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=1)
print(dataset)
pbar = tqdm(loader)
for raw_imgs, raw_filenames in pbar:
imgs = []
filenames = []
for i, filename in enumerate(raw_filenames):
if filename != "not_a_image" and text_dict.__contains__(filename):
imgs.append(raw_imgs[i])
filenames.append(filename)
else:
print("warning: deleted damaged image")
imgs = torch.stack(imgs)
imgs = imgs.to(device)
try:
if txt_type == "h5":
filenames = filenames.numpy()
txts = [text_dict[filename] for filename in filenames]
if txt_type != "h5":
codes = make_super_resolution_batch(model, txts, imgs)
else:
codes = make_tuple_text_image_batch(model, txts, imgs)
for code in codes:
txn.put(str(index).encode('utf-8'), pickle.dumps(code))
index += 1
except KeyError:
print("warning: KeyError. The text cannot be find")
pass
txn.put('length'.encode('utf-8'), str(index).encode('utf-8'))
# -*- encoding: utf-8 -*-
'''
@File : preprocess_text_jsonformat_data.py
@Time : 2021/03/14 20:56:28
@Author : Ming Ding
@Contact : dm18@mail.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
from tqdm import tqdm
import pickle
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import lmdb
from .pretokenized_data import make_cut_text_batch
import timeit
import ujson as json
def extract_code(datasets, name, seq_len):
'''
datasets: [json_name1, json_name2, ...]
'''
index = 0
map_size = 1024 * 1024 * 1024 * 1024
lmdb_env = lmdb.open(f'/root/mnt/lmdb/{name}', map_size=map_size, writemap=True)
with lmdb_env.begin(write=True) as txn:
for dataset in datasets:
with open(dataset, 'r') as fin:
print(f'Loading {dataset}...')
raw_json = json.load(fin)["RECORDS"]
bs = 512
for i in tqdm(range(0, len(raw_json), bs)):
txts = [t["content"] for t in raw_json[i: i + bs]]
txts = make_cut_text_batch(txts, seq_len)
for code in txts:
txn.put(str(index).encode('utf-8'), pickle.dumps(code))
index += 1
txn.put('length'.encode('utf-8'), str(index).encode('utf-8'))
print(f'/root/mnt/lmdb/{name}, length={index}')
# -*- encoding: utf-8 -*-
'''
@File : pretokenized_data.py
@Time : 2021/01/20 15:39:10
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from vqvae import *
from data_utils import Code2CodeTemplate, concat_codes
from torchvision.transforms.functional import resize
from torchvision import transforms
from data_utils import get_tokenizer
# def make_hierarchical_batch(model, txts, imgs):
# '''
# model: VQVAE
# txts: ['text1', 'text2', ...]
# imgs: [b, 3, s, s]
# '''
# s = img.shape[-1]
# assert img.shape[-2] == s # square
# codes_base = img2code(model, img)
# img_tiny = resize(img, size=s//4).numpy()
# codes_tiny = img2code(model, img_tiny).numpy()
# ret = []
# for i in range(len(txts)):
# text = '[ROI1] ' + txts[i]
# ret.append(
# Code2CodeTemplate(text, codes_tiny[i], codes_base[i])
# )
# return ret
def make_super_resolution_batch(model, txts, imgs):
'''
[text...small_img...base_img]
'''
tokenizer = get_tokenizer()
if not hasattr(make_super_resolution_batch, 'pos'):
pos = ['左上', '正上', '右上', '左侧', '中间', '右侧', '左下', '正下', '右下']
pos = [
tokenizer.parse_query('[ROI1] 是{}部分图'.format(p))
for p in pos
] # [[23, 354...], [232, ...]]
pw = [0, 64, 128] * 3
ph = [0, 0, 0, 64, 64, 64, 128, 128, 128]
make_super_resolution_batch.pos = list(zip(pos, ph, pw))
make_super_resolution_batch.weights = [1] * 9
make_super_resolution_batch.prefix = tokenizer.parse_query('[ROI2] 是 [ROI1] 的放大图')
s = imgs.shape[-1]
assert s == imgs.shape[-2] == 256
# Crop 128 * 128 patch
selected_poses = random.choices(range(9), weights=make_super_resolution_batch.weights)
pos = make_super_resolution_batch.pos
patches = [
imgs[i, :, pos[p][1]:pos[p][1] + 128, pos[p][2]: pos[p][2]+128]
for i, p in enumerate(selected_poses)
]
patches = torch.stack(patches)
small_patches = resize(patches, size=64)
codes_base = img2code(model, patches).cpu().numpy()
codes_small = img2code(model, small_patches).cpu().numpy()
ret = []
for i in range(len(txts)):
code_text = tokenizer(txts[i])
ret.append(
concat_codes(code_text + make_super_resolution_batch.prefix,
codes_small[i],
pos[selected_poses[i]][0],
codes_base[i])
)
return ret
def make_super_resolution_batch(model, txts, imgs, img_size=512, sampling_num=4):
'''
[text...small_img...base_img]
'''
tokenizer = get_tokenizer()
t0, t1 = img_size // 4, img_size // 2
if img_size == 512:
size_tk = tokenizer['[BASE]']
else:
raise NotImplementedError
pw = [0, t0, t1] * 3
ph = [0, 0, 0, t0, t0, t0, t1, t1, t1]
ptk = [[tokenizer['[EOI1]'], tokenizer['[ROI2]'], tokenizer[f'[POS{i}]'], size_tk, tokenizer['[BOI2]']]
for i in range(9)
]
pos = list(zip(ptk, ph, pw))
weights = [1] * 9
s = imgs.shape[-1]
assert s == imgs.shape[-2] == img_size
# Crop img_size/2 * img_size/2 patch
selected_poses = random.choices(range(9), weights=weights, k=sampling_num)
pos = pos
patches = [
imgs[i, :, pos[p][1]:pos[p][1] + t1, pos[p][2]: pos[p][2]+t1]
for i in range(imgs.shape[0])
for p in selected_poses
]
patch_prefix = [
pos[p][0]
for p in selected_poses
] * imgs.shape[0]
patches = torch.stack(patches)
overviews = torch.nn.functional.interpolate(imgs, size=(t1, t1), mode='bilinear')
codes_patches = img2code(model, patches).cpu().numpy()
codes_overviews = img2code(model, overviews).cpu().numpy()
ret = []
for i in range(len(txts)):
code_text = [tokenizer['[ROI1]']] + tokenizer(txts[i]) + [size_tk, tokenizer['[BOI1]']]
for j in range(sampling_num):
ret.append(
concat_codes(code_text,
codes_overviews[i],
patch_prefix[i* sampling_num + j],
codes_patches[i * sampling_num + j],
[tokenizer['[EOI2]']]
)
)
return ret
def make_text_image_batch(model, txts, imgs):
from data_utils import TextCodeTemplate
s = imgs.shape[-1]
assert s == imgs.shape[-2] == 256
tokenizer = get_tokenizer()
codes = img2code(model, imgs).cpu().numpy()
ret = []
for i in range(len(txts)):
ret.append(
TextCodeTemplate(txts[i], codes[i])
)
return ret
def make_tuple_text_image_batch(model, txts, imgs):
s = imgs.shape[-1]
assert s == imgs.shape[-2] == 256
codes = img2code(model, imgs).cpu().numpy()
ret = []
for i in range(len(txts)):
ret.append(
(txts[i], codes[i])
)
return codes
import itertools
def make_cut_text_batch(txts, seq_len):
from data_utils import PureTextTemplate
tmp_list = np.array(list(
itertools.chain(*(PureTextTemplate(txt) for txt in txts))
))
ret = [
tmp_list[en - seq_len: en]
for en in range(seq_len, len(tmp_list), seq_len)
]
return ret
# -*- encoding: utf-8 -*-
'''
@File : raw_datasets.py
@Time : 2021/01/24 15:31:34
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
from tqdm import tqdm
import ctypes
import io
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, IterableDataset
from torchvision import datasets
import unrar
from PIL import Image
import timeit
from collections import Iterable
class ImageFileDataset(datasets.ImageFolder):
def __getitem__(self, index):
sample, target = super().__getitem__(index)
path, _ = self.samples[index]
dirs, filename = os.path.split(path)
filename = filename.split('.')[0]
return sample, filename
class RarDataset(Dataset):
def __init__(self, path, transform=None):
from unrar import rarfile
self.rar = rarfile.RarFile(path)
self.infos = self.rar.infolist()
self.transform = transform
def __len__(self):
return len(self.infos)
def __getitem__(self, idx):
target_info = self.infos[idx]
img = Image.open(self.rar.open(target_info))
dirs, filename = os.path.split(self.infos[idx].filename)
filename = filename.split('.')[0]
if self.transform is not None:
img = self.transform(img)
return img, filename
from unrar import rarfile
from unrar import unrarlib
from unrar import constants
from unrar.rarfile import _ReadIntoMemory, BadRarFile
import zipfile
import PIL
class ZipDataset(Dataset):
def __init__(self, path, transform=None):
self.zip = zipfile.ZipFile(path)
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.members = [info for info in self.zip.infolist() if info.filename[-1] != os.sep]
else:
all_members = [info for info in self.zip.infolist() if info.filename[-1] != os.sep]
num_workers = worker_info.num_workers
worker_id = worker_info.id
self.members = [x for i, x in enumerate(all_members) if i % num_workers == worker_id]
self.transform = transform
def __len__(self):
return len(self.members)
def __getitem__(self, idx):
target_info = self.members[idx]
img = Image.open(self.zip.open(target_info))
dirs, filename = os.path.split(self.members[idx].filename)
filename = filename.split('.')[0]
if self.transform is not None:
img = self.transform(img)
return img, filename
import h5py
class H5Dataset(Dataset):
def __init__(self, path, transform=None):
self.h5 = h5py.File(path, "r")
self.images = self.h5["input_image"]
self.members = None
self.transform = transform
def create_members(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.members = self.h5['index'][:]
else:
all_members = self.h5['index'][:]
num_workers = worker_info.num_workers
worker_id = worker_info.id
self.members = [x for i, x in enumerate(all_members) if i % num_workers == worker_id]
def __len__(self):
if self.members is None:
self.create_members()
return len(self.members)
def __getitem__(self, idx):
if self.members is None:
self.create_members()
target_info = self.members[idx]
try:
img = Image.fromarray(self.images[target_info][0])
if self.transform is not None:
img = self.transform(img)
return img, int(target_info)
except(OSError, IndexError):
print("warning: OSError or IndexError")
return Image.new('RGB', (256, 256), (255, 255, 255)), -1
# class StreamingZipDataset(IterableDataset):
# def __init__(self, path, transform=None):
# self.zip = zipfile.ZipFile(path, "r")
# self.transform = transform
# def __len__(self):
# return len(self.zip.filelist)
# def __next__(self):
# img = Image.open(self.rar.open(target_info))
#
# pass
# def __iter__(self):
# worker_info = torch.utils.data.get_worker_info()
# if worker_info is None:
# self.members = self.zip.namelist()
# else:
# all_members = self.zip.namelist()
# num_workers = worker_info.num_workers
# worker_id = worker_info.id
# self.members = [x for i, x in enumerate(all_members) if i % num_workers == worker_id]
# self.pointer = 0
# return self
# def __del__(self):
# self.zip.close()
class StreamingRarDataset(IterableDataset):
def __init__(self, path, transform=None, default_size=256):
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
print("begin open rar")
self.rar = rarfile.RarFile(path)
print("finish open rar")
self.transform = transform
def callback_fn(file_buffer, filename):
try:
img = Image.open(file_buffer.get_bytes()).convert('RGB')
dirs, filename = os.path.split(filename)
filename = filename.split('.')[0]
if self.transform is not None:
img = self.transform(img)
return img, filename
except PIL.UnidentifiedImageError:
print("UnidentifiedImageError")
return torch.zeros((3, default_size, default_size)), "not_a_image"
self.callback_fn = callback_fn
# new handle
self.handle = None
self.callback_fn = callback_fn
def __len__(self):
return len(self.rar.filelist)
def __next__(self):
if self.pointer >= len(self.members):
raise StopIteration()
if self.handle == None:
archive = unrarlib.RAROpenArchiveDataEx(
self.rar.filename, mode=constants.RAR_OM_EXTRACT)
self.handle = self.rar._open(archive)
# callback to memory
self.data_storage = _ReadIntoMemory()
c_callback = unrarlib.UNRARCALLBACK(self.data_storage._callback)
unrarlib.RARSetCallback(self.handle, c_callback, 0)
handle = self.handle
try:
rarinfo = self.rar._read_header(handle)
while rarinfo is not None:
if rarinfo.filename == self.members[self.pointer]:
self.rar._process_current(handle, constants.RAR_TEST)
break
else:
self.rar._process_current(handle, constants.RAR_SKIP)
rarinfo = self.rar._read_header(handle)
if rarinfo is None:
self.data_storage = None
except unrarlib.UnrarException:
raise BadRarFile("Bad RAR archive data.")
if self.data_storage is None:
raise KeyError('There is no item named %r in the archive' % self.members[self.pointer])
# return file-like object
ret = self.data_storage
if self.callback_fn is not None:
ret = self.callback_fn(ret, self.members[self.pointer])
self.pointer += 1
return ret
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.members = self.rar.namelist()
else:
all_members = self.rar.namelist()
num_workers = worker_info.num_workers
worker_id = worker_info.id
self.members = [x for i, x in enumerate(all_members) if i % num_workers == worker_id]
self.pointer = 0
return self
def __del__(self):
self.rar._close(self.handle)
# -*- encoding: utf-8 -*-
'''
@File : utils.py
@Time : 2021/01/24 16:35:43
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from vqvae import code2img, img2code
from torchvision.utils import save_image
def show_recover_results(model, imgs):
codes = img2code(model, imgs)
recovered = code2img(model, codes)
mean = torch.tensor([0.79093, 0.76271, 0.75340], device=recovered.device).view(-1, 1, 1)
std = torch.tensor([0.30379, 0.32279, 0.32800], device=recovered.device).view(-1, 1, 1)
recovered = (recovered * std + mean).clamp(0, 1)
imgs = (imgs * std + mean).clamp(0, 1)
out = torch.cat([imgs, recovered], dim=0)
save_image(out, 'samples/show_recover_results.jpg', normalize=False, nrow=len(imgs))
import os
import sys
import math
import random
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="preprocess args")
parser.add_argument("--dataset", type=str, required=True)
parser.add_argument("--img_tokenizer_path", type=str, default='vqvae_hard_biggerset_011.pt')
parser.add_argument("--encode_size", type=int, default=32)
parser.add_argument("--device", type=int, default=0)
args = parser.parse_args()
print(args)
img_size = args.encode_size * 8
# args = argparse.Namespace()
# args.img_tokenizer_path = 'pretrained/vqvae/vqvae_hard_018.pt'#old path
# args.img_tokenizer_path = 'pretrained/vqvae/vqvae_hard_biggerset_011.pt'
# args.img_tokenizer_path = '/root/mnt/vqvae_1epoch_64x64.pt'
args.img_tokenizer_num_tokens = None
device = f'cuda:{args.device}'
torch.cuda.set_device(device)
name = args.dataset + "_" + args.img_tokenizer_path.split(".")[0] + ".lmdb"
args.img_tokenizer_path = f"pretrained/vqvae/{args.img_tokenizer_path}"
datasets = {}
datasets["ali"] = [
['/root/mnt/sq_gouhou_white_pict_title_word_256_fulltitle.tsv'],
['/root/mnt/dingming/ali_white_picts_256.zip'],
"tsv"
]
datasets["ks3"] = [
['/root/mnt/KS3/a_baidu_image_msg_data.json'],
['/root/mnt/KS3/downloadImages.rar'],
"json_ks"
]
datasets["zijian"] = [
['/root/mnt/zijian/zj_duomotai_clean_done_data_new.json',
'/root/mnt/zijian/zj_duomotai_local_server_last_surplus_120w.json'],
['/root/mnt/imageFolder_part01.rar',
'/root/mnt/zijian/imagesFolder_last_surplus_120w.rar'],
"json"
]
datasets["google"] = [
['/root/mnt/google/google_image_message_data.json'],
['/root/mnt/google/downloadImage_2020_12_16.rar'],
"json_ks"
]
datasets["zijian1"] = [
['/root/mnt/zijian/zj_duomotai_clean_done_data_new.json'],
['/root/cogview2/data/imageFolder_part01.rar'],
"json"
]
datasets["zijian2"] = [
['/root/mnt/zijian/zj_duomotai_local_server_last_surplus_120w.json'],
['/root/mnt/zijian/imagesFolder_last_surplus_120w.rar'],
"json"
]
txt_files, img_folders, txt_type = datasets[args.dataset]
os.environ['UNRAR_LIB_PATH'] = '/usr/local/lib/libunrar.so'
from data_utils import get_tokenizer
tokenizer = get_tokenizer(args)
model = tokenizer.img_tokenizer.model
print("finish init vqvae_model")
from preprocess.preprocess_text_image_data import extract_code,extract_code_super_resolution_patches
# ===================== Define Imgs ======================== #
from preprocess.raw_datasets import H5Dataset, StreamingRarDataset, ZipDataset
datasets = []
for img_folder in img_folders:
if img_folder[-3:] == "rar":
dataset = StreamingRarDataset(path=img_folder, transform=transforms.Compose([
transforms.Resize(img_size),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800]),
]),
default_size=img_size)
elif img_folder[-3:] == "zip":
dataset = ZipDataset(path=img_folder, transform=transforms.Compose([
transforms.Resize(img_size),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800]),
]))
else:
dataset = H5Dataset(path=img_folder, transform=transforms.Compose([
transforms.Resize(img_size),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800]),
]))
datasets.append(dataset)
print('Finish reading meta-data of dataset.')
# ===================== END OF BLOCK ======================= #
# from preprocess import show_recover_results
# loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8)
# loader = iter(loader)
# samples = []
# for k in range(8):
# x = next(loader)
# print(x[1])
# x = x[0].to(device)
# samples.append(x)
# samples = torch.cat(samples, dim=0)
# show_recover_results(model, samples)
# ===================== Load Text ======================== #
if txt_type == "json":
import json
txt_list = []
for txt in txt_files:
with open(txt, 'r') as fin:
t = json.load(fin)
txt_list.extend(list(t.items()))
tmp = []
for k, v in tqdm(txt_list):
tmp.append((v['uniqueKey'], v['cnShortText']))
text_dict = dict(tmp)
elif txt_type == "json_ks":
import json
txt_list = []
for txt in txt_files:
with open(txt, 'r') as fin:
t = json.load(fin)
txt_list.extend(t["RECORDS"])
tmp = []
for v in tqdm(txt_list):
tmp.append((v['uniqueKey'], v['cnShortText']))
text_dict = dict(tmp)
elif txt_type == "tsv":
import pandas as pd
txt_list = []
for txt in txt_files:
t = pd.read_csv(txt, sep='\t')
txt_list.extend(list(t.values))
tmp = []
for k, v in tqdm(txt_list):
tmp.append((str(k), v))
text_dict = dict(tmp)
else:
des = dataset.h5["input_concat_description"]
txt_name = dataset.h5["input_name"]
tmp = []
for i in tqdm(range(len(des))):
tmp.append((i, des[i][0].decode("latin-1")+txt_name[i][0].decode("latin-1")))
text_dict = dict(tmp)
print('Finish reading texts of dataset.')
# ===================== END OF BLOCK ======================= #
# extract_code(model, datasets, text_dict, name, device, txt_type)
extract_code_super_resolution_patches(model, datasets, text_dict, name, device, txt_type)
\ No newline at end of file
from data_utils.datasets import BinaryDataset
from data_utils import get_tokenizer
import argparse
import os
import torch
import random
test_dir = 'tmp'
# bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin'
bin_dir = '/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata'
bin_ds = BinaryDataset(os.path.join(bin_dir), process_fn=lambda x:x, length_per_sample=16**2+64*64+32*32+64, dtype='int32', preload=False)
args = argparse.Namespace(img_tokenizer_path='pretrained/vqvae/vqvae_hard_biggerset_011.pt', img_tokenizer_num_tokens=None)
tokenizer = get_tokenizer(args)
bin_ds = [bin_ds[random.randint(0, len(bin_ds)-1)] for i in range(32)]
for x in bin_ds:
if x[63] != -1:
end = 64
else:
end = x.tolist().index(-1)
print(tokenizer.DecodeIds(x[:end])[0])
from torchvision.utils import save_image
imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64:64+16**2], dtype=torch.long, device='cuda')) for x in bin_ds], dim=0)
save_image(imgs, os.path.join(test_dir, 'testcase128.jpg'), normalize=True)
imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64+16**2:64+16**2+32**2], dtype=torch.long,device='cuda')) for x in bin_ds], dim=0)
save_image(imgs, os.path.join(test_dir, 'testcase256.jpg'), normalize=True)
imgs = torch.cat([tokenizer.img_tokenizer.DecodeIds(torch.tensor(x[64+16**2+32**2:], dtype=torch.long,device='cuda')) for x in bin_ds], dim=0)
save_image(imgs, os.path.join(test_dir, 'testcase512.jpg'), normalize=True)
\ No newline at end of file
<p align="center">
<img src="assets/logo.png"/>
</p>
<p align="center">
<b>Generate vivid Images for <i>Any</i> (Chinese) text</b>
</p>
![teaser](assets/cogviewcase.png)
CogView is a pretrained (4B-param) transformer for text-to-image generation in general domain.
* **Read** our paper [CogView: Mastering Text-to-Image Generation via Transformers](https://arxiv.org/pdf/2105.13290.pdf) on ArXiv for a formal introduction. The *PB-relax* and *Sandwich-LN* can also help you train large and deep transformers stably (e.g. eliminating NaN losses).
* **Visit** our demo at [Github Page](https://thudm.github.io/CogView/index.html) or [Wudao](https://wudao.aminer.cn/CogView/)! (Without post-selection or super-resolution, currently only supports simplified Chinese input, but one can translate text from other languages into Chinese for input. Note: *Wudao* provides faster access for users from China mainland.)
* **Download** our pretrained models from [Project Wudao-Wenhui](https://resource.wudaoai.cn/home?ind=2&name=WuDao%20WenHui&id=1399364355975327744)(悟道-文汇).
* **Cite** our paper if you find our work is helpful~
```
@article{ding2021cogview,
title={CogView: Mastering Text-to-Image Generation via Transformers},
author={Ding, Ming and Yang, Zhuoyi and Hong, Wenyi and Zheng, Wendi and Zhou, Chang and Yin, Da and Lin, Junyang and Zou, Xu and Shao, Zhou and Yang, Hongxia and Tang, Jie},
journal={arXiv preprint arXiv:2105.13290},
year={2021}
```
* **Google Colab** Two contributors successfully setup up CogView on Colab [![Links to Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/THUDM/CogView/issues/10)!
## Getting Started
### Setup
* Hardware: Linux servers with Nvidia V100s or A100s are recommended, but it is also okay to run the pretrained models with smaller `--max-inference-batch-size` or training smaller models on less powerful GPUs.
* Environment (Option 1): Please first install PyTorch (>=1.7.0) and [apex](https://github.com/NVIDIA/apex), and then install other dependencies via `pip install -r requirements.txt`.
* Environment (Option 2): We prepare a docker image in case that you fail to handle the environments. Pull the image, create a (background) container and get into it via:
```
docker pull cogview/cuda111_torch181_deepspeed040
./env/start_docker.sh && docker exec -it bg-cogview bash
cd /root/cogview # in the container
```
### Download
0. Download the image tokenizer `vqvae_hard_biggerset_011.pt` from [BAAI website](https://resource.wudaoai.cn/home?ind=2&name=WuDao%20WenHui&id=1399364355975327744) or Tsinghua Cloud. Place the file under `pretrained/vqvae`.
```
wget https://cloud.tsinghua.edu.cn/f/71607a5dca69417baa8c/?dl=1 -O pretrained/vqvae/vqvae_hard_biggerset_011.pt
```
1. Download models from [Project Wudao-Wenhui](https://resource.wudaoai.cn/home?ind=2&name=WuDao%20WenHui&id=1399364355975327744).
| FileName | Discription |
| ---- | ---- |
| cogview-base.tar | The pretrained text-to-image model. |
| cogview-caption.tar | Finetuned image-to-text model, also used for reranking. |
| cogview-sr.tar | Finetuned super-resolution model. (warning: it runs slow.) |
Uncompress them into `pretrained/cogview/`. The following command should be modified based on the model name.
```
tar -xvf cogview-{base, sr, caption}.tar -C pretrained/cogview/
```
2. (Only for training tutorial, skip it for inference.) Download a small "bird-and-animal" example dataset from our link at Tsinghua Cloud.
```
wget https://cloud.tsinghua.edu.cn/f/1e4963ec8ac84941ba68/?dl=1 -O data/bird_animal.bin
```
### Run CogView! (Model Inference)
We encapsulate the generation functions into scripts. See `generate_samples.py` and `arguments.py` for details.
#### Text-to-Image Generation
Write text queries (one per line) into `input.txt` and run:
```
./scripts/text2image.sh --debug
```
The results will in a new folder `samples_text2image/`.
Arguments useful in inference are mainly:
* `--input-source [path or "interactive"]`. The path of the input file, can also be "interactive", which will launch a CLI.
* `--output-path [path]`. The folder containing the results.
* `--batch-size [int]`. The number of samples will be generated per query.
* `--max-inference-batch-size [int]`. Maximum batch size per forward. Reduce it if OOM.
* `--debug`. Only save concatenated images for all generated samples, and name them by input text and date.
* `--with-id`. When it toggled, you must specify an "id" before each input, e.g. `001\t一个漂亮的女孩`, \t denoting TAB (**NOT space**). It will generate `batch-size` split images in a folder named "id" for each input. Confict with `--debug`.
* `--device [int]`. Running on which GPU.
#### Super-resolution
Run the following script and input `text\t{image_path}`, where `{image_path}` means the path of a previously generated image.
```
./scripts/super_resolution.sh
```
Note: *It is only effective for generated images from our Image Tokenizer (due to the token distribution).*
#### Image-to-Text
The input is "one image path per line", and will print the results to stdout.
```
./scripts/image2text.sh
```
Note: *Not optimized for this task, so it might not very competitive (but okay). We will consider to release a version funetuning for a longer period on this task in the future.* (*TODO*)
#### Post-selection
This application only takes file inputs, where each line is `{text}\t{image_path1}\t{image_path2}\t{image_path3}...`.
The output is `{output_path}/scores.txt`, a line of a list of scores, following a line from inputs.
```
./scripts/post_selection.sh
```
Note: *In the released codes, for simplicity, we did not expose the raw API , which supports some advanced generation modes, e.g. text and part of image.*
## Training
Here we use a subset of our dataset from bird-and-animal for tutorial. The binary dataset is generated by our [cogdata toolkit](https://github.com/Sleepychord/cogdata). Please wait for a formal release with tutorials of cogdata (although it is available now).
### Single Node
After downloading the dataset, directly run
```
./scripts/pretrain_single_node.sh
```
### Multiple Nodes
If you want to train the models on multiple servers inter-connected by infiniband without a shared file system (you may need `pdsh` to accelerate this process):
1. On **each** server, use `git clone` to download this repo, and make sure the data (LMDB format) are moved into the `data` subfolder.
2. On **each** server, `echo "ip1 ip2 <other IPs>" > ./docker/ip_list.txt`, and then start the docker by `./env/start_docker.sh`.
3. Get into **the docker on the first node** container via `docker exec -it bg-cogview bash`.
4. Get into `/root/cogview` and run `./scripts/pretrain_multiple_nodes.sh`. You may need to change the config (especially `OPTIONS_NCCL`) in the shell script.
See the `arguments.py` for advanced functions for training.
*TODO*
## Gallery
![more_samples](assets/coco_new.png)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment