From 69cbd3bdc81f43a17bb785247b34ee5c5bfa24e7 Mon Sep 17 00:00:00 2001 From: danielgordon10 <danielgordon10@gmail.com> Date: Wed, 24 Apr 2019 19:22:26 -0700 Subject: [PATCH] Merging get_uneven_splits into get_splits function (#74) * merged uneven_splits stuff into split function --- habitat/core/dataset.py | 75 +++++++++++++++++------------------------ test/test_dataset.py | 27 ++++++++++++--- 2 files changed, 53 insertions(+), 49 deletions(-) diff --git a/habitat/core/dataset.py b/habitat/core/dataset.py index 7e02927e2..a7841dc07 100644 --- a/habitat/core/dataset.py +++ b/habitat/core/dataset.py @@ -128,10 +128,11 @@ class Dataset(Generic[T]): def get_splits( self, num_splits: int, - max_episodes_per_split: Optional[int] = None, + episodes_per_split: Optional[int] = None, remove_unused_episodes: bool = False, collate_scene_ids: bool = True, sort_by_episode_id: bool = False, + allow_uneven_splits: bool = False, ) -> List["Dataset"]: """ Returns a list of new datasets, each with a subset of the original @@ -139,7 +140,7 @@ class Dataset(Generic[T]): episodes will be duplicated. Args: num_splits: The number of splits to create. - max_episodes_per_split: If provided, each split will have up to + episodes_per_split: If provided, each split will have up to this many episodes. If it is not provided, each dataset will have len(original_dataset.episodes) // num_splits episodes. If max_episodes_per_split is provided and is larger than this @@ -153,24 +154,42 @@ class Dataset(Generic[T]): to each other because they will be in the same scene. sort_by_episode_id: If true, sequences are sorted by their episode ID in the returned splits. + allow_uneven_splits: If true, the last split can be shorter than + the others. This is especially useful for splitting over + validation/test datasets in order to make sure that all + episodes are copied but none are duplicated. Returns: A list of new datasets, each with their own subset of episodes. """ - assert ( len(self.episodes) >= num_splits ), "Not enough episodes to create this many splits." + if episodes_per_split is not None: + assert not allow_uneven_splits, ( + "You probably don't want to specify allow_uneven_splits" + " and episodes_per_split." + ) + assert num_splits * episodes_per_split <= len(self.episodes) new_datasets = [] - if max_episodes_per_split is None: - max_episodes_per_split = len(self.episodes) // num_splits - max_episodes_per_split = min( - max_episodes_per_split, (len(self.episodes) // num_splits) - ) + + if allow_uneven_splits: + stride = int(np.ceil(len(self.episodes) * 1.0 / num_splits)) + split_lengths = [stride] * (num_splits - 1) + split_lengths.append( + (len(self.episodes) - stride * (num_splits - 1)) + ) + else: + if episodes_per_split is not None: + stride = episodes_per_split + else: + stride = len(self.episodes) // num_splits + split_lengths = [stride] * num_splits + + num_episodes = sum(split_lengths) + rand_items = np.random.choice( - len(self.episodes), - num_splits * max_episodes_per_split, - replace=False, + len(self.episodes), num_episodes, replace=False ) if collate_scene_ids: scene_ids = {} @@ -187,7 +206,7 @@ class Dataset(Generic[T]): new_dataset = copy.copy(self) # Creates a shallow copy new_dataset.episodes = [] new_datasets.append(new_dataset) - for ii in range(max_episodes_per_split): + for ii in range(split_lengths[nn]): new_dataset.episodes.append(self.episodes[rand_items[ep_ind]]) ep_ind += 1 if sort_by_episode_id: @@ -196,35 +215,3 @@ class Dataset(Generic[T]): if remove_unused_episodes: self.episodes = new_episodes return new_datasets - - def get_uneven_splits(self, num_splits): - """ - Returns a list of new datasets, each with a subset of the original - episodes. The last dataset may have fewer episodes than the others. - This is especially useful for splitting over validation/test datasets - in order to make sure that all episodes are copied but none are - duplicated. - Args: - num_splits: The number of splits to create. - Returns: - A list of new datasets, each with their own subset of episodes. - """ - assert ( - len(self.episodes) >= num_splits - ), "Not enough episodes to create this many splits." - new_datasets = [] - num_episodes = len(self.episodes) - stride = int(np.ceil(num_episodes * 1.0 / num_splits)) - for ii, split in enumerate( - range(0, num_episodes, stride)[:num_splits] - ): - new_dataset = copy.copy(self) # Creates a shallow copy - new_dataset.episodes = new_dataset.episodes[ - split : min(split + stride, num_episodes) - ].copy() - new_datasets.append(new_dataset) - assert ( - sum([len(new_dataset.episodes) for new_dataset in new_datasets]) - == num_episodes - ) - return new_datasets diff --git a/test/test_dataset.py b/test/test_dataset.py index 69b2bf69f..975f6ada6 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -64,7 +64,7 @@ def test_get_splits_with_remainder(): assert len(split.episodes) == 9 -def test_get_splits_max_episodes_specified(): +def test_get_splits_num_episodes_specified(): dataset = _construct_dataset(100) splits = dataset.get_splits(10, 3, False) assert len(splits) == 10 @@ -73,7 +73,7 @@ def test_get_splits_max_episodes_specified(): assert len(dataset.episodes) == 100 dataset = _construct_dataset(100) - splits = dataset.get_splits(10, 11, False) + splits = dataset.get_splits(10, 10) assert len(splits) == 10 for split in splits: assert len(split.episodes) == 10 @@ -86,6 +86,13 @@ def test_get_splits_max_episodes_specified(): assert len(split.episodes) == 3 assert len(dataset.episodes) == 30 + dataset = _construct_dataset(100) + try: + splits = dataset.get_splits(10, 20) + assert False + except AssertionError: + pass + def test_get_splits_collate_scenes(): dataset = _construct_dataset(10000) @@ -155,7 +162,17 @@ def test_get_splits_sort_by_episode_id(): def test_get_uneven_splits(): - dataset = _construct_dataset(100) - splits = dataset.get_uneven_splits(9) + dataset = _construct_dataset(10000) + splits = dataset.get_splits(9, allow_uneven_splits=False) + assert len(splits) == 9 + assert sum([len(split.episodes) for split in splits]) == (10000 // 9) * 9 + + dataset = _construct_dataset(10000) + splits = dataset.get_splits(9, allow_uneven_splits=True) assert len(splits) == 9 - assert sum([len(split.episodes) for split in splits]) == 100 + assert sum([len(split.episodes) for split in splits]) == 10000 + + dataset = _construct_dataset(10000) + splits = dataset.get_splits(10, allow_uneven_splits=True) + assert len(splits) == 10 + assert sum([len(split.episodes) for split in splits]) == 10000 -- GitLab