diff --git a/habitat/config/default.py b/habitat/config/default.py index 8fd60e74baaa3f7ce6b2bbabfa2ec5c120d812f6..a433e997ccaef07e1f2d8bdd93e77504c145d32e 100644 --- a/habitat/config/default.py +++ b/habitat/config/default.py @@ -22,6 +22,12 @@ _C.SEED = 100 _C.ENVIRONMENT = CN() _C.ENVIRONMENT.MAX_EPISODE_STEPS = 1000 _C.ENVIRONMENT.MAX_EPISODE_SECONDS = 10000000 +_C.ENVIRONMENT.ITERATOR_OPTIONS = CN() +_C.ENVIRONMENT.ITERATOR_OPTIONS.CYCLE = True +_C.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False +_C.ENVIRONMENT.ITERATOR_OPTIONS.GROUP_BY_SCENE = True +_C.ENVIRONMENT.ITERATOR_OPTIONS.NUM_EPISODE_SAMPLE = -1 +_C.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT = -1 # ----------------------------------------------------------------------------- # TASK # ----------------------------------------------------------------------------- @@ -150,7 +156,6 @@ _C.DATASET = CN() _C.DATASET.TYPE = "PointNav-v1" _C.DATASET.SPLIT = "train" _C.DATASET.SCENES_DIR = "data/scene_datasets" -_C.DATASET.NUM_EPISODE_SAMPLE = -1 _C.DATASET.CONTENT_SCENES = ["*"] _C.DATASET.DATA_PATH = ( "data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz" diff --git a/habitat/core/dataset.py b/habitat/core/dataset.py index b0bc35ae4d924384f5d545ce6aab400e3cc43fda..c94736ea39398e2a4a723c30f0ed9188a5c5b479 100644 --- a/habitat/core/dataset.py +++ b/habitat/core/dataset.py @@ -10,8 +10,18 @@ of a ``habitat.Agent`` inside ``habitat.Env``. """ import copy import json -from itertools import cycle -from typing import Callable, Dict, Generic, List, Optional, Type, TypeVar +import random +from itertools import groupby +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterator, + List, + Optional, + TypeVar, +) import attr import numpy as np @@ -48,16 +58,12 @@ class Episode: info: Optional[Dict[str, str]] = None -T = TypeVar("T", Episode, Type[Episode]) +T = TypeVar("T", bound=Episode) class Dataset(Generic[T]): r"""Base class for dataset specification. - - Attributes: - episodes: list of episodes containing instance information. """ - episodes: List[T] @property @@ -90,15 +96,20 @@ class Dataset(Generic[T]): """ return [self.episodes[episode_id] for episode_id in indexes] - def get_episode_iterator(self): - r""" - Creates and returns an iterator that iterates through self.episodes - in the desirable way specified. + def get_episode_iterator(self, *args: Any, **kwargs: Any) -> Iterator: + r"""Gets episode iterator with options. Options are specified in + EpisodeIterator documentation. To further customize iterator behavior + for your Dataset subclass, create a customized iterator class like + EpisodeIterator and override this method. + + Args: + *args: positional args for iterator constructor + **kwargs: keyword args for iterator constructor + Returns: - iterator for episodes + Iterator: episode iterator with specified behavior """ - # TODO: support shuffling between epoch and scene switching - return cycle(self.episodes) + return EpisodeIterator(self.episodes, *args, **kwargs) def to_json(self) -> str: class DatasetJSONEncoder(json.JSONEncoder): @@ -111,8 +122,7 @@ class Dataset(Generic[T]): def from_json( self, json_str: str, scenes_dir: Optional[str] = None ) -> None: - r""" - Creates dataset from ``json_str``. Directory containing relevant + r"""Creates dataset from ``json_str``. Directory containing relevant graphical assets of scenes is passed through ``scenes_dir``. Args: @@ -122,11 +132,8 @@ class Dataset(Generic[T]): """ raise NotImplementedError - def filter_episodes( - self, filter_fn: Callable[[Episode], bool] - ) -> "Dataset": - r""" - Returns a new dataset with only the filtered episodes from the + def filter_episodes(self, filter_fn: Callable[[T], bool]) -> "Dataset": + r"""Returns a new dataset with only the filtered episodes from the original dataset. Args: @@ -236,20 +243,118 @@ class Dataset(Generic[T]): self.episodes = new_episodes return new_datasets - def sample_episodes(self, num_episodes: int) -> None: - """ - Sample from existing episodes a list of episodes of size num_episodes, - and replace self.episodes with the list of sampled episodes. + +class EpisodeIterator(Iterator): + r"""Episode Iterator class that gives options for how a list of episodes + should be iterated. Some of those options are desirable for the internal + simulator to get higher performance. More context: simulator suffers + overhead when switching between scenes, therefore episodes of the same + scene should be loaded consecutively. However, if too many consecutive + episodes from same scene are feed into RL model, the model will risk to + overfit that scene. Therefore it's better to load same scene consecutively + and switch once a number threshold is reached. + + Currently supports the following features: + Cycling: when all episodes are iterated, cycle back to start instead of + throwing StopIteration. + Cycling with shuffle: when cycling back, shuffle episodes groups + grouped by scene. + Group by scene: episodes of same scene will be grouped and loaded + consecutively. + Set max scene repeat: set a number threshold on how many episodes from + the same scene can be loaded consecutively. + Sample episodes: sample the specified number of episodes. + """ + + def __init__( + self, + episodes: List[T], + cycle: bool = True, + shuffle: bool = False, + group_by_scene: bool = True, + max_scene_repeat: int = -1, + num_episode_sample: int = -1, + ): + r""" Args: - num_episodes: number of episodes to sample, input -1 to use - whole episodes + episodes: list of episodes. + cycle: if true, cycle back to first episodes when StopIteration. + shuffle: if true, shuffle scene groups when cycle. + No effect if cycle is set to false. Will shuffle grouped + scenes if group_by_scene is true. + group_by_scene: if true, group episodes from same scene. + max_scene_repeat: threshold of how many episodes from the same + scene can be loaded consecutively. -1 for no limit + num_episode_sample: number of episodes to be sampled. + -1 for no sampling. """ - if num_episodes == -1: - return - if num_episodes < -1: - raise ValueError( - f"Invalid number for episodes to sample: {num_episodes}" + # sample episodes + if num_episode_sample >= 0: + episodes = np.random.choice( + episodes, num_episode_sample, replace=False ) - self.episodes = np.random.choice( - self.episodes, num_episodes, replace=False - ) + self.episodes = episodes + self.cycle = cycle + self.group_by_scene = group_by_scene + if group_by_scene: + num_scene_groups = len( + list(groupby(episodes, key=lambda x: x.scene_id)) + ) + num_unique_scenes = len(set([e.scene_id for e in episodes])) + if num_scene_groups >= num_unique_scenes: + self.episodes = sorted(self.episodes, key=lambda x: x.scene_id) + self.max_scene_repetition = max_scene_repeat + self.shuffle = shuffle + self._rep_count = 0 + self._prev_scene_id = None + self._iterator = iter(self.episodes) + + def __iter__(self): + return self + + def __next__(self): + r"""The main logic for handling how episodes will be iterated. + + Returns: + next episode. + """ + + next_episode = next(self._iterator, None) + if next_episode is None: + if not self.cycle: + raise StopIteration + self._iterator = iter(self.episodes) + if self.shuffle: + self._shuffle_iterator() + next_episode = next(self._iterator) + + if self._prev_scene_id == next_episode.scene_id: + self._rep_count += 1 + if ( + self.max_scene_repetition > 0 + and self._rep_count >= self.max_scene_repetition - 1 + ): + self._shuffle_iterator() + self._rep_count = 0 + + self._prev_scene_id = next_episode.scene_id + return next_episode + + def _shuffle_iterator(self) -> None: + r"""Internal method that shuffles the remaining episodes. + If self.group_by_scene is true, then shuffle groups of scenes. + + Returns: + None. + """ + if self.group_by_scene: + grouped_episodes = [ + list(g) + for k, g in groupby(self._iterator, key=lambda x: x.scene_id) + ] + random.shuffle(grouped_episodes) + self._iterator = iter(sum(grouped_episodes, [])) + else: + episodes = list(self._iterator) + random.shuffle(episodes) + self._iterator = iter(episodes) diff --git a/habitat/core/env.py b/habitat/core/env.py index 926763306c928233107676e224214a353d6ddc88..d948ece59c98bfece847e720db83533fb6bc8b01 100644 --- a/habitat/core/env.py +++ b/habitat/core/env.py @@ -72,7 +72,13 @@ class Env: ) self._episodes = self._dataset.episodes if self._dataset else [] self._current_episode = None - self._episode_iterator = self._dataset.get_episode_iterator() + iter_option_dict = { + k.lower(): v + for k, v in config.ENVIRONMENT.ITERATOR_OPTIONS.items() + } + self._episode_iterator = self._dataset.get_episode_iterator( + **iter_option_dict + ) # load the first scene if dataset is present if self._dataset: diff --git a/habitat/datasets/eqa/mp3d_eqa_dataset.py b/habitat/datasets/eqa/mp3d_eqa_dataset.py index 4194fb58aefe230abc2ca33ee4da515125a79ee0..f256a5164f32aa004b574875bf4b69c6d37a41fa 100644 --- a/habitat/datasets/eqa/mp3d_eqa_dataset.py +++ b/habitat/datasets/eqa/mp3d_eqa_dataset.py @@ -52,8 +52,6 @@ class Matterport3dDatasetV1(Dataset): with gzip.open(config.DATA_PATH.format(split=config.SPLIT), "rt") as f: self.from_json(f.read()) - self.sample_episodes(config.NUM_EPISODE_SAMPLE) - def from_json( self, json_str: str, scenes_dir: Optional[str] = None ) -> None: diff --git a/habitat/datasets/pointnav/pointnav_dataset.py b/habitat/datasets/pointnav/pointnav_dataset.py index 70caad5e9bb3078f9c86b5e01182dfdbb868ca84..9e9cdf516f6cd28cd5edd5d0f187bcb6921966aa 100644 --- a/habitat/datasets/pointnav/pointnav_dataset.py +++ b/habitat/datasets/pointnav/pointnav_dataset.py @@ -98,8 +98,6 @@ class PointNavDatasetV1(Dataset): with gzip.open(scene_filename, "rt") as f: self.from_json(f.read(), scenes_dir=config.SCENES_DIR) - self.sample_episodes(config.NUM_EPISODE_SAMPLE) - def from_json( self, json_str: str, scenes_dir: Optional[str] = None ) -> None: diff --git a/test/test_dataset.py b/test/test_dataset.py index 20792a264fbd041c19e123397805eb375883c031..ea5b03b712f68b39c1a2cd5b0af0dd68c4366165 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -4,17 +4,19 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from itertools import groupby, islice + import pytest from habitat.core.dataset import Dataset, Episode -def _construct_dataset(num_episodes): +def _construct_dataset(num_episodes, num_groups=10): episodes = [] - for ii in range(num_episodes): + for i in range(num_episodes): episode = Episode( - episode_id=str(ii), - scene_id="scene_id_" + str(ii % 10), + episode_id=str(i), + scene_id="scene_id_" + str(i % num_groups), start_position=[0, 0, 0], start_rotation=[0, 0, 0, 1], ) @@ -181,30 +183,84 @@ def test_get_uneven_splits(): def test_sample_episodes(): - dataset = _construct_dataset(10000) - dataset.sample_episodes(-1) - assert len(dataset.episodes) == 10000 + dataset = _construct_dataset(1000) + ep_iter = dataset.get_episode_iterator( + num_episode_sample=1000, cycle=False + ) + assert len(list(ep_iter)) == 1000 - dataset = _construct_dataset(10000) - dataset.sample_episodes(0) - assert len(dataset.episodes) == 0 + ep_iter = dataset.get_episode_iterator(num_episode_sample=0, cycle=False) + assert len(list(ep_iter)) == 0 - dataset = _construct_dataset(10000) - dataset.sample_episodes(1) - assert len(dataset.episodes) == 1 + with pytest.raises(ValueError): + dataset.get_episode_iterator(num_episode_sample=1001, cycle=False) - dataset = _construct_dataset(10000) - dataset.sample_episodes(10000) - assert len(dataset.episodes) == 10000 + ep_iter = dataset.get_episode_iterator(num_episode_sample=100, cycle=True) + ep_id_list = [e.episode_id for e in list(islice(ep_iter, 100))] + assert len(set(ep_id_list)) == 100 + next_episode = next(ep_iter) + assert next_episode.episode_id in ep_id_list - dataset = _construct_dataset(10000) - with pytest.raises(Exception): - dataset.sample_episodes(10001) + ep_iter = dataset.get_episode_iterator(num_episode_sample=0, cycle=False) + with pytest.raises(StopIteration): + next(ep_iter) -def test_iterator_looping(): +def test_iterator_cycle(): dataset = _construct_dataset(100) - episode_iter = dataset.get_episode_iterator() + ep_iter = dataset.get_episode_iterator( + cycle=True, shuffle=False, group_by_scene=False + ) for i in range(200): - episode = next(episode_iter) + episode = next(ep_iter) assert episode.episode_id == dataset.episodes[i % 100].episode_id + + ep_iter = dataset.get_episode_iterator(cycle=True, num_episode_sample=20) + episodes = list(islice(ep_iter, 20)) + for i in range(200): + episode = next(ep_iter) + assert episode.episode_id == episodes[i % 20].episode_id + + +def test_iterator_shuffle(): + dataset = _construct_dataset(100) + episode_iter = dataset.get_episode_iterator(shuffle=True) + first_round_episodes = list(islice(episode_iter, 100)) + second_round_episodes = list(islice(episode_iter, 100)) + + # both rounds should have same episodes but in different order + assert sorted(first_round_episodes) == sorted(second_round_episodes) + assert first_round_episodes != second_round_episodes + + # both rounds should be grouped by scenes + first_round_scene_groups = [ + k for k, g in groupby(first_round_episodes, key=lambda x: x.scene_id) + ] + second_round_scene_groups = [ + k for k, g in groupby(second_round_episodes, key=lambda x: x.scene_id) + ] + assert len(first_round_scene_groups) == len(second_round_scene_groups) + assert len(first_round_scene_groups) == len(set(first_round_scene_groups)) + + +def test_iterator_scene_switching(): + total_ep = 1000 + max_repeat = 25 + dataset = _construct_dataset(total_ep) + + episode_iter = dataset.get_episode_iterator(max_scene_repeat=max_repeat) + episodes = sorted(dataset.episodes, key=lambda x: x.scene_id) + + # episodes before max_repeat reached should be identical + for i in range(max_repeat): + episode = next(episode_iter) + assert episode.episode_id == episodes.pop(0).episode_id + + remaining_episodes = list(islice(episode_iter, total_ep - max_repeat)) + # remaining episodes should be same but in different order + assert len(remaining_episodes) == len(episodes) + assert remaining_episodes != episodes + assert sorted(remaining_episodes) == sorted(episodes) + + # next episodes should still be grouped by scene (before next switching) + assert len(set([e.scene_id for e in remaining_episodes[:max_repeat]])) == 1