diff --git a/SwissArmyTransformer/arguments.py b/SwissArmyTransformer/arguments.py
index dfa75c910489d592262a7483c6b0cfa2d901660b..85f701611615755885485fd57799ca4422791691 100755
--- a/SwissArmyTransformer/arguments.py
+++ b/SwissArmyTransformer/arguments.py
@@ -146,7 +146,8 @@ def add_evaluation_args(parser):
                             'validation/test for')
     group.add_argument('--eval-interval', type=int, default=1000,
                        help='interval between running evaluation on validation set')
-
+    group.add_argument('--strict-eval', action='store_true',
+                       help='won\'t enlarge or randomly map eval-data, and eval full eval-data.')
     return parser
 
 
diff --git a/SwissArmyTransformer/data_utils/configure_data.py b/SwissArmyTransformer/data_utils/configure_data.py
index b7d6299c69ef5d14030876fb6462b89b2daec654..5d53cded92c2264a82021ffb16912d8be8a45ee9 100755
--- a/SwissArmyTransformer/data_utils/configure_data.py
+++ b/SwissArmyTransformer/data_utils/configure_data.py
@@ -24,14 +24,17 @@ from .samplers import DistributedBatchSampler
 from SwissArmyTransformer import mpu
 
 
-def make_data_loader(dataset, batch_size, num_iters, args):
+def make_data_loader(dataset, batch_size, args):
     world_size = torch.distributed.get_world_size(
         group=mpu.get_data_parallel_group())
     rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
     distributed = world_size > 1
 
     sampler = torch.utils.data.SequentialSampler(dataset)
-    drop_last = distributed
+    # drop_last = distributed
+    drop_last = True # TODO will always drop last to keep the consistency. 
+    # or, how to avg in eval last batch?
+    
     # the GPUs in the same model parallel group receive the same data
     if distributed: # TODO reformat this, but it is not urgent
         gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1)
@@ -52,7 +55,7 @@ def make_data_loader(dataset, batch_size, num_iters, args):
     return data_loader
 
 
-def make_dataset_full(path, split, args, create_dataset_function, **kwargs):
+def make_dataset_full(path, split, args, create_dataset_function, random_mapping=True, **kwargs):
     """function to create datasets+tokenizers for common options"""
     print('make dataset ...', path)
     if split is None:
@@ -67,12 +70,8 @@ def make_dataset_full(path, split, args, create_dataset_function, **kwargs):
     ds = ConcatDataset(ds)
     if should_split(split):
         ds = split_ds(ds, split, block_size=args.block_size)
-    else:
+    elif random_mapping:
         ds = RandomMappingDataset(ds)
-
-    # if should_split(split):
-    #     ds = split_ds(ds, split) # Large dataset, cannot shuffle, randomly mapping
-    # FIXME this will merge valid set and train set.
     return ds
 
 def make_loaders(args, create_dataset_function):
@@ -115,25 +114,25 @@ def make_loaders(args, create_dataset_function):
     # make training and val dataset if necessary
     if valid is None and args.valid_data is not None:
         eval_set_args['path'] = args.valid_data
-        valid = make_dataset(**eval_set_args, args=args)
+        valid = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
     if test is None and args.test_data is not None:
         eval_set_args['path'] = args.test_data
-        test = make_dataset(**eval_set_args, args=args)
+        test = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
 
     # wrap datasets with data loader
     if train is not None and args.batch_size > 0:
-        train = make_data_loader(train, batch_size, args.train_iters, args)
+        train = make_data_loader(train, batch_size, args)
         args.do_train = True
     else:
         args.do_train = False
     eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
     if valid is not None:
-        valid = make_data_loader(valid, eval_batch_size, args.train_iters, args)
+        valid = make_data_loader(valid, eval_batch_size, args)
         args.do_valid = True
     else:
         args.do_valid = False
     if test is not None:
-        test = make_data_loader(test, eval_batch_size, len(test) // eval_batch_size + 1, args)
+        test = make_data_loader(test, eval_batch_size, args)
         args.do_test = True
     else:
         args.do_test = False
diff --git a/SwissArmyTransformer/data_utils/datasets.py b/SwissArmyTransformer/data_utils/datasets.py
index bec1207112194e70221c8f52852f2c9855ed951e..ef46e2da7bc185465320792a89b6c2804dce1ab6 100755
--- a/SwissArmyTransformer/data_utils/datasets.py
+++ b/SwissArmyTransformer/data_utils/datasets.py
@@ -65,3 +65,18 @@ class BinaryDataset(Dataset):
     def __getitem__(self, index):
         return self.process_fn(self.bin[index])
 
+class TSVDataset(Dataset):
+    def __init__(self, path, process_fn, with_heads=True, **kwargs):
+        self.process_fn = process_fn
+        with open(path, 'r') as fin:
+            if with_heads:
+                self.heads = fin.readline().split('\t')
+            else:
+                self.heads = None
+            self.items = [line.split('\t') for line in fin]
+
+    def __len__(self):
+        return len(self.items)
+    
+    def __getitem__(self, index):
+        return self.process_fn(self.items[index])
diff --git a/SwissArmyTransformer/generation/autoregressive_sampling.py b/SwissArmyTransformer/generation/autoregressive_sampling.py
index 37a24a88a004bfff63a87ebdcf3c909b29256a20..4eada99a05d6fe244fe8f6c342322e6f6c7c485f 100644
--- a/SwissArmyTransformer/generation/autoregressive_sampling.py
+++ b/SwissArmyTransformer/generation/autoregressive_sampling.py
@@ -56,7 +56,8 @@ def filling_sequence(
         max_memory_length=100000,
         log_attention_weights=None,
         get_masks_and_position_ids=get_masks_and_position_ids_default,
-        mems=None
+        mems=None,
+        **kw_args
         ):
     '''
         seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
@@ -104,7 +105,8 @@ def filling_sequence(
             position_ids[..., index: counter+1],
             attention_mask[..., index: counter+1, :counter+1], # TODO memlen
             mems=mems,
-            log_attention_weights=log_attention_weights_part
+            log_attention_weights=log_attention_weights_part,
+            **kw_args
         )
         mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
         counter += 1
diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py
index fa93f3fd575bb4049c0b3bbbddd7d66e196d1e2b..7e0485274deeae3f095456f73a0d9c842a026e46 100644
--- a/SwissArmyTransformer/tokenization/__init__.py
+++ b/SwissArmyTransformer/tokenization/__init__.py
@@ -51,7 +51,7 @@ def get_tokenizer(args=None, outer_tokenizer=None):
             from .cogview import UnifiedTokenizer
             get_tokenizer.tokenizer = UnifiedTokenizer(
                 args.img_tokenizer_path,
-                # txt_tokenizer_type=args.tokenizer_type,
+                txt_tokenizer_type='cogview',
                 device=torch.cuda.current_device()
             )
         elif args.tokenizer_type.startswith('glm'):
diff --git a/SwissArmyTransformer/tokenization/cogview/sp_tokenizer.py b/SwissArmyTransformer/tokenization/cogview/sp_tokenizer.py
index ee8c907742dea9df7a715d1f44fb3ad067ddadb0..c9e731e6fdcc7fc52c6a0cb90d1c14f3ca0306fb 100755
--- a/SwissArmyTransformer/tokenization/cogview/sp_tokenizer.py
+++ b/SwissArmyTransformer/tokenization/cogview/sp_tokenizer.py
@@ -22,6 +22,8 @@ python setup.py install
 
 PRETRAINED_MODEL_FILE = os.path.join(os.path.dirname(os.path.dirname(__file__)),
      'embed_assets', 'chinese_sentencepiece/cog-pretrain.model')
+PRETRAINED_MODEL_FILE_ICE = os.path.join(os.path.dirname(os.path.dirname(__file__)),
+     'embed_assets', 'chinese_sentencepiece/ice.model') # merge xlnet 3,2000 En tokens
 
 
 def get_pairs(word):
@@ -148,5 +150,10 @@ def get_encoder(encoder_file, bpe_file):
         )
 
 
-def from_pretrained():
-    return get_encoder(PRETRAINED_MODEL_FILE, "")
\ No newline at end of file
+def from_pretrained(tokenizer_type='cogview'):
+    if tokenizer_type == 'cogview_ICE':
+        return get_encoder(PRETRAINED_MODEL_FILE_ICE, "")
+    elif tokenizer_type == 'cogview':
+        return get_encoder(PRETRAINED_MODEL_FILE, "")
+    else:
+        raise ValueError('Unknown cogview tokenizer.')
\ No newline at end of file
diff --git a/SwissArmyTransformer/tokenization/cogview/unified_tokenizer.py b/SwissArmyTransformer/tokenization/cogview/unified_tokenizer.py
index 8a06004d332955aa4b2f1fde712b52f188bc2a59..7b17f751aa8ca1d583f4e2641f43270c8667fe13 100755
--- a/SwissArmyTransformer/tokenization/cogview/unified_tokenizer.py
+++ b/SwissArmyTransformer/tokenization/cogview/unified_tokenizer.py
@@ -20,10 +20,10 @@ from .sp_tokenizer import from_pretrained
 from .vqvae_tokenizer import VQVAETokenizer, sqrt_int
 
 class UnifiedTokenizer(object):
-    def __init__(self, img_tokenizer_path, device):
+    def __init__(self, img_tokenizer_path, txt_tokenizer_type, device):
         self.device = device
         self.img_tokenizer = VQVAETokenizer(model_path=img_tokenizer_path, device=self.device)
-        self.txt_tokenizer = from_pretrained()
+        self.txt_tokenizer = from_pretrained(txt_tokenizer_type)
         self.num_tokens = self.img_tokenizer.num_tokens + self.txt_tokenizer.num_tokens
         self.raw_command_tokens = [
             ('[PAD]', 0),
diff --git a/SwissArmyTransformer/training/deepspeed_training.py b/SwissArmyTransformer/training/deepspeed_training.py
index 335fc1919d2abb048127dc0bc59e5c4badc33d1e..caeb340e4ff9121196072ca286a25713ef35c480 100644
--- a/SwissArmyTransformer/training/deepspeed_training.py
+++ b/SwissArmyTransformer/training/deepspeed_training.py
@@ -100,15 +100,7 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio
         if val_data is not None:
             start_iter_val = (args.train_iters // args.save_interval) * args.eval_interval
             val_data.batch_sampler.start_iter = start_iter_val % len(val_data)
-    if train_data is not None:
-        train_data_iterator = iter(train_data)
-    else:
-        train_data_iterator = None
-    if val_data is not None:
-        val_data_iterator = iter(val_data)
-    else:
-        val_data_iterator = None
-        
+
     # init hook before training
     if hooks['init_function'] is not None:
         hooks['init_function'](args, model, optimizer)
@@ -122,15 +114,11 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio
                     save_checkpoint(args_.iteration, model_, optimizer_, lr_scheduler_, args_)
                 iteration, skipped = train(model, optimizer,
                     lr_scheduler,
-                    train_data_iterator,
-                    val_data_iterator,
+                    train_data,
+                    val_data,
                     timers, args, summary_writer=summary_writer,
                     hooks=hooks
                     )
-        if args.do_valid:
-            prefix = 'the end of training for val data'
-            val_loss = evaluate_and_print_results(prefix, val_data_iterator,
-                model, args, timers, False, hooks=hooks)
 
     # final save
     if args.save and iteration != 0: # TODO save
@@ -140,7 +128,7 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio
     if args.do_test and test_data is not None:
         prefix = 'the end of training for test data'
         evaluate_and_print_results(prefix, iter(test_data),
-            model, args, timers, True, hooks=hooks)
+            model, len(test_data) if args.strict_eval else args.eval_iters, args, timers, True, hooks=hooks)
 
 def get_model(args, model_cls):
     """Build the model."""
@@ -254,11 +242,21 @@ def get_learning_rate_scheduler(optimizer, iteration, args,
 
 
 def train(model, optimizer, lr_scheduler,
-        train_data_iterator, val_data_iterator, timers, args, 
+        train_data, val_data, timers, args, 
         summary_writer=None, hooks={}):
     """Train the model."""
+    if train_data is not None:
+        train_data_iterator = iter(train_data)
+    else:
+        train_data_iterator = None
+    if val_data is not None:
+        val_data_iterator = iter(val_data)
+    else:
+        val_data_iterator = None
+        
     # Turn on training mode which enables dropout.
     model.train()
+    
 
     # Tracking loss.
     total_lm_loss = 0.0
@@ -312,9 +310,14 @@ def train(model, optimizer, lr_scheduler,
 
         # Evaluation
         if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid:
+            if args.strict_eval:
+                val_data_iterator = iter(val_data)
+                eval_iters = len(val_data)
+            else:
+                eval_iters = args.eval_iters
             prefix = 'iteration {}'.format(args.iteration)
             evaluate_and_print_results(
-                prefix, val_data_iterator, model, args, timers, False, step=args.iteration, summary_writer=summary_writer, hooks=hooks)
+                prefix, val_data_iterator, model, eval_iters, args, timers, False, step=args.iteration, summary_writer=summary_writer, hooks=hooks)
 
         if args.exit_interval and args.iteration % args.exit_interval == 0:
             torch.distributed.barrier()
@@ -406,20 +409,19 @@ def backward_step(optimizer, model, loss, args, timers):
 
     return
 
-def evaluate(data_iterator, model, args, timers, verbose=False, hooks={}):
+def evaluate(data_iterator, model, eval_iters, args, timers, verbose=False, hooks={}):
     """Evaluation."""
     forward_step = hooks['forward_step']
-
     # Turn on evaluation mode which disables dropout.
     model.eval()
 
-    total_lm_loss = 0
+    total_lm_loss, metrics_total = 0, {}
     with torch.no_grad():
         iteration = 0
-        while iteration < args.eval_iters:
+        while iteration < eval_iters:
             iteration += 1
             if verbose and iteration % args.log_interval == 0:
-                print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters))
+                print_rank_0('Evaluating iter {}/{}'.format(iteration, eval_iters))
             # Forward evaluation.
             lm_loss, metrics = forward_step(data_iterator, model, args, timers)
             '''when contiguous memory optimizations are enabled, the buffers
@@ -429,26 +431,31 @@ def evaluate(data_iterator, model, args, timers, verbose=False, hooks={}):
             if args.deepspeed and args.deepspeed_activation_checkpointing:
                 deepspeed.checkpointing.reset()
             total_lm_loss += lm_loss.data.detach().float().item()
+            for name in metrics:
+                if name not in metrics_total:
+                    metrics_total[name] = 0.0
+                metrics_total[name] += metrics[name]
 
     # Move model back to the train mode.
     model.train()
 
-    total_lm_loss /= args.eval_iters
-    return total_lm_loss
+    total_lm_loss /= eval_iters
+    metrics_avg = {key: value / eval_iters for key, value in metrics_total.items()}
+    return total_lm_loss, metrics_avg
 
-def evaluate_and_print_results(prefix, data_iterator, model,
+def evaluate_and_print_results(prefix, data_iterator, model, eval_iters,
                             args, timers, verbose=False, step=None, summary_writer=None, hooks={}):
     """Helper function to evaluate and dump results on screen."""
     # import line_profiler
     # profile = line_profiler.LineProfiler(model.module.module.transformer.layers[0].forward)
     # profile.enable()
     # torch.cuda.empty_cache()
-    lm_loss = evaluate(data_iterator, model, args, timers, verbose, hooks=hooks)
+    lm_loss, metrics = evaluate(data_iterator, model, eval_iters, args, timers, verbose, hooks=hooks)
     # profile.disable()
     # import sys
     # profile.print_stats(sys.stdout)
     lm_ppl = math.exp(min(20, lm_loss))
-    report_evaluate_metrics(summary_writer, prefix, lm_loss, lm_ppl, step)
+    report_evaluate_metrics(summary_writer, prefix, lm_loss, lm_ppl, step, metrics)
 
     return lm_loss
 
@@ -471,10 +478,12 @@ def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time,
             summary_writer.add_scalar('Train/'+key, avg_metrics[key], step)
 
 
-def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step):
+def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step, avg_metrics):
     string = ' validation loss at {} | '.format(prefix)
     string += 'LM loss: {:.6E} | '.format(loss)
     string += 'LM PPL: {:.6E}'.format(ppl)
+    for key in avg_metrics:
+        string += ' {} {:.6E} |'.format(key, avg_metrics[key])
     length = len(string) + 1
     print_rank_0('-' * 100)
     print_rank_0('-' * length)
@@ -483,6 +492,8 @@ def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step):
     if summary_writer is not None:
         summary_writer.add_scalar(f'Train/valid_ppl', ppl, step)
         summary_writer.add_scalar(f'Train/valid_loss', loss, step)
+        for key in avg_metrics:
+            summary_writer.add_scalar('Train/valid_'+key, avg_metrics[key], step)
         
 
 '''