# -*- encoding: utf-8 -*- ''' @File : utils.py @Time : 2021/01/24 16:35:43 @Author : Ming Ding @Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib import os import sys import math import random from tqdm import tqdm import numpy as np import torch import torch.nn.functional as F from vqvae import code2img, img2code from torchvision.utils import save_image def show_recover_results(model, imgs): codes = img2code(model, imgs) recovered = code2img(model, codes) mean = torch.tensor([0.79093, 0.76271, 0.75340], device=recovered.device).view(-1, 1, 1) std = torch.tensor([0.30379, 0.32279, 0.32800], device=recovered.device).view(-1, 1, 1) recovered = (recovered * std + mean).clamp(0, 1) imgs = (imgs * std + mean).clamp(0, 1) out = torch.cat([imgs, recovered], dim=0) save_image(out, 'samples/show_recover_results.jpg', normalize=False, nrow=len(imgs))