From 8584bf91059e00341bca0df1ed5f932f4ae1e415 Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Thu, 14 Oct 2021 07:22:04 +0000
Subject: [PATCH] prepare for large-scale generation

---
 .gitignore                                 |  3 +-
 arguments.py                               |  2 +
 generation/cuda2d_sampling.py              |  1 -
 inference_cogview.py                       |  2 +-
 inference_cogview2.py                      | 10 ++--
 scripts/large_scale_text2image_cogview2.sh | 53 ++++++++++++++++++++++
 6 files changed, 64 insertions(+), 7 deletions(-)
 create mode 100755 scripts/large_scale_text2image_cogview2.sh

diff --git a/.gitignore b/.gitignore
index 45ed55c..831cfe2 100755
--- a/.gitignore
+++ b/.gitignore
@@ -14,4 +14,5 @@ pretrained/
 *.jpg
 *.jpeg
 input*.txt
-samples*
\ No newline at end of file
+*samples*/
+
diff --git a/arguments.py b/arguments.py
index f87d809..026e406 100755
--- a/arguments.py
+++ b/arguments.py
@@ -246,6 +246,8 @@ def get_args(args_list=None):
     args.rank = int(os.getenv('RANK', '0'))
     args.world_size = int(os.getenv("WORLD_SIZE", '1'))
     
+    if args.local_rank is not None:
+        args.device = args.local_rank
 
     args.model_parallel_size = min(args.model_parallel_size, args.world_size)
     if args.rank == 0:
diff --git a/generation/cuda2d_sampling.py b/generation/cuda2d_sampling.py
index c7ebf31..c191540 100644
--- a/generation/cuda2d_sampling.py
+++ b/generation/cuda2d_sampling.py
@@ -84,7 +84,6 @@ def filling_sequence_cuda2d(
             for x in range(min(ll, step_cnt - warmup_steps)):
                 y = step_cnt - warmup_steps - x - 1
                 if y < rr:
-                    print(x,y)
                     unfixed[..., -(layout[-1] - layout[-2]):].view(
                         batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = False
 
diff --git a/inference_cogview.py b/inference_cogview.py
index 88311ff..dd4c90f 100644
--- a/inference_cogview.py
+++ b/inference_cogview.py
@@ -42,7 +42,7 @@ def main(args):
     
     def process(raw_text):
         if args.with_id:
-            query_id, raw_text = raw_text.split()
+            query_id, raw_text = raw_text.split('\t')
         print('raw text: ', raw_text)
         text = query_template.format(raw_text)
         seq = tokenizer.parse_query(text, img_size=args.img_size)
diff --git a/inference_cogview2.py b/inference_cogview2.py
index 6895d36..75ad1c6 100644
--- a/inference_cogview2.py
+++ b/inference_cogview2.py
@@ -52,7 +52,7 @@ def main(args):
     
     def process(raw_text):
         if args.with_id:
-            query_id, raw_text = raw_text.split()
+            query_id, raw_text = raw_text.split('\t')
         print('raw text: ', raw_text)
         text = query_template.format(raw_text)
         seq = tokenizer.parse_query(text, img_size=args.img_size)
@@ -62,7 +62,7 @@ def main(args):
         txt_len = seq.index(tokenizer['[BASE]'])
         log_attention_weights = torch.zeros(len(seq), len(seq), 
             device=args.device, dtype=torch.half if args.fp16 else torch.float32)
-        log_attention_weights[txt_len+2:, 1:txt_len] = 1.8 if txt_len <= 10 else 1.4 # TODO args
+        log_attention_weights[txt_len+2:, 1:txt_len] = 1.8 if txt_len <= 10 else 1.6 # TODO args
 
         # generation
         seq = torch.cuda.LongTensor(seq, device=args.device)
@@ -77,11 +77,13 @@ def main(args):
                     )
             imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1025:-1].tolist())) for x in output0]
             blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(args.device), add_normalization=True) # [batch_size, 4096]
-            output1 = filling_sequence_cuda2d(model, output0, blur64, 
+            len_tim = output0.shape[0]
+            for tim2 in range(0, len_tim, 4):
+                output1 = filling_sequence_cuda2d(model, output0[tim2: tim2+4], blur64[tim2: tim2+4], 
                     warmup_steps=3, block_hw=(4, 4),
                     strategy=strategy1
                     )
-            output_list.append(output1)
+                output_list.append(output1)
         output_tokens = torch.cat(output_list, dim=0)
         # decoding
         imgs, txts = [], []
diff --git a/scripts/large_scale_text2image_cogview2.sh b/scripts/large_scale_text2image_cogview2.sh
new file mode 100755
index 0000000..fb1eb17
--- /dev/null
+++ b/scripts/large_scale_text2image_cogview2.sh
@@ -0,0 +1,53 @@
+#!/bin/bash
+
+NUM_WORKERS=4
+NUM_GPUS_PER_WORKER=8
+MP_SIZE=1
+
+OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
+HOST_FILE_PATH="hostfile"
+# HOST_FILE_PATH="hostfile_single"
+
+CHECKPOINT_PATH=pretrained/cogview/cogview2-base
+NLAYERS=48
+NHIDDEN=2560
+NATT=40
+MAXSEQLEN=1089
+MASTER_PORT=$(shuf -n 1 -i 10000-65535)
+MPSIZE=1
+
+#SAMPLING ARGS
+TEMP=1.03
+TOPK=200
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+
+gpt_options=" \
+       --tokenizer-type cogview \
+       --img-tokenizer-path pretrained/vqvae/l1+ms-ssim+revd_percep.pt \
+       --mode inference \
+       --distributed-backend nccl \
+       --max-sequence-length 1089 \
+       --sandwich-ln \
+       --fp16 \
+       --model-parallel-size $MPSIZE \
+       --num-layers $NLAYERS \
+       --hidden-size $NHIDDEN \
+       --load $CHECKPOINT_PATH \
+       --num-attention-heads $NATT \
+       --temperature $TEMP \
+       --top_k $TOPK \
+       --sandwich-ln \
+       --input-source ./coco30k.txt \
+       --output-path coco_samples \
+       --batch-size 60 \
+       --max-inference-batch-size 12 \
+       --with-id \
+    "
+
+run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} inference_cogview2.py $@ ${gpt_options}"
+echo ${run_cmd}
+eval ${run_cmd}
+
+set +x
\ No newline at end of file
-- 
GitLab