From 1847f79df21b8609fc28e40e1b7731a5a3a3e16f Mon Sep 17 00:00:00 2001 From: JasonJiazhiZhang <21229070+JasonJiazhiZhang@users.noreply.github.com> Date: Tue, 16 Jul 2019 21:52:19 -0700 Subject: [PATCH] Make env access episode through iterator (#156) Make env access episode through iterator --- habitat/core/dataset.py | 11 +++++ habitat/core/env.py | 33 +++++++++------ test/test_dataset.py | 8 ++++ test/test_pointnav_dataset.py | 4 +- test/test_sensors.py | 80 +++++++++++++++++++---------------- 5 files changed, 85 insertions(+), 51 deletions(-) diff --git a/habitat/core/dataset.py b/habitat/core/dataset.py index 713cb63ab..b0bc35ae4 100644 --- a/habitat/core/dataset.py +++ b/habitat/core/dataset.py @@ -10,6 +10,7 @@ of a ``habitat.Agent`` inside ``habitat.Env``. """ import copy import json +from itertools import cycle from typing import Callable, Dict, Generic, List, Optional, Type, TypeVar import attr @@ -89,6 +90,16 @@ class Dataset(Generic[T]): """ return [self.episodes[episode_id] for episode_id in indexes] + def get_episode_iterator(self): + r""" + Creates and returns an iterator that iterates through self.episodes + in the desirable way specified. + Returns: + iterator for episodes + """ + # TODO: support shuffling between epoch and scene switching + return cycle(self.episodes) + def to_json(self) -> str: class DatasetJSONEncoder(json.JSONEncoder): def default(self, object): diff --git a/habitat/core/env.py b/habitat/core/env.py index 721adb34c..7141050b7 100644 --- a/habitat/core/env.py +++ b/habitat/core/env.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import time -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, Iterator, List, Optional, Tuple, Type import gym import numpy as np @@ -46,6 +46,8 @@ class Env: _dataset: Optional[Dataset] _episodes: List[Type[Episode]] _current_episode_index: Optional[int] + _current_episode: Optional[Type[Episode]] + _episode_iterator: Optional[Iterator] _sim: Simulator _task: EmbodiedTask _max_episode_seconds: int @@ -69,6 +71,8 @@ class Env: id_dataset=config.DATASET.TYPE, config=config.DATASET ) self._episodes = self._dataset.episodes if self._dataset else [] + self._current_episode = None + self._episode_iterator = self._dataset.get_episode_iterator() # load the first scene if dataset is present if self._dataset: @@ -105,11 +109,20 @@ class Env: @property def current_episode(self) -> Type[Episode]: - assert ( - self._current_episode_index is not None - and self._current_episode_index < len(self._episodes) - ) - return self._episodes[self._current_episode_index] + assert self._current_episode is not None + return self._current_episode + + @current_episode.setter + def current_episode(self, episode: Type[Episode]) -> None: + self._current_episode = episode + + @property + def episode_iterator(self) -> Iterator: + return self._episode_iterator + + @episode_iterator.setter + def episode_iterator(self, new_iter: Iterator) -> None: + self._episode_iterator = new_iter @property def episodes(self) -> List[Type[Episode]]: @@ -176,13 +189,7 @@ class Env: assert len(self.episodes) > 0, "Episodes list is empty" - # Switch to next episode in a loop - if self._current_episode_index is None: - self._current_episode_index = 0 - else: - self._current_episode_index = ( - self._current_episode_index + 1 - ) % len(self._episodes) + self.current_episode = next(self._episode_iterator) self.reconfigure(self._config) observations = self._sim.reset() diff --git a/test/test_dataset.py b/test/test_dataset.py index 01d59f8da..20792a264 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -200,3 +200,11 @@ def test_sample_episodes(): dataset = _construct_dataset(10000) with pytest.raises(Exception): dataset.sample_episodes(10001) + + +def test_iterator_looping(): + dataset = _construct_dataset(100) + episode_iter = dataset.get_episode_iterator() + for i in range(200): + episode = next(episode_iter) + assert episode.episode_id == dataset.episodes[i % 100].episode_id diff --git a/test/test_pointnav_dataset.py b/test/test_pointnav_dataset.py index 7a3954c64..c655fc062 100644 --- a/test/test_pointnav_dataset.py +++ b/test/test_pointnav_dataset.py @@ -124,7 +124,7 @@ def check_shortest_path(env, episode): len(episode.shortest_paths) == 1 ), "Episode has no shortest paths or more than one." - env.episodes = [episode] + env.episode_iterator = iter([episode]) env.reset() start_state = env.sim.get_agent_state() check_state(start_state, episode.start_position, episode.start_rotation) @@ -165,7 +165,7 @@ def test_pointnav_episode_generator(): ): episodes.append(episode) assert len(episodes) == 2 * NUM_EPISODES - env.episodes = episodes + env.episode_iterator = iter(episodes) for episode in episodes: check_shortest_path(env, episode) diff --git a/test/test_sensors.py b/test/test_sensors.py index 2b4d5971b..9372dacf9 100644 --- a/test/test_sensors.py +++ b/test/test_sensors.py @@ -25,15 +25,17 @@ def _random_episode(env, config): 0, np.cos(random_heading / 2), ] - env.episodes = [ - NavigationEpisode( - episode_id="0", - scene_id=config.SIMULATOR.SCENE, - start_position=random_location, - start_rotation=random_rotation, - goals=[], - ) - ] + env.episode_iterator = iter( + [ + NavigationEpisode( + episode_id="0", + scene_id=config.SIMULATOR.SCENE, + start_position=random_location, + start_rotation=random_rotation, + goals=[], + ) + ] + ) def test_heading_sensor(): @@ -56,15 +58,17 @@ def test_heading_sensor(): 0, np.cos(random_heading / 2), ] - env.episodes = [ - NavigationEpisode( - episode_id="0", - scene_id=config.SIMULATOR.SCENE, - start_position=[03.00611, 0.072447, -2.67867], - start_rotation=random_rotation, - goals=[], - ) - ] + env.episode_iterator = iter( + [ + NavigationEpisode( + episode_id="0", + scene_id=config.SIMULATOR.SCENE, + start_position=[03.00611, 0.072447, -2.67867], + start_rotation=random_rotation, + goals=[], + ) + ] + ) obs = env.reset() heading = obs["heading"] @@ -167,15 +171,17 @@ def test_static_pointgoal_sensor(): # corresponds to simulator using z-negative as forward action start_rotation = [0, 0, 0, 1] - env.episodes = [ - NavigationEpisode( - episode_id="0", - scene_id=config.SIMULATOR.SCENE, - start_position=valid_start_position, - start_rotation=start_rotation, - goals=[NavigationGoal(position=goal_position)], - ) - ] + env.episode_iterator = iter( + [ + NavigationEpisode( + episode_id="0", + scene_id=config.SIMULATOR.SCENE, + start_position=valid_start_position, + start_rotation=start_rotation, + goals=[NavigationGoal(position=goal_position)], + ) + ] + ) non_stop_actions = [ act @@ -211,15 +217,17 @@ def test_get_observations_at(): # corresponds to simulator using z-negative as forward action start_rotation = [0, 0, 0, 1] - env.episodes = [ - NavigationEpisode( - episode_id="0", - scene_id=config.SIMULATOR.SCENE, - start_position=valid_start_position, - start_rotation=start_rotation, - goals=[NavigationGoal(position=goal_position)], - ) - ] + env.episode_iterator = iter( + [ + NavigationEpisode( + episode_id="0", + scene_id=config.SIMULATOR.SCENE, + start_position=valid_start_position, + start_rotation=start_rotation, + goals=[NavigationGoal(position=goal_position)], + ) + ] + ) non_stop_actions = [ act for act in range(env.action_space.n) -- GitLab