Skip to content
Snippets Groups Projects
Commit a36b9ca3 authored by zhuoyiyang's avatar zhuoyiyang
Browse files

fix valid dataset bug

parent 465c146b
No related branches found
No related tags found
No related merge requests found
......@@ -194,6 +194,9 @@ def add_data_args(parser):
group.add_argument('--num-workers', type=int, default=2,
help="""Number of workers to use for dataloading""")
group.add_argument('--block-size', type=int, default=10000,
help="""Size of block to reduce memory in dataset""")
return parser
def add_generation_api_args(parser):
......
......@@ -64,11 +64,14 @@ def make_dataset_full(path, split, args, create_dataset_function, **kwargs):
for p in path:
d = create_dataset_function(p, args)
ds.append(d)
ds = RandomMappingDataset(ConcatDataset(ds))
ds = ConcatDataset(ds)
if should_split(split):
ds = split_ds(ds, split) # Large dataset, cannot shuffle, randomly mapping
ds = split_ds(ds, split, block_size=args.block_size)
else:
ds = RandomMappingDataset(ds)
# if should_split(split):
# ds = split_ds(ds, split) # Large dataset, cannot shuffle, randomly mapping
# FIXME this will merge valid set and train set.
return ds
......@@ -174,7 +177,7 @@ def should_split(split):
"""
return max(split) / sum(split) != 1.
def split_ds(ds, split=[.8,.2,.0]):
def split_ds(ds, split=[.8,.2,.0], block_size = 10000):
"""
Split a dataset into subsets given proportions of how
much to allocate per split. If a split is 0% returns None for that split.
......@@ -189,18 +192,20 @@ def split_ds(ds, split=[.8,.2,.0]):
raise Exception('Split cannot sum to 0.')
split = np.array(split)
split /= split_sum
ds_len = len(ds)
assert block_size <= len(ds)
start_idx = 0
residual_idx = 0
rtn_ds = [None]*len(split)
indices = np.random.permutation(np.array(range(block_size)))
for i, f in enumerate(split):
if f != 0:
proportion = ds_len*split[i]
proportion = block_size*split[i]
residual_idx += proportion % 1
split_ = int(int(proportion) + residual_idx)
split_range = (start_idx, start_idx+max(split_, 1))
rtn_ds[i] = SplitDataset(ds, split_range)
rtn_ds[i] = BlockedRandomSplitDataset(ds, indices[range(start_idx, start_idx+max(split_, 1))], block_size)
rtn_ds[i] = EnlargedDataset(rtn_ds[i])
start_idx += split_
residual_idx %= 1
return rtn_ds
......@@ -284,3 +289,41 @@ class RandomMappingDataset(data.Dataset):
rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)])
index = rng.randint(len(self.wrapped_data))
return self.wrapped_data[index]
class BlockedRandomSplitDataset(data.Dataset):
'''
Dataset wrapper to access a subset of another dataset.
Use block algorithm to reduce memory
'''
def __init__(self, ds, indices, block_size,**kwargs):
if type(indices) is not np.ndarray:
indices = np.array(indices)
self.block_size = block_size
self.wrapped_data = ds
self.wrapped_data_len = len(ds)
self.indices = indices
self.len = len(indices) * (len(ds) // block_size) + np.sum(indices < (len(ds) % block_size))
def __len__(self):
return self.len
def __getitem__(self, index):
return self.wrapped_data[(index // len(self.indices)) * self.block_size + self.indices[index % self.block_size]]
class EnlargedDataset(data.Dataset):
'''
Dataset wrapper to enlarge the dataset
'''
def __init__(self, ds, scale=200, **kwargs):
self.wrapped_data = ds
self.wrapped_data_len = len(ds)
self.scale = scale
def __len__(self):
return self.wrapped_data_len * self.scale
def __getitem__(self, index):
return self.wrapped_data[index%self.wrapped_data_len]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment