Skip to content
Snippets Groups Projects
Commit 69cbd3bd authored by danielgordon10's avatar danielgordon10 Committed by Oleksandr
Browse files

Merging get_uneven_splits into get_splits function (#74)

* merged uneven_splits stuff into split function
parent 57691b6a
No related branches found
No related tags found
No related merge requests found
......@@ -128,10 +128,11 @@ class Dataset(Generic[T]):
def get_splits(
self,
num_splits: int,
max_episodes_per_split: Optional[int] = None,
episodes_per_split: Optional[int] = None,
remove_unused_episodes: bool = False,
collate_scene_ids: bool = True,
sort_by_episode_id: bool = False,
allow_uneven_splits: bool = False,
) -> List["Dataset"]:
"""
Returns a list of new datasets, each with a subset of the original
......@@ -139,7 +140,7 @@ class Dataset(Generic[T]):
episodes will be duplicated.
Args:
num_splits: The number of splits to create.
max_episodes_per_split: If provided, each split will have up to
episodes_per_split: If provided, each split will have up to
this many episodes. If it is not provided, each dataset will
have len(original_dataset.episodes) // num_splits episodes. If
max_episodes_per_split is provided and is larger than this
......@@ -153,24 +154,42 @@ class Dataset(Generic[T]):
to each other because they will be in the same scene.
sort_by_episode_id: If true, sequences are sorted by their episode
ID in the returned splits.
allow_uneven_splits: If true, the last split can be shorter than
the others. This is especially useful for splitting over
validation/test datasets in order to make sure that all
episodes are copied but none are duplicated.
Returns:
A list of new datasets, each with their own subset of episodes.
"""
assert (
len(self.episodes) >= num_splits
), "Not enough episodes to create this many splits."
if episodes_per_split is not None:
assert not allow_uneven_splits, (
"You probably don't want to specify allow_uneven_splits"
" and episodes_per_split."
)
assert num_splits * episodes_per_split <= len(self.episodes)
new_datasets = []
if max_episodes_per_split is None:
max_episodes_per_split = len(self.episodes) // num_splits
max_episodes_per_split = min(
max_episodes_per_split, (len(self.episodes) // num_splits)
)
if allow_uneven_splits:
stride = int(np.ceil(len(self.episodes) * 1.0 / num_splits))
split_lengths = [stride] * (num_splits - 1)
split_lengths.append(
(len(self.episodes) - stride * (num_splits - 1))
)
else:
if episodes_per_split is not None:
stride = episodes_per_split
else:
stride = len(self.episodes) // num_splits
split_lengths = [stride] * num_splits
num_episodes = sum(split_lengths)
rand_items = np.random.choice(
len(self.episodes),
num_splits * max_episodes_per_split,
replace=False,
len(self.episodes), num_episodes, replace=False
)
if collate_scene_ids:
scene_ids = {}
......@@ -187,7 +206,7 @@ class Dataset(Generic[T]):
new_dataset = copy.copy(self) # Creates a shallow copy
new_dataset.episodes = []
new_datasets.append(new_dataset)
for ii in range(max_episodes_per_split):
for ii in range(split_lengths[nn]):
new_dataset.episodes.append(self.episodes[rand_items[ep_ind]])
ep_ind += 1
if sort_by_episode_id:
......@@ -196,35 +215,3 @@ class Dataset(Generic[T]):
if remove_unused_episodes:
self.episodes = new_episodes
return new_datasets
def get_uneven_splits(self, num_splits):
"""
Returns a list of new datasets, each with a subset of the original
episodes. The last dataset may have fewer episodes than the others.
This is especially useful for splitting over validation/test datasets
in order to make sure that all episodes are copied but none are
duplicated.
Args:
num_splits: The number of splits to create.
Returns:
A list of new datasets, each with their own subset of episodes.
"""
assert (
len(self.episodes) >= num_splits
), "Not enough episodes to create this many splits."
new_datasets = []
num_episodes = len(self.episodes)
stride = int(np.ceil(num_episodes * 1.0 / num_splits))
for ii, split in enumerate(
range(0, num_episodes, stride)[:num_splits]
):
new_dataset = copy.copy(self) # Creates a shallow copy
new_dataset.episodes = new_dataset.episodes[
split : min(split + stride, num_episodes)
].copy()
new_datasets.append(new_dataset)
assert (
sum([len(new_dataset.episodes) for new_dataset in new_datasets])
== num_episodes
)
return new_datasets
......@@ -64,7 +64,7 @@ def test_get_splits_with_remainder():
assert len(split.episodes) == 9
def test_get_splits_max_episodes_specified():
def test_get_splits_num_episodes_specified():
dataset = _construct_dataset(100)
splits = dataset.get_splits(10, 3, False)
assert len(splits) == 10
......@@ -73,7 +73,7 @@ def test_get_splits_max_episodes_specified():
assert len(dataset.episodes) == 100
dataset = _construct_dataset(100)
splits = dataset.get_splits(10, 11, False)
splits = dataset.get_splits(10, 10)
assert len(splits) == 10
for split in splits:
assert len(split.episodes) == 10
......@@ -86,6 +86,13 @@ def test_get_splits_max_episodes_specified():
assert len(split.episodes) == 3
assert len(dataset.episodes) == 30
dataset = _construct_dataset(100)
try:
splits = dataset.get_splits(10, 20)
assert False
except AssertionError:
pass
def test_get_splits_collate_scenes():
dataset = _construct_dataset(10000)
......@@ -155,7 +162,17 @@ def test_get_splits_sort_by_episode_id():
def test_get_uneven_splits():
dataset = _construct_dataset(100)
splits = dataset.get_uneven_splits(9)
dataset = _construct_dataset(10000)
splits = dataset.get_splits(9, allow_uneven_splits=False)
assert len(splits) == 9
assert sum([len(split.episodes) for split in splits]) == (10000 // 9) * 9
dataset = _construct_dataset(10000)
splits = dataset.get_splits(9, allow_uneven_splits=True)
assert len(splits) == 9
assert sum([len(split.episodes) for split in splits]) == 100
assert sum([len(split.episodes) for split in splits]) == 10000
dataset = _construct_dataset(10000)
splits = dataset.get_splits(10, allow_uneven_splits=True)
assert len(splits) == 10
assert sum([len(split.episodes) for split in splits]) == 10000
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment