Skip to content
Snippets Groups Projects
Commit 69cbd3bd authored by danielgordon10's avatar danielgordon10 Committed by Oleksandr
Browse files

Merging get_uneven_splits into get_splits function (#74)

* merged uneven_splits stuff into split function
parent 57691b6a
No related branches found
No related tags found
No related merge requests found
...@@ -128,10 +128,11 @@ class Dataset(Generic[T]): ...@@ -128,10 +128,11 @@ class Dataset(Generic[T]):
def get_splits( def get_splits(
self, self,
num_splits: int, num_splits: int,
max_episodes_per_split: Optional[int] = None, episodes_per_split: Optional[int] = None,
remove_unused_episodes: bool = False, remove_unused_episodes: bool = False,
collate_scene_ids: bool = True, collate_scene_ids: bool = True,
sort_by_episode_id: bool = False, sort_by_episode_id: bool = False,
allow_uneven_splits: bool = False,
) -> List["Dataset"]: ) -> List["Dataset"]:
""" """
Returns a list of new datasets, each with a subset of the original Returns a list of new datasets, each with a subset of the original
...@@ -139,7 +140,7 @@ class Dataset(Generic[T]): ...@@ -139,7 +140,7 @@ class Dataset(Generic[T]):
episodes will be duplicated. episodes will be duplicated.
Args: Args:
num_splits: The number of splits to create. num_splits: The number of splits to create.
max_episodes_per_split: If provided, each split will have up to episodes_per_split: If provided, each split will have up to
this many episodes. If it is not provided, each dataset will this many episodes. If it is not provided, each dataset will
have len(original_dataset.episodes) // num_splits episodes. If have len(original_dataset.episodes) // num_splits episodes. If
max_episodes_per_split is provided and is larger than this max_episodes_per_split is provided and is larger than this
...@@ -153,24 +154,42 @@ class Dataset(Generic[T]): ...@@ -153,24 +154,42 @@ class Dataset(Generic[T]):
to each other because they will be in the same scene. to each other because they will be in the same scene.
sort_by_episode_id: If true, sequences are sorted by their episode sort_by_episode_id: If true, sequences are sorted by their episode
ID in the returned splits. ID in the returned splits.
allow_uneven_splits: If true, the last split can be shorter than
the others. This is especially useful for splitting over
validation/test datasets in order to make sure that all
episodes are copied but none are duplicated.
Returns: Returns:
A list of new datasets, each with their own subset of episodes. A list of new datasets, each with their own subset of episodes.
""" """
assert ( assert (
len(self.episodes) >= num_splits len(self.episodes) >= num_splits
), "Not enough episodes to create this many splits." ), "Not enough episodes to create this many splits."
if episodes_per_split is not None:
assert not allow_uneven_splits, (
"You probably don't want to specify allow_uneven_splits"
" and episodes_per_split."
)
assert num_splits * episodes_per_split <= len(self.episodes)
new_datasets = [] new_datasets = []
if max_episodes_per_split is None:
max_episodes_per_split = len(self.episodes) // num_splits if allow_uneven_splits:
max_episodes_per_split = min( stride = int(np.ceil(len(self.episodes) * 1.0 / num_splits))
max_episodes_per_split, (len(self.episodes) // num_splits) split_lengths = [stride] * (num_splits - 1)
) split_lengths.append(
(len(self.episodes) - stride * (num_splits - 1))
)
else:
if episodes_per_split is not None:
stride = episodes_per_split
else:
stride = len(self.episodes) // num_splits
split_lengths = [stride] * num_splits
num_episodes = sum(split_lengths)
rand_items = np.random.choice( rand_items = np.random.choice(
len(self.episodes), len(self.episodes), num_episodes, replace=False
num_splits * max_episodes_per_split,
replace=False,
) )
if collate_scene_ids: if collate_scene_ids:
scene_ids = {} scene_ids = {}
...@@ -187,7 +206,7 @@ class Dataset(Generic[T]): ...@@ -187,7 +206,7 @@ class Dataset(Generic[T]):
new_dataset = copy.copy(self) # Creates a shallow copy new_dataset = copy.copy(self) # Creates a shallow copy
new_dataset.episodes = [] new_dataset.episodes = []
new_datasets.append(new_dataset) new_datasets.append(new_dataset)
for ii in range(max_episodes_per_split): for ii in range(split_lengths[nn]):
new_dataset.episodes.append(self.episodes[rand_items[ep_ind]]) new_dataset.episodes.append(self.episodes[rand_items[ep_ind]])
ep_ind += 1 ep_ind += 1
if sort_by_episode_id: if sort_by_episode_id:
...@@ -196,35 +215,3 @@ class Dataset(Generic[T]): ...@@ -196,35 +215,3 @@ class Dataset(Generic[T]):
if remove_unused_episodes: if remove_unused_episodes:
self.episodes = new_episodes self.episodes = new_episodes
return new_datasets return new_datasets
def get_uneven_splits(self, num_splits):
"""
Returns a list of new datasets, each with a subset of the original
episodes. The last dataset may have fewer episodes than the others.
This is especially useful for splitting over validation/test datasets
in order to make sure that all episodes are copied but none are
duplicated.
Args:
num_splits: The number of splits to create.
Returns:
A list of new datasets, each with their own subset of episodes.
"""
assert (
len(self.episodes) >= num_splits
), "Not enough episodes to create this many splits."
new_datasets = []
num_episodes = len(self.episodes)
stride = int(np.ceil(num_episodes * 1.0 / num_splits))
for ii, split in enumerate(
range(0, num_episodes, stride)[:num_splits]
):
new_dataset = copy.copy(self) # Creates a shallow copy
new_dataset.episodes = new_dataset.episodes[
split : min(split + stride, num_episodes)
].copy()
new_datasets.append(new_dataset)
assert (
sum([len(new_dataset.episodes) for new_dataset in new_datasets])
== num_episodes
)
return new_datasets
...@@ -64,7 +64,7 @@ def test_get_splits_with_remainder(): ...@@ -64,7 +64,7 @@ def test_get_splits_with_remainder():
assert len(split.episodes) == 9 assert len(split.episodes) == 9
def test_get_splits_max_episodes_specified(): def test_get_splits_num_episodes_specified():
dataset = _construct_dataset(100) dataset = _construct_dataset(100)
splits = dataset.get_splits(10, 3, False) splits = dataset.get_splits(10, 3, False)
assert len(splits) == 10 assert len(splits) == 10
...@@ -73,7 +73,7 @@ def test_get_splits_max_episodes_specified(): ...@@ -73,7 +73,7 @@ def test_get_splits_max_episodes_specified():
assert len(dataset.episodes) == 100 assert len(dataset.episodes) == 100
dataset = _construct_dataset(100) dataset = _construct_dataset(100)
splits = dataset.get_splits(10, 11, False) splits = dataset.get_splits(10, 10)
assert len(splits) == 10 assert len(splits) == 10
for split in splits: for split in splits:
assert len(split.episodes) == 10 assert len(split.episodes) == 10
...@@ -86,6 +86,13 @@ def test_get_splits_max_episodes_specified(): ...@@ -86,6 +86,13 @@ def test_get_splits_max_episodes_specified():
assert len(split.episodes) == 3 assert len(split.episodes) == 3
assert len(dataset.episodes) == 30 assert len(dataset.episodes) == 30
dataset = _construct_dataset(100)
try:
splits = dataset.get_splits(10, 20)
assert False
except AssertionError:
pass
def test_get_splits_collate_scenes(): def test_get_splits_collate_scenes():
dataset = _construct_dataset(10000) dataset = _construct_dataset(10000)
...@@ -155,7 +162,17 @@ def test_get_splits_sort_by_episode_id(): ...@@ -155,7 +162,17 @@ def test_get_splits_sort_by_episode_id():
def test_get_uneven_splits(): def test_get_uneven_splits():
dataset = _construct_dataset(100) dataset = _construct_dataset(10000)
splits = dataset.get_uneven_splits(9) splits = dataset.get_splits(9, allow_uneven_splits=False)
assert len(splits) == 9
assert sum([len(split.episodes) for split in splits]) == (10000 // 9) * 9
dataset = _construct_dataset(10000)
splits = dataset.get_splits(9, allow_uneven_splits=True)
assert len(splits) == 9 assert len(splits) == 9
assert sum([len(split.episodes) for split in splits]) == 100 assert sum([len(split.episodes) for split in splits]) == 10000
dataset = _construct_dataset(10000)
splits = dataset.get_splits(10, allow_uneven_splits=True)
assert len(splits) == 10
assert sum([len(split.episodes) for split in splits]) == 10000
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment