Skip to content
Snippets Groups Projects
test_lmdb.py 1.49 KiB
Newer Older
  • Learn to ignore specific revisions
  • Ming Ding's avatar
    Ming Ding committed
    import lmdb
    import os, sys
    from data_utils import get_tokenizer
    
    def initialize(file_name):
        env = lmdb.open(file_name, "r")
        return env
    
    def insert(env, sid, name):
        txn = env.begin(write=True)
        txn.put(str(sid).encode('utf-8'), name.encode('utf-8'))
        txn.commit()
    
    def delete(env, sid):
        txn = env.begin(write=True)
        txn.delete(str(sid).encode('utf-8'))
        txn.commit()
    
    def update(env, sid, name):
        txn = env.begin(write=True)
        txn.put(str(sid).encode('utf-8'), name.encode('utf-8'))
        txn.commit()
    
    
    import pickle
    def search(env, sid):
        txn = env.begin()
        data = pickle.loads(txn.get(str(sid).encode('utf-8')))
        return data
    
    import argparse
    import torch
    from torchvision.utils import save_image
    
    if __name__ == "__main__":
        # settings
        lmdb_path = "data/ali_vqvae_hard_biggerset_011.lmdb"
        output_path = f"test_lmdb_{lmdb_path.split('/')[-1]}.jpg"
        args = argparse.Namespace()
        args.img_tokenizer_path = 'pretrained/vqvae/vqvae_hard_biggerset_011.pt'
        args.img_tokenizer_num_tokens = None
        device = 'cuda:0'
    
        torch.cuda.set_device(device)
        tokenizer = get_tokenizer(args)
        with lmdb.open(lmdb_path, readonly=True, lock=False) as env:
            imgs = []
            txts = []
            for i in range(20,50):
                txt, images = tokenizer.DecodeIds(search(env, i))
                txts.append(txt)
                imgs.append(images[0])
            print(txts)
            imgs = torch.cat(imgs, dim=0)
            save_image(imgs, output_path,  normalize=True, range=None)