diff --git a/data_utils/configure_data.py b/data_utils/configure_data.py index 5023ae629c98a9da4608b0705e25eb4501fc7249..4d44335db1d208289d3abbd01d59721db6acdc52 100755 --- a/data_utils/configure_data.py +++ b/data_utils/configure_data.py @@ -205,7 +205,7 @@ def split_ds(ds, split=[.8,.2,.0], block_size = 10000): residual_idx += proportion % 1 split_ = int(int(proportion) + residual_idx) rtn_ds[i] = BlockedRandomSplitDataset(ds, indices[range(start_idx, start_idx+max(split_, 1))], block_size) - rtn_ds[i] = EnlargedDataset(rtn_ds[i]) + rtn_ds[i] = RandomMappingDataset(rtn_ds[i]) start_idx += split_ residual_idx %= 1 return rtn_ds @@ -308,7 +308,7 @@ class BlockedRandomSplitDataset(data.Dataset): return self.len def __getitem__(self, index): - return self.wrapped_data[(index // len(self.indices)) * self.block_size + self.indices[index % self.block_size]] + return self.wrapped_data[(index // len(self.indices)) * self.block_size + self.indices[index % len(self.indices)]] class EnlargedDataset(data.Dataset):