diff --git a/habitat/core/dataset.py b/habitat/core/dataset.py index 7e02927e2d92217b0b5e67b717eb2c6d9906471c..a7841dc07dddb3af2b23566e412853278ee79b9f 100644 --- a/habitat/core/dataset.py +++ b/habitat/core/dataset.py @@ -128,10 +128,11 @@ class Dataset(Generic[T]): def get_splits( self, num_splits: int, - max_episodes_per_split: Optional[int] = None, + episodes_per_split: Optional[int] = None, remove_unused_episodes: bool = False, collate_scene_ids: bool = True, sort_by_episode_id: bool = False, + allow_uneven_splits: bool = False, ) -> List["Dataset"]: """ Returns a list of new datasets, each with a subset of the original @@ -139,7 +140,7 @@ class Dataset(Generic[T]): episodes will be duplicated. Args: num_splits: The number of splits to create. - max_episodes_per_split: If provided, each split will have up to + episodes_per_split: If provided, each split will have up to this many episodes. If it is not provided, each dataset will have len(original_dataset.episodes) // num_splits episodes. If max_episodes_per_split is provided and is larger than this @@ -153,24 +154,42 @@ class Dataset(Generic[T]): to each other because they will be in the same scene. sort_by_episode_id: If true, sequences are sorted by their episode ID in the returned splits. + allow_uneven_splits: If true, the last split can be shorter than + the others. This is especially useful for splitting over + validation/test datasets in order to make sure that all + episodes are copied but none are duplicated. Returns: A list of new datasets, each with their own subset of episodes. """ - assert ( len(self.episodes) >= num_splits ), "Not enough episodes to create this many splits." + if episodes_per_split is not None: + assert not allow_uneven_splits, ( + "You probably don't want to specify allow_uneven_splits" + " and episodes_per_split." + ) + assert num_splits * episodes_per_split <= len(self.episodes) new_datasets = [] - if max_episodes_per_split is None: - max_episodes_per_split = len(self.episodes) // num_splits - max_episodes_per_split = min( - max_episodes_per_split, (len(self.episodes) // num_splits) - ) + + if allow_uneven_splits: + stride = int(np.ceil(len(self.episodes) * 1.0 / num_splits)) + split_lengths = [stride] * (num_splits - 1) + split_lengths.append( + (len(self.episodes) - stride * (num_splits - 1)) + ) + else: + if episodes_per_split is not None: + stride = episodes_per_split + else: + stride = len(self.episodes) // num_splits + split_lengths = [stride] * num_splits + + num_episodes = sum(split_lengths) + rand_items = np.random.choice( - len(self.episodes), - num_splits * max_episodes_per_split, - replace=False, + len(self.episodes), num_episodes, replace=False ) if collate_scene_ids: scene_ids = {} @@ -187,7 +206,7 @@ class Dataset(Generic[T]): new_dataset = copy.copy(self) # Creates a shallow copy new_dataset.episodes = [] new_datasets.append(new_dataset) - for ii in range(max_episodes_per_split): + for ii in range(split_lengths[nn]): new_dataset.episodes.append(self.episodes[rand_items[ep_ind]]) ep_ind += 1 if sort_by_episode_id: @@ -196,35 +215,3 @@ class Dataset(Generic[T]): if remove_unused_episodes: self.episodes = new_episodes return new_datasets - - def get_uneven_splits(self, num_splits): - """ - Returns a list of new datasets, each with a subset of the original - episodes. The last dataset may have fewer episodes than the others. - This is especially useful for splitting over validation/test datasets - in order to make sure that all episodes are copied but none are - duplicated. - Args: - num_splits: The number of splits to create. - Returns: - A list of new datasets, each with their own subset of episodes. - """ - assert ( - len(self.episodes) >= num_splits - ), "Not enough episodes to create this many splits." - new_datasets = [] - num_episodes = len(self.episodes) - stride = int(np.ceil(num_episodes * 1.0 / num_splits)) - for ii, split in enumerate( - range(0, num_episodes, stride)[:num_splits] - ): - new_dataset = copy.copy(self) # Creates a shallow copy - new_dataset.episodes = new_dataset.episodes[ - split : min(split + stride, num_episodes) - ].copy() - new_datasets.append(new_dataset) - assert ( - sum([len(new_dataset.episodes) for new_dataset in new_datasets]) - == num_episodes - ) - return new_datasets diff --git a/test/test_dataset.py b/test/test_dataset.py index 69b2bf69f83ae38d1460b9934753f2481767eb9d..975f6ada66300e81ac41ed6b5b5cbe315d219339 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -64,7 +64,7 @@ def test_get_splits_with_remainder(): assert len(split.episodes) == 9 -def test_get_splits_max_episodes_specified(): +def test_get_splits_num_episodes_specified(): dataset = _construct_dataset(100) splits = dataset.get_splits(10, 3, False) assert len(splits) == 10 @@ -73,7 +73,7 @@ def test_get_splits_max_episodes_specified(): assert len(dataset.episodes) == 100 dataset = _construct_dataset(100) - splits = dataset.get_splits(10, 11, False) + splits = dataset.get_splits(10, 10) assert len(splits) == 10 for split in splits: assert len(split.episodes) == 10 @@ -86,6 +86,13 @@ def test_get_splits_max_episodes_specified(): assert len(split.episodes) == 3 assert len(dataset.episodes) == 30 + dataset = _construct_dataset(100) + try: + splits = dataset.get_splits(10, 20) + assert False + except AssertionError: + pass + def test_get_splits_collate_scenes(): dataset = _construct_dataset(10000) @@ -155,7 +162,17 @@ def test_get_splits_sort_by_episode_id(): def test_get_uneven_splits(): - dataset = _construct_dataset(100) - splits = dataset.get_uneven_splits(9) + dataset = _construct_dataset(10000) + splits = dataset.get_splits(9, allow_uneven_splits=False) + assert len(splits) == 9 + assert sum([len(split.episodes) for split in splits]) == (10000 // 9) * 9 + + dataset = _construct_dataset(10000) + splits = dataset.get_splits(9, allow_uneven_splits=True) assert len(splits) == 9 - assert sum([len(split.episodes) for split in splits]) == 100 + assert sum([len(split.episodes) for split in splits]) == 10000 + + dataset = _construct_dataset(10000) + splits = dataset.get_splits(10, allow_uneven_splits=True) + assert len(splits) == 10 + assert sum([len(split.episodes) for split in splits]) == 10000