diff --git a/arguments.py b/arguments.py index bccbd28237c22e0f8f0777824b03bc1cf3d47712..bdb4de4a2bb309101a282b2d8d9f84357cfab991 100755 --- a/arguments.py +++ b/arguments.py @@ -194,6 +194,9 @@ def add_data_args(parser): group.add_argument('--num-workers', type=int, default=2, help="""Number of workers to use for dataloading""") + group.add_argument('--block-size', type=int, default=10000, + help="""Size of block to reduce memory in dataset""") + return parser def add_generation_api_args(parser): diff --git a/data_utils/configure_data.py b/data_utils/configure_data.py index 95c3bd1be2c7fa1629d2d1f9fffff99b58d6aced..693c224e9fd2010a51ee62672afcdcfadddb4dcb 100755 --- a/data_utils/configure_data.py +++ b/data_utils/configure_data.py @@ -64,11 +64,14 @@ def make_dataset_full(path, split, args, create_dataset_function, **kwargs): for p in path: d = create_dataset_function(p, args) ds.append(d) - - ds = RandomMappingDataset(ConcatDataset(ds)) - + ds = ConcatDataset(ds) if should_split(split): - ds = split_ds(ds, split) # Large dataset, cannot shuffle, randomly mapping + ds = split_ds(ds, split, block_size=args.block_size) + else: + ds = RandomMappingDataset(ds) + + # if should_split(split): + # ds = split_ds(ds, split) # Large dataset, cannot shuffle, randomly mapping # FIXME this will merge valid set and train set. return ds @@ -174,7 +177,7 @@ def should_split(split): """ return max(split) / sum(split) != 1. -def split_ds(ds, split=[.8,.2,.0]): +def split_ds(ds, split=[.8,.2,.0], block_size = 10000): """ Split a dataset into subsets given proportions of how much to allocate per split. If a split is 0% returns None for that split. @@ -189,18 +192,20 @@ def split_ds(ds, split=[.8,.2,.0]): raise Exception('Split cannot sum to 0.') split = np.array(split) split /= split_sum - ds_len = len(ds) + + assert block_size <= len(ds) start_idx = 0 residual_idx = 0 rtn_ds = [None]*len(split) + indices = np.random.permutation(np.array(range(block_size))) for i, f in enumerate(split): if f != 0: - proportion = ds_len*split[i] + proportion = block_size*split[i] residual_idx += proportion % 1 split_ = int(int(proportion) + residual_idx) - split_range = (start_idx, start_idx+max(split_, 1)) - rtn_ds[i] = SplitDataset(ds, split_range) + rtn_ds[i] = BlockedRandomSplitDataset(ds, indices[range(start_idx, start_idx+max(split_, 1))], block_size) + rtn_ds[i] = RandomMappingDataset(rtn_ds[i]) start_idx += split_ residual_idx %= 1 return rtn_ds @@ -284,3 +289,42 @@ class RandomMappingDataset(data.Dataset): rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)]) index = rng.randint(len(self.wrapped_data)) return self.wrapped_data[index] + +class BlockedRandomSplitDataset(data.Dataset): + ''' + Dataset wrapper to access a subset of another dataset. + Use block algorithm to reduce memory + ''' + def __init__(self, ds, indices, block_size,**kwargs): + if type(indices) is not np.ndarray: + indices = np.array(indices) + indices = np.sort(indices) + self.block_size = block_size + self.wrapped_data = ds + self.wrapped_data_len = len(ds) + self.indices = indices + self.len = len(indices) * (len(ds) // block_size) + np.sum(indices < (len(ds) % block_size)) + + def __len__(self): + return self.len + + def __getitem__(self, index): + return self.wrapped_data[(index // len(self.indices)) * self.block_size + self.indices[index % len(self.indices)]] + + +class EnlargedDataset(data.Dataset): + ''' + Dataset wrapper to enlarge the dataset + ''' + def __init__(self, ds, scale=200, **kwargs): + self.wrapped_data = ds + self.wrapped_data_len = len(ds) + self.scale = scale + + def __len__(self): + return self.wrapped_data_len * self.scale + + def __getitem__(self, index): + return self.wrapped_data[index%self.wrapped_data_len] + +