diff --git a/habitat/core/dataset.py b/habitat/core/dataset.py index 4bbb5384cb69add17e918a82ada25a5fe173accc..7e02927e2d92217b0b5e67b717eb2c6d9906471c 100644 --- a/habitat/core/dataset.py +++ b/habitat/core/dataset.py @@ -4,8 +4,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import copy import json -from typing import Dict, List, Type, TypeVar, Generic, Optional +import random +from typing import Dict, List, Type, TypeVar, Generic, Optional, Callable + +import numpy as np class Episode: @@ -67,7 +71,7 @@ class Dataset(Generic[T]): Returns: unique scene ids present in the dataset """ - return list({episode.scene_id for episode in self.episodes}) + return sorted(list({episode.scene_id for episode in self.episodes})) def get_scene_episodes(self, scene_id: str) -> List[T]: """ @@ -101,3 +105,126 @@ class Dataset(Generic[T]): def from_json(self, json_str: str) -> None: raise NotImplementedError + + def filter_episodes( + self, filter_fn: Callable[[Episode], bool] + ) -> "Dataset": + """ + Returns a new dataset with only the filtered episodes from the original + dataset. + Args: + filter_fn: Function used to filter the episodes. + Returns: + The new dataset. + """ + new_episodes = [] + for episode in self.episodes: + if filter_fn(episode): + new_episodes.append(episode) + new_dataset = copy.copy(self) + new_dataset.episodes = new_episodes + return new_dataset + + def get_splits( + self, + num_splits: int, + max_episodes_per_split: Optional[int] = None, + remove_unused_episodes: bool = False, + collate_scene_ids: bool = True, + sort_by_episode_id: bool = False, + ) -> List["Dataset"]: + """ + Returns a list of new datasets, each with a subset of the original + episodes. All splits will have the same number of episodes, but no + episodes will be duplicated. + Args: + num_splits: The number of splits to create. + max_episodes_per_split: If provided, each split will have up to + this many episodes. If it is not provided, each dataset will + have len(original_dataset.episodes) // num_splits episodes. If + max_episodes_per_split is provided and is larger than this + value, it will be capped to this value. + remove_unused_episodes: Once the splits are created, the extra + episodes will be destroyed from the original dataset. This + saves memory for large datasets. + collate_scene_ids: If true, episodes with the same scene id are + next to each other. This saves on overhead of switching between + scenes, but means multiple sequential episodes will be related + to each other because they will be in the same scene. + sort_by_episode_id: If true, sequences are sorted by their episode + ID in the returned splits. + Returns: + A list of new datasets, each with their own subset of episodes. + """ + + assert ( + len(self.episodes) >= num_splits + ), "Not enough episodes to create this many splits." + + new_datasets = [] + if max_episodes_per_split is None: + max_episodes_per_split = len(self.episodes) // num_splits + max_episodes_per_split = min( + max_episodes_per_split, (len(self.episodes) // num_splits) + ) + rand_items = np.random.choice( + len(self.episodes), + num_splits * max_episodes_per_split, + replace=False, + ) + if collate_scene_ids: + scene_ids = {} + for rand_ind in rand_items: + scene = self.episodes[rand_ind].scene_id + if scene not in scene_ids: + scene_ids[scene] = [] + scene_ids[scene].append(rand_ind) + rand_items = [] + list(map(rand_items.extend, scene_ids.values())) + ep_ind = 0 + new_episodes = [] + for nn in range(num_splits): + new_dataset = copy.copy(self) # Creates a shallow copy + new_dataset.episodes = [] + new_datasets.append(new_dataset) + for ii in range(max_episodes_per_split): + new_dataset.episodes.append(self.episodes[rand_items[ep_ind]]) + ep_ind += 1 + if sort_by_episode_id: + new_dataset.episodes.sort(key=lambda ep: ep.episode_id) + new_episodes.extend(new_dataset.episodes) + if remove_unused_episodes: + self.episodes = new_episodes + return new_datasets + + def get_uneven_splits(self, num_splits): + """ + Returns a list of new datasets, each with a subset of the original + episodes. The last dataset may have fewer episodes than the others. + This is especially useful for splitting over validation/test datasets + in order to make sure that all episodes are copied but none are + duplicated. + Args: + num_splits: The number of splits to create. + Returns: + A list of new datasets, each with their own subset of episodes. + """ + assert ( + len(self.episodes) >= num_splits + ), "Not enough episodes to create this many splits." + new_datasets = [] + num_episodes = len(self.episodes) + stride = int(np.ceil(num_episodes * 1.0 / num_splits)) + for ii, split in enumerate( + range(0, num_episodes, stride)[:num_splits] + ): + new_dataset = copy.copy(self) # Creates a shallow copy + new_dataset.episodes = new_dataset.episodes[ + split : min(split + stride, num_episodes) + ].copy() + new_datasets.append(new_dataset) + assert ( + sum([len(new_dataset.episodes) for new_dataset in new_datasets]) + == num_episodes + ) + return new_datasets diff --git a/test/test_dataset.py b/test/test_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..69b2bf69f83ae38d1460b9934753f2481767eb9d --- /dev/null +++ b/test/test_dataset.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from habitat.core.dataset import Dataset, Episode + + +def _construct_dataset(num_episodes): + episodes = [] + for ii in range(num_episodes): + episode = Episode( + episode_id=str(ii), + scene_id="scene_id_" + str(ii % 10), + start_position=[0, 0, 0], + start_rotation=[0, 0, 0, 1], + ) + episodes.append(episode) + dataset = Dataset() + dataset.episodes = episodes + return dataset + + +def test_scene_ids(): + dataset = _construct_dataset(100) + assert dataset.scene_ids == ["scene_id_" + str(ii) for ii in range(10)] + + +def test_get_scene_episodes(): + dataset = _construct_dataset(100) + scene = "scene_id_0" + scene_episodes = dataset.get_scene_episodes(scene) + assert len(scene_episodes) == 10 + for ep in scene_episodes: + assert ep.scene_id == scene + + +def test_filter_episodes(): + dataset = _construct_dataset(100) + + def filter_fn(episode: Episode) -> bool: + return int(episode.episode_id) % 2 == 0 + + filtered_dataset = dataset.filter_episodes(filter_fn) + assert len(filtered_dataset.episodes) == 50 + for ep in filtered_dataset.episodes: + assert filter_fn(ep) + + +def test_get_splits_even_split_possible(): + dataset = _construct_dataset(100) + splits = dataset.get_splits(10) + assert len(splits) == 10 + for split in splits: + assert len(split.episodes) == 10 + + +def test_get_splits_with_remainder(): + dataset = _construct_dataset(100) + splits = dataset.get_splits(11) + assert len(splits) == 11 + for split in splits: + assert len(split.episodes) == 9 + + +def test_get_splits_max_episodes_specified(): + dataset = _construct_dataset(100) + splits = dataset.get_splits(10, 3, False) + assert len(splits) == 10 + for split in splits: + assert len(split.episodes) == 3 + assert len(dataset.episodes) == 100 + + dataset = _construct_dataset(100) + splits = dataset.get_splits(10, 11, False) + assert len(splits) == 10 + for split in splits: + assert len(split.episodes) == 10 + assert len(dataset.episodes) == 100 + + dataset = _construct_dataset(100) + splits = dataset.get_splits(10, 3, True) + assert len(splits) == 10 + for split in splits: + assert len(split.episodes) == 3 + assert len(dataset.episodes) == 30 + + +def test_get_splits_collate_scenes(): + dataset = _construct_dataset(10000) + splits = dataset.get_splits(10, 23, collate_scene_ids=True) + assert len(splits) == 10 + for split in splits: + assert len(split.episodes) == 23 + prev_ids = set() + for ii, ep in enumerate(split.episodes): + if ep.scene_id not in prev_ids: + prev_ids.add(ep.scene_id) + else: + assert split.episodes[ii - 1].scene_id == ep.scene_id + + dataset = _construct_dataset(10000) + splits = dataset.get_splits(10, 200, collate_scene_ids=False) + assert len(splits) == 10 + for split in splits: + prev_ids = set() + found_not_collated = False + for ii, ep in enumerate(split.episodes): + if ep.scene_id not in prev_ids: + prev_ids.add(ep.scene_id) + else: + if split.episodes[ii - 1].scene_id != ep.scene_id: + found_not_collated = True + break + assert found_not_collated + + dataset = _construct_dataset(10000) + splits = dataset.get_splits(10, collate_scene_ids=True) + assert len(splits) == 10 + for split in splits: + assert len(split.episodes) == 1000 + prev_ids = set() + for ii, ep in enumerate(split.episodes): + if ep.scene_id not in prev_ids: + prev_ids.add(ep.scene_id) + else: + assert split.episodes[ii - 1].scene_id == ep.scene_id + + dataset = _construct_dataset(10000) + splits = dataset.get_splits(10, collate_scene_ids=False) + assert len(splits) == 10 + for split in splits: + prev_ids = set() + found_not_collated = False + for ii, ep in enumerate(split.episodes): + if ep.scene_id not in prev_ids: + prev_ids.add(ep.scene_id) + else: + if split.episodes[ii - 1].scene_id != ep.scene_id: + found_not_collated = True + break + assert found_not_collated + + +def test_get_splits_sort_by_episode_id(): + dataset = _construct_dataset(10000) + splits = dataset.get_splits(10, 23, sort_by_episode_id=True) + assert len(splits) == 10 + for split in splits: + assert len(split.episodes) == 23 + for ii, ep in enumerate(split.episodes): + if ii > 0: + assert ep.episode_id >= split.episodes[ii - 1].episode_id + + +def test_get_uneven_splits(): + dataset = _construct_dataset(100) + splits = dataset.get_uneven_splits(9) + assert len(splits) == 9 + assert sum([len(split.episodes) for split in splits]) == 100