From 7c5a12da328f6bd44ff98cdc9c2aa392db516fe6 Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Fri, 8 Oct 2021 18:15:08 +0000
Subject: [PATCH] tmp finish naive ar sampling

---
 generation/__init__.py                        |   1 -
 generation/autoregressive_sampling.py         |  99 +++++++++++
 generation/cuda2d_sampling.py                 | 157 ++++++++++++++++++
 generation/local_sampling.py                  | 147 ----------------
 generation/sampling_strategies/__init__.py    |   1 +
 .../sampling_strategies/base_strategy.py      |  46 +++++
 model/cached_autoregressive_model.py          |   3 +-
 model/mixins.py                               |   4 +-
 mpu/transformer.py                            |   5 +-
 pretrain_cogview2.py                          |   2 +-
 scripts/finetune_into_cogview2.sh             |  59 +++++++
 training/model_io.py                          |   2 +-
 12 files changed, 370 insertions(+), 156 deletions(-)
 create mode 100644 generation/autoregressive_sampling.py
 create mode 100644 generation/cuda2d_sampling.py
 delete mode 100644 generation/local_sampling.py
 create mode 100644 generation/sampling_strategies/__init__.py
 create mode 100644 generation/sampling_strategies/base_strategy.py
 create mode 100755 scripts/finetune_into_cogview2.sh

diff --git a/generation/__init__.py b/generation/__init__.py
index 94f6ea1..90c43c6 100755
--- a/generation/__init__.py
+++ b/generation/__init__.py
@@ -1,4 +1,3 @@
 from .sampling import get_batch, filling_sequence, add_interlacing_beam_marks, inverse_prompt_score
 from .magnify import magnify
-from .local_sampling import filling_sequence_local
 from .cuda_2d_sampling import filling_sequence_cuda_2d
\ No newline at end of file
diff --git a/generation/autoregressive_sampling.py b/generation/autoregressive_sampling.py
new file mode 100644
index 0000000..f1a047c
--- /dev/null
+++ b/generation/autoregressive_sampling.py
@@ -0,0 +1,99 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   autoregressive_sampling.py
+@Time    :   2021/10/08 15:43:59
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+from .sampling_strategies import BaseStrategy
+
+def get_masks_and_position_ids(seq):
+    tokens = seq.unsqueeze(0)
+
+    attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
+    attention_mask.tril_()
+    attention_mask.unsqueeze_(1)
+
+    position_ids = torch.arange(len(seq), dtype=torch.long, device=tokens.device)
+    position_ids = position_ids.unsqueeze(0)
+    return tokens, attention_mask, position_ids
+
+def update_mems(hiddens, mems, max_memory_length):
+    if hiddens is None:
+        return []
+    memory_length = mems[0].size(1) if mems else 0
+    query_length = hiddens[0].size(1)
+    new_memory_length = min(max_memory_length, memory_length + query_length)
+    new_mems = []
+    with torch.no_grad():
+        for i in range(len(hiddens)):
+            if new_memory_length <= query_length:
+                new_mems.append(hiddens[i][:, -new_memory_length:])
+            else:
+                new_mems.append(torch.cat((mems[i][:, -new_memory_length+query_length:], hiddens[i]), dim=1))
+    return new_mems
+
+
+def filling_sequence(
+        model, 
+        seq, 
+        batch_size,
+        max_memory_length=100000,
+        strategy=BaseStrategy()
+        ):
+    '''
+        seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
+    '''
+    assert len(seq.shape) == 1
+
+    # building the initial tokens, attention_mask, and position_ids
+    context_length = 0
+    while seq[context_length] >= 0:
+        context_length += 1 # [0, context_length-1] are given
+    assert context_length > 0
+    tokens, attention_mask, position_ids = get_masks_and_position_ids(seq)
+    tokens = tokens[..., :context_length]
+
+    # initialize generation
+    counter = context_length - 1 # Last fixed index is ``counter'' 
+    index = 0 # Next forward starting index, also the length of cache.
+    mems = [] # mems are the first-level citizens here, but we don't assume what is memorized.
+    
+    # step-by-step generation
+    while counter < len(seq) - 1:
+        # Now, we want to generate seq[counter + 1],
+        # token[:, index: counter+1] needs forwarding.
+
+        if seq[counter + 1] >= 0: # provided
+            tokens = torch.cat(
+                (
+                    tokens, 
+                    seq[counter+1: counter+2].expand(tokens.shape[0], 1)
+                ), dim=1
+            )
+            counter += 1
+            continue
+
+        # forward
+        logits, *mem_kv = model(
+            tokens[:, index:], 
+            position_ids[..., index: counter+1],
+            attention_mask[..., index: counter+1, :counter+1], # TODO mem
+            *mems
+        )
+        mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
+        counter += 1
+        index = counter
+
+        # sampling
+        logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size]
+        tokens, mems = strategy.forward(logits, tokens, mems)
+        
+    return tokens
\ No newline at end of file
diff --git a/generation/cuda2d_sampling.py b/generation/cuda2d_sampling.py
new file mode 100644
index 0000000..bd75552
--- /dev/null
+++ b/generation/cuda2d_sampling.py
@@ -0,0 +1,157 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   cuda2d_sampling.py
+@Time    :   2021/10/09 00:46:04
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+from .sampling_strategies import BaseStrategy
+
+def filling_sequence(
+        model, 
+        seq0,
+        seq1, 
+        warmup_steps=3,
+        block_hw=(4, 4),
+        strategy=BaseStrategy(topk=10)
+        ):
+    '''
+        seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
+            4095 {layout[2]} final_token
+    '''
+    assert hasattr(model, 'layout')
+    layout = model.layout
+    assert len(seq0.shape) == 2 and len(seq1.shape) == 2 \
+        and seq0.shape[0] == seq1.shape[0]
+    assert len(layout) == 3
+    assert seq1.shape[1] == layout[-1] - layout[-2]
+    assert (seq1 >= 0).all() and (seq0 >= 0).all()
+    device = seq0.device
+
+    # concat and pad sequences
+    batch_size = seq0.shape[0]
+    n_pad = layout[1] + 1 - len(seq0) # +1 for [EOI1]
+    assert n_pad > 0, "You should truncate long input before filling."
+    seq = torch.cat((
+        torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype)
+            .unsqueeze(0).expand(batch_size, n_pad),
+        seq0, seq1), dim=1) # [b, layout[-1]+1]
+    assert seq.shape[1] == layout[-1] + 1
+
+    # build initial tokens, attention_mask, and position_ids
+    tokens = seq[:, :-1].clone()
+    attention_mask = torch.ones(layout[1], layout[1]).tril().to(device)
+    attention_mask[n_pad:, :n_pad] = 0
+    position_ids = torch.cat((
+        torch.zeros(n_pad, dtype=torch.long),
+        torch.arange(0, layout[1] - n_pad), 
+        torch.arange(0, layout[2]-layout[1]))).to(device)
+
+    # iterative refining
+    ll, rr = block_hw
+    num_steps = warmup_steps + ll + rr - 2
+    for step_cnt in range(num_steps):
+        logits, *_dump = model(tokens, position_ids, attention_mask)
+
+        # warmup 
+        real_topk = 10
+        warmup_steps = 3
+        iterative_step= warmup_steps + 6
+        if step_cnt <= warmup_steps:
+            real_temp = 0.1
+        elif step_cnt == warmup_steps + 1:
+            real_temp = 0.55
+        elif step_cnt > warmup_steps + 1:
+            real_temp = 0.45
+        # if  5 < step_cnt:
+        #     real_topk = 200
+        # sampling
+        for invalid_slice in invalid_slices: # forbide to generate other tokens
+            logits[..., invalid_slice] = -float('Inf')
+        assert args.top_k > 0
+        
+        # probs0 = F.softmax(logits/real_temp, dim=-1)
+        topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1)
+        ent = -(topraw * topraw.log()).sum(dim=-1)
+        # topsum = topraw.sum(dim=-1)
+        if step_cnt > warmup_steps:
+            # import pdb;pdb.set_trace()
+            real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.3).unsqueeze(-1) + 0.6
+            # import pdb;pdb.set_trace()
+        else:
+            real_temp2 = real_temp
+        # import pdb;pdb.set_trace()
+        probs = F.softmax(logits/real_temp2, dim=-1)
+        tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1)
+        prev = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
+        edge_idx = tk_idx[:, :, -1:]
+        edge_value = tk_value[:, :, -1:]
+        edge_mask = probs.gather(dim=-1, index=prev) < edge_value
+        prev[edge_mask] = edge_idx[edge_mask]
+        prev.squeeze_(-1)
+        # tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1])
+        # prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1)
+        # prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1)
+        # update unfixed
+        choice = 1
+        if choice == 0 and warmup_steps < step_cnt:
+            mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
+            # import pdb;pdb.set_trace()
+            dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:])
+
+            new_fixed = unfixed.clone()
+            moved_new_fixed = new_fixed[:, 2:]
+            moved_new_fixed &= dprob
+            moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not()
+            moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not()
+            # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not()
+            moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not()
+            moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not()
+            # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not()
+        elif choice == 1 and warmup_steps < step_cnt:
+            new_fixed = unfixed & False
+            ll, rr = 4, 4
+            for x in range(min(ll, step_cnt - warmup_steps)):
+                y = step_cnt - warmup_steps - x - 1
+                if y < rr:
+                    print(x,y)
+                    new_fixed[..., -4096:].view(batch_size, 64//ll, ll, 64//rr, rr)[:, :, x, :, y] = True
+            new_fixed &= unfixed
+        else:
+            new_fixed = unfixed & False # TODO
+        new_fixed[:, -1] = True
+
+        # with open(f'bed{step_cnt}.txt', 'w') as fout:
+        #     for i, prob in enumerate(topraw[0, -4096:]):
+        #         s = ' '.join([str(x) for x in prob.tolist()])
+        #         fout.write(f'{i} {s}\n')
+
+        unfixed &= new_fixed.logical_not()
+        # update seq and tokens
+        seq[new_fixed] = prev[new_fixed[:, 1:]]
+        tokens = seq[:, :-1].clone()
+        tokens[:,1:][unfixed[:, 1:-1]] = prev[:, :-1][unfixed[:, 1:-1]]
+
+        if step_cnt == iterative_step: 
+            seq[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
+            n_unfixed = unfixed.sum(dim=-1).tolist()
+            print(f'Exit with {n_unfixed} unfixed tokens.')
+            break
+        if args.debug:
+            from torchvision.utils import save_image
+            seqt = seq.clone()
+            seqt[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
+            imgs.extend([tokenizer.img_tokenizer.DecodeIds(s[-4096:]) for s in seqt])
+    if args.debug:
+        imgs = torch.cat(imgs, dim=0)
+        save_image(imgs, f'steps{device}.jpg', normalize=True)
+    model.module.transformer.max_memory_length = args.max_memory_length
+
+    return seq
\ No newline at end of file
diff --git a/generation/local_sampling.py b/generation/local_sampling.py
deleted file mode 100644
index 9fba71a..0000000
--- a/generation/local_sampling.py
+++ /dev/null
@@ -1,147 +0,0 @@
-from .sampling import *
-import math
-import sys
-from copy import deepcopy
-
-def make_local_mask(sparse_config):
-    layout = sparse_config.layout
-    k1, k2 = sparse_config.kernel_size, sparse_config.kernel_size2
-    k1h = k1*2-1
-    h = w = int(math.sqrt(layout[2] - layout[1]) + 1e-3)
-    m = torch.zeros(layout[-1]+1, layout[-1], dtype=torch.bool, device='cuda')
-    for i in range(layout[1]):
-        m[i, :i] = True
-    m[layout[1]:, :layout[0]] = True
-    for i in tqdm(range(layout[1], layout[2])):
-        # m[i, layout[1]:i] = True
-        x = (i - layout[1]) // w
-        y = (i - layout[1]) % w
-        lx = max(0, x - k1h // 2)
-        ly = max(0, y - k1 // 2)
-        rx = min(h-1, x + k1h // 2)
-        ry = min(w-1, y + k1 // 2)
-        m[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = True
-        m[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = True
-        m[i, i] = False
-    for i in tqdm(range(layout[2], layout[3])):
-        x = (i - layout[2]) // (2*w)
-        y = (i - layout[2]) % (2*w)
-        lx = max(0, x - k1h // 2)
-        ly = max(0, y - k1 // 2)
-        rx = min(2*h-1, x + k1h // 2)
-        ry = min(2*w-1, y + k1 // 2)
-        m[i, layout[2]:layout[3]].view(h*2, w*2)[lx:x, ly:ry+1] = True
-        m[i, layout[2]:layout[3]].view(h*2, w*2)[x, ly:y+1] = True
-        x = x // 2
-        y = y // 2
-        lx = max(0, x - k2 // 2)
-        ly = max(0, y - k2 // 2)
-        rx = min(h-1, x + k2 // 2)
-        ry = min(w-1, y + k2 // 2)
-        m[i, layout[1]:layout[2]].view(h, w)[lx:rx+1, ly:ry+1] = True
-        m[i, i] = False
-    return m.unsqueeze(-1)
-
-def update_mems_local(hiddens, mems, start, end, mem_bag, mask):
-    if isinstance(hiddens, list):
-        hiddens = torch.stack(hiddens)
-        mem_bag[:, :, start:end] = hiddens.to('cuda')
-     # first level
-    del mems
-    mems = mem_bag.masked_select(mask[end]).view(*mem_bag.shape[:2], -1, mem_bag.shape[3]).to(hiddens.device)
-    return mems
-
-
-def filling_sequence_local(
-        model, 
-        seq, 
-        args, 
-        mems=None, 
-        invalid_slices=[], 
-        **kwargs):
-    '''
-        seq: [2, 3, 5, ..., -1(to be generated), -N (N beams), -1]
-        context_length: first non(-1)s
-    '''
-    loss_sum, loss_n = 0, 0.0001
-    tokenizer = get_tokenizer()
-    device = seq.device
-    assert len(seq.shape) == 1
-    assert args.sparse_config.sparse_type == 'standard'
-    sparse_config = deepcopy(args.sparse_config)
-    # with open('tmp_save.bin', 'rb') as fout:
-    #     seq = torch.load(fout)[1]
-    #     seq = torch.cat((seq, torch.tensor([-1], device=seq.device)))
-    # import pdb; pdb.set_trace()
-    # sparse_config.layout[0] = seq.tolist().index(-1)
-    sparse_config.layout[0] = seq.tolist().index(tokenizer['[BASE]'])
-    n_pad = sparse_config.layout[1] - sparse_config.layout[0]
-    assert n_pad > 0 # TODO long trunc
-    seq = torch.cat((seq[:sparse_config.layout[0]], torch.tensor([tokenizer['[POS8]']]* n_pad, device=seq.device, dtype=seq.dtype), seq[sparse_config.layout[0]:]))
-    out_seq_length = len(seq)
-    # building the initial tokens, attention_mask, and position_ids
-    context_length = sparse_config.layout[1] + 1
-
-    tokens, attention_mask, position_ids = get_batch(seq[:context_length], device, args)
-    tokens = tokens.expand(-min(seq), *tokens.shape[1:])
-
-    counter = context_length - 1 # == len(tokens) - 1
-    index = 0 # len(mems)
-    if mems is None:
-        mems = []
-    mem_bag = torch.zeros(args.num_layers, tokens.shape[0], out_seq_length-1, 2*args.hidden_size, device='cuda')
-    local_mask = make_local_mask(sparse_config)
-    
-    while counter < (out_seq_length - 1):
-        if counter % 100 == 0:
-            print(counter, loss_sum / loss_n, file=sys.stderr)
-        # Now, we want to generate seq[counter + 1]
-        # token[:, index: counter+1] are just added.
-        # import pdb;pdb.set_trace()
-
-        if index == 0: # first
-            logits, *qkv = model(tokens, position_ids, attention_mask, *mems)
-            mems = update_mems_local(qkv, mems, index, counter+1, mem_bag, local_mask)
-            index = counter
-        else:
-            assert tokens.shape[1] == counter + 1 
-            position_ids = torch.arange(index, counter + 1, dtype=torch.long, device=tokens.device).unsqueeze(0)
-            logits, *qkv = model(tokens[:, index: ], 
-                position_ids,
-                0, # rebuild in transformers (sep version)
-                *mems
-                )
-            mems = update_mems_local(qkv, mems, index, counter+1, mem_bag, local_mask)
-            index = counter
-        counter += 1
-        index += 1
-
-        if seq[counter] >= 0: # provided
-            tokens = torch.cat((tokens, seq[counter: counter+1].expand(tokens.shape[0], 1)), dim=1)
-            loss_n +=1
-            loss_this = F.log_softmax(logits, dim=-1)[:, -1, seq[counter]].mean()
-            print(counter-64, loss_this.item())
-            loss_sum -= loss_this
-            continue
-
-        logits = logits[:, -1] # [batch size, vocab size]
-        temp = args.temperature
-        logits /= temp
-        for invalid_slice in invalid_slices: # forbide to generate other tokens
-            logits[..., invalid_slice] = -float('Inf')
-        logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)
-        log_probs = F.softmax(logits, dim=-1)
-
-        # expand beams
-        prev = torch.multinomial(log_probs, num_samples=1)
-        tokens = torch.cat((tokens, prev.view(tokens.shape[0], 1)), dim=1)
-
-    output_tokens_list = tokens.view(tokens.shape[0], -1).contiguous()
-    output_tokens_list = torch.cat(
-        (
-            output_tokens_list[:, :sparse_config.layout[0]],
-            output_tokens_list[:, sparse_config.layout[1]+1:sparse_config.layout[2]+1],
-            torch.tensor([[tokenizer['[EOI1]']]], dtype=tokens.dtype, device=tokens.device).expand(output_tokens_list.shape[0], 1),
-            output_tokens_list[:, sparse_config.layout[2]+1:]
-        ), dim=1)
-    return output_tokens_list
\ No newline at end of file
diff --git a/generation/sampling_strategies/__init__.py b/generation/sampling_strategies/__init__.py
new file mode 100644
index 0000000..f109603
--- /dev/null
+++ b/generation/sampling_strategies/__init__.py
@@ -0,0 +1 @@
+from .base_strategy import BaseStrategy
\ No newline at end of file
diff --git a/generation/sampling_strategies/base_strategy.py b/generation/sampling_strategies/base_strategy.py
new file mode 100644
index 0000000..99a0a74
--- /dev/null
+++ b/generation/sampling_strategies/base_strategy.py
@@ -0,0 +1,46 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   base_strategy.py
+@Time    :   2021/10/08 22:22:42
+@Author  :   Ming Ding 
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import math
+import random
+import torch
+import torch.nn.functional as F
+
+def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
+    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+    logits[indices_to_remove] = filter_value     
+    return logits
+
+class BaseStrategy:
+    def __init__(self, invalid_slices=[], temperature=1., topk=200, debias=False):
+        self.invalid_slices = invalid_slices
+        self.temperature = temperature
+        self.topk = topk
+        self.debias = debias
+    def forward(self, logits, tokens, mems, temperature=None):
+        if temperature is None:
+            temperature = self.temperature 
+        logits = logits / temperature
+        for invalid_slice in self.invalid_slices:
+            logits[..., invalid_slice] = -float('Inf')
+        if self.debias:
+            probs = F.softmax(logits, dim=-1)
+            tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1)
+            pred = torch.multinomial(probs, num_samples=1)
+            for j in range(0, pred.shape[0]):
+                if probs[j, pred[j,-1]] < tk_value[j, -1]:
+                    pred[j, -1] = tk_idx[j, torch.randint(tk_idx.shape[-1]-100, tk_idx.shape[-1], (1,))] # 100 is the last N as outlier, which is chosen casually
+        else:
+            logits = top_k_logits_(logits)
+            probs = F.softmax(logits, dim=-1)
+            pred = torch.multinomial(probs, num_samples=1)
+        tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
+        return tokens, mems
diff --git a/model/cached_autoregressive_model.py b/model/cached_autoregressive_model.py
index ec2dd60..225ff3e 100755
--- a/model/cached_autoregressive_model.py
+++ b/model/cached_autoregressive_model.py
@@ -31,7 +31,8 @@ class CachedAutoregressiveModel(BaseModel):
             mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
         
         if mem is not None: # the first time, mem is None
-            memk, memv = split_tensor_along_last_dim(mem, 2)
+            b = mixed_key_layer.shape[0] # might change batch_size
+            memk, memv = split_tensor_along_last_dim(mem.expand(b, -1, -1), 2)
             mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1)
             mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1)
 
diff --git a/model/mixins.py b/model/mixins.py
index 0d304cf..78befb8 100644
--- a/model/mixins.py
+++ b/model/mixins.py
@@ -34,9 +34,9 @@ class PositionEmbeddingMixin(BaseMixin):
         torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
     def reinit(self, transformer, *pre_mixins):
         old_weights = transformer.position_embeddings.weight.data[self.reinit_slice]
-        old_len, hidden_size = old_weights.shape[0]
+        old_len, hidden_size = old_weights.shape
         assert hidden_size == self.position_embeddings.weight.shape[-1]
-        self.position_embeddings_plus.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
+        self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
 
 class AttentionMixin(BaseMixin):
     def __init__(self, num_layers,
diff --git a/mpu/transformer.py b/mpu/transformer.py
index 099dcb5..61f658e 100755
--- a/mpu/transformer.py
+++ b/mpu/transformer.py
@@ -1,4 +1,4 @@
-# coding=utf-
+# coding=utf-8
 # rewritten, Copyright (c) 2021, Ming Ding.  All rights reserved.
 # Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
 #
@@ -26,7 +26,6 @@ from .initialize import get_model_parallel_world_size
 from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
 from .mappings import gather_from_model_parallel_region, copy_to_model_parallel_region
 
-import deepspeed
 from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint, get_cuda_rng_tracker
 
 from .utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu
@@ -248,7 +247,7 @@ class BaseTransformerLayer(torch.nn.Module):
         # Layer norm post the self attention.
         layernorm_output = self.post_attention_layernorm(layernorm_input)
         # MLP.
-        mlp_output = self.mlp(layernorm_output,  *other_tensors)
+        mlp_output = self.mlp(layernorm_output, *other_tensors)
 
         # Fourth LayerNorm
         if self.sandwich_ln:
diff --git a/pretrain_cogview2.py b/pretrain_cogview2.py
index 032fade..5fdb431 100755
--- a/pretrain_cogview2.py
+++ b/pretrain_cogview2.py
@@ -1,6 +1,6 @@
 # -*- encoding: utf-8 -*-
 '''
-@File    :   pretrain_gpt2.py
+@File    :   pretrain_cogview2.py
 @Time    :   2021/10/06 00:58:32
 @Author  :   Ming Ding 
 @Contact :   dm18@mail.tsinghua.edu.cn
diff --git a/scripts/finetune_into_cogview2.sh b/scripts/finetune_into_cogview2.sh
new file mode 100755
index 0000000..3bdf5ea
--- /dev/null
+++ b/scripts/finetune_into_cogview2.sh
@@ -0,0 +1,59 @@
+#! /bin/bash
+
+# Change for multinode config
+
+NUM_WORKERS=1
+NUM_GPUS_PER_WORKER=8
+MP_SIZE=1
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+main_dir=$(dirname $script_dir)
+
+OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
+HOST_FILE_PATH="hostfile"
+HOST_FILE_PATH="hostfile_single"
+
+full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin"
+small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata"
+
+config_json="$script_dir/ds_config_zero.json"
+gpt_options=" \
+       --experiment-name finetune-cogview2-test \
+       --tokenizer-type cogview \
+       --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \
+       --model-parallel-size ${MP_SIZE} \
+       --mode finetune \
+       --num-layers 48 \
+       --hidden-size 2560 \
+       --num-attention-heads 40 \
+       --train-iters 200000 \
+       --resume-dataloader \
+       --train-data ${full_data} \
+       --split 949,50,1 \
+       --distributed-backend nccl \
+       --lr-decay-style cosine \
+       --warmup .1 \
+       --checkpoint-activations \
+       --max-sequence-length 1089 \
+       --sandwich-ln \
+       --fp16 \
+       --save-interval 2000 \
+       --eval-interval 1000 \
+       --save $main_dir/checkpoints \
+       --load pretrained/cogview/cogview-base
+"
+       # --load pretrained/cogview/cogview-base
+
+
+gpt_options="${gpt_options}
+       --deepspeed \
+       --deepspeed_config ${config_json} \
+"
+              
+
+run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_cogview2.py $@ ${gpt_options}"
+echo ${run_cmd}
+eval ${run_cmd}
+
+set +x
diff --git a/training/model_io.py b/training/model_io.py
index f7527f9..ee2e8f1 100644
--- a/training/model_io.py
+++ b/training/model_io.py
@@ -135,7 +135,7 @@ def load_checkpoint(model, args):
         if not args.do_train:
             raise ValueError(f'Missing keys for inference: {missing_keys}.')
         else: # new params
-            assert all(name.find('mixins')>0 for name in missing_keys)
+            assert all(name.find('mixins')>=0 for name in missing_keys)
             module.reinit() # initialize mixins
     model.optimizer.refresh_fp32_params() # restore fp32 weights from module
 
-- 
GitLab