diff --git a/arguments.py b/arguments.py
index bccbd28237c22e0f8f0777824b03bc1cf3d47712..bdb4de4a2bb309101a282b2d8d9f84357cfab991 100755
--- a/arguments.py
+++ b/arguments.py
@@ -194,6 +194,9 @@ def add_data_args(parser):
     group.add_argument('--num-workers', type=int, default=2,
                        help="""Number of workers to use for dataloading""")
 
+    group.add_argument('--block-size', type=int, default=10000,
+                       help="""Size of block to reduce memory in dataset""")
+
     return parser
 
 def add_generation_api_args(parser):
diff --git a/data_utils/configure_data.py b/data_utils/configure_data.py
index 95c3bd1be2c7fa1629d2d1f9fffff99b58d6aced..693c224e9fd2010a51ee62672afcdcfadddb4dcb 100755
--- a/data_utils/configure_data.py
+++ b/data_utils/configure_data.py
@@ -64,11 +64,14 @@ def make_dataset_full(path, split, args, create_dataset_function, **kwargs):
     for p in path:
         d = create_dataset_function(p, args)
         ds.append(d)
-
-    ds = RandomMappingDataset(ConcatDataset(ds))
-
+    ds = ConcatDataset(ds)
     if should_split(split):
-        ds = split_ds(ds, split) # Large dataset, cannot shuffle, randomly mapping
+        ds = split_ds(ds, split, block_size=args.block_size)
+    else:
+        ds = RandomMappingDataset(ds)
+
+    # if should_split(split):
+    #     ds = split_ds(ds, split) # Large dataset, cannot shuffle, randomly mapping
     # FIXME this will merge valid set and train set.
     return ds
 
@@ -174,7 +177,7 @@ def should_split(split):
     """
     return max(split) / sum(split) != 1.
 
-def split_ds(ds, split=[.8,.2,.0]):
+def split_ds(ds, split=[.8,.2,.0], block_size = 10000):
     """
     Split a dataset into subsets given proportions of how
     much to allocate per split. If a split is 0% returns None for that split.
@@ -189,18 +192,20 @@ def split_ds(ds, split=[.8,.2,.0]):
         raise Exception('Split cannot sum to 0.')
     split = np.array(split)
     split /= split_sum
-    ds_len = len(ds)
+
+    assert block_size <= len(ds)
 
     start_idx = 0
     residual_idx = 0
     rtn_ds = [None]*len(split)
+    indices = np.random.permutation(np.array(range(block_size)))
     for i, f in enumerate(split):
         if f != 0:
-            proportion = ds_len*split[i]
+            proportion = block_size*split[i]
             residual_idx += proportion % 1
             split_ = int(int(proportion) + residual_idx)
-            split_range = (start_idx, start_idx+max(split_, 1))
-            rtn_ds[i] = SplitDataset(ds, split_range)
+            rtn_ds[i] = BlockedRandomSplitDataset(ds, indices[range(start_idx, start_idx+max(split_, 1))], block_size)
+            rtn_ds[i] = RandomMappingDataset(rtn_ds[i])
             start_idx += split_
             residual_idx %= 1
     return rtn_ds
@@ -284,3 +289,42 @@ class RandomMappingDataset(data.Dataset):
         rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)])
         index = rng.randint(len(self.wrapped_data))
         return self.wrapped_data[index]
+
+class BlockedRandomSplitDataset(data.Dataset):
+    '''
+    Dataset wrapper to access a subset of another dataset.
+    Use block algorithm to reduce memory
+    '''
+    def __init__(self, ds, indices, block_size,**kwargs):
+        if type(indices) is not np.ndarray:
+            indices = np.array(indices)
+        indices = np.sort(indices)
+        self.block_size = block_size
+        self.wrapped_data = ds
+        self.wrapped_data_len = len(ds)
+        self.indices = indices
+        self.len = len(indices) * (len(ds) // block_size) + np.sum(indices < (len(ds) % block_size))
+
+    def __len__(self):
+        return self.len
+
+    def __getitem__(self, index):
+        return self.wrapped_data[(index // len(self.indices)) * self.block_size + self.indices[index % len(self.indices)]]
+
+
+class EnlargedDataset(data.Dataset):
+    '''
+    Dataset wrapper to enlarge the dataset
+    '''
+    def __init__(self, ds, scale=200, **kwargs):
+        self.wrapped_data = ds
+        self.wrapped_data_len = len(ds)
+        self.scale = scale
+
+    def __len__(self):
+        return self.wrapped_data_len * self.scale
+
+    def __getitem__(self, index):
+        return self.wrapped_data[index%self.wrapped_data_len]
+
+