From 1847f79df21b8609fc28e40e1b7731a5a3a3e16f Mon Sep 17 00:00:00 2001
From: JasonJiazhiZhang <21229070+JasonJiazhiZhang@users.noreply.github.com>
Date: Tue, 16 Jul 2019 21:52:19 -0700
Subject: [PATCH] Make env access episode through iterator (#156)

Make env access episode through iterator
---
 habitat/core/dataset.py       | 11 +++++
 habitat/core/env.py           | 33 +++++++++------
 test/test_dataset.py          |  8 ++++
 test/test_pointnav_dataset.py |  4 +-
 test/test_sensors.py          | 80 +++++++++++++++++++----------------
 5 files changed, 85 insertions(+), 51 deletions(-)

diff --git a/habitat/core/dataset.py b/habitat/core/dataset.py
index 713cb63ab..b0bc35ae4 100644
--- a/habitat/core/dataset.py
+++ b/habitat/core/dataset.py
@@ -10,6 +10,7 @@ of a ``habitat.Agent`` inside ``habitat.Env``.
 """
 import copy
 import json
+from itertools import cycle
 from typing import Callable, Dict, Generic, List, Optional, Type, TypeVar
 
 import attr
@@ -89,6 +90,16 @@ class Dataset(Generic[T]):
         """
         return [self.episodes[episode_id] for episode_id in indexes]
 
+    def get_episode_iterator(self):
+        r"""
+        Creates and returns an iterator that iterates through self.episodes
+        in the desirable way specified.
+        Returns:
+            iterator for episodes
+        """
+        # TODO: support shuffling between epoch and  scene switching
+        return cycle(self.episodes)
+
     def to_json(self) -> str:
         class DatasetJSONEncoder(json.JSONEncoder):
             def default(self, object):
diff --git a/habitat/core/env.py b/habitat/core/env.py
index 721adb34c..7141050b7 100644
--- a/habitat/core/env.py
+++ b/habitat/core/env.py
@@ -5,7 +5,7 @@
 # LICENSE file in the root directory of this source tree.
 
 import time
-from typing import Any, Dict, List, Optional, Tuple, Type
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Type
 
 import gym
 import numpy as np
@@ -46,6 +46,8 @@ class Env:
     _dataset: Optional[Dataset]
     _episodes: List[Type[Episode]]
     _current_episode_index: Optional[int]
+    _current_episode: Optional[Type[Episode]]
+    _episode_iterator: Optional[Iterator]
     _sim: Simulator
     _task: EmbodiedTask
     _max_episode_seconds: int
@@ -69,6 +71,8 @@ class Env:
                 id_dataset=config.DATASET.TYPE, config=config.DATASET
             )
         self._episodes = self._dataset.episodes if self._dataset else []
+        self._current_episode = None
+        self._episode_iterator = self._dataset.get_episode_iterator()
 
         # load the first scene if dataset is present
         if self._dataset:
@@ -105,11 +109,20 @@ class Env:
 
     @property
     def current_episode(self) -> Type[Episode]:
-        assert (
-            self._current_episode_index is not None
-            and self._current_episode_index < len(self._episodes)
-        )
-        return self._episodes[self._current_episode_index]
+        assert self._current_episode is not None
+        return self._current_episode
+
+    @current_episode.setter
+    def current_episode(self, episode: Type[Episode]) -> None:
+        self._current_episode = episode
+
+    @property
+    def episode_iterator(self) -> Iterator:
+        return self._episode_iterator
+
+    @episode_iterator.setter
+    def episode_iterator(self, new_iter: Iterator) -> None:
+        self._episode_iterator = new_iter
 
     @property
     def episodes(self) -> List[Type[Episode]]:
@@ -176,13 +189,7 @@ class Env:
 
         assert len(self.episodes) > 0, "Episodes list is empty"
 
-        # Switch to next episode in a loop
-        if self._current_episode_index is None:
-            self._current_episode_index = 0
-        else:
-            self._current_episode_index = (
-                self._current_episode_index + 1
-            ) % len(self._episodes)
+        self.current_episode = next(self._episode_iterator)
         self.reconfigure(self._config)
 
         observations = self._sim.reset()
diff --git a/test/test_dataset.py b/test/test_dataset.py
index 01d59f8da..20792a264 100644
--- a/test/test_dataset.py
+++ b/test/test_dataset.py
@@ -200,3 +200,11 @@ def test_sample_episodes():
     dataset = _construct_dataset(10000)
     with pytest.raises(Exception):
         dataset.sample_episodes(10001)
+
+
+def test_iterator_looping():
+    dataset = _construct_dataset(100)
+    episode_iter = dataset.get_episode_iterator()
+    for i in range(200):
+        episode = next(episode_iter)
+        assert episode.episode_id == dataset.episodes[i % 100].episode_id
diff --git a/test/test_pointnav_dataset.py b/test/test_pointnav_dataset.py
index 7a3954c64..c655fc062 100644
--- a/test/test_pointnav_dataset.py
+++ b/test/test_pointnav_dataset.py
@@ -124,7 +124,7 @@ def check_shortest_path(env, episode):
         len(episode.shortest_paths) == 1
     ), "Episode has no shortest paths or more than one."
 
-    env.episodes = [episode]
+    env.episode_iterator = iter([episode])
     env.reset()
     start_state = env.sim.get_agent_state()
     check_state(start_state, episode.start_position, episode.start_rotation)
@@ -165,7 +165,7 @@ def test_pointnav_episode_generator():
     ):
         episodes.append(episode)
     assert len(episodes) == 2 * NUM_EPISODES
-    env.episodes = episodes
+    env.episode_iterator = iter(episodes)
 
     for episode in episodes:
         check_shortest_path(env, episode)
diff --git a/test/test_sensors.py b/test/test_sensors.py
index 2b4d5971b..9372dacf9 100644
--- a/test/test_sensors.py
+++ b/test/test_sensors.py
@@ -25,15 +25,17 @@ def _random_episode(env, config):
         0,
         np.cos(random_heading / 2),
     ]
-    env.episodes = [
-        NavigationEpisode(
-            episode_id="0",
-            scene_id=config.SIMULATOR.SCENE,
-            start_position=random_location,
-            start_rotation=random_rotation,
-            goals=[],
-        )
-    ]
+    env.episode_iterator = iter(
+        [
+            NavigationEpisode(
+                episode_id="0",
+                scene_id=config.SIMULATOR.SCENE,
+                start_position=random_location,
+                start_rotation=random_rotation,
+                goals=[],
+            )
+        ]
+    )
 
 
 def test_heading_sensor():
@@ -56,15 +58,17 @@ def test_heading_sensor():
             0,
             np.cos(random_heading / 2),
         ]
-        env.episodes = [
-            NavigationEpisode(
-                episode_id="0",
-                scene_id=config.SIMULATOR.SCENE,
-                start_position=[03.00611, 0.072447, -2.67867],
-                start_rotation=random_rotation,
-                goals=[],
-            )
-        ]
+        env.episode_iterator = iter(
+            [
+                NavigationEpisode(
+                    episode_id="0",
+                    scene_id=config.SIMULATOR.SCENE,
+                    start_position=[03.00611, 0.072447, -2.67867],
+                    start_rotation=random_rotation,
+                    goals=[],
+                )
+            ]
+        )
 
         obs = env.reset()
         heading = obs["heading"]
@@ -167,15 +171,17 @@ def test_static_pointgoal_sensor():
     # corresponds to simulator using z-negative as forward action
     start_rotation = [0, 0, 0, 1]
 
-    env.episodes = [
-        NavigationEpisode(
-            episode_id="0",
-            scene_id=config.SIMULATOR.SCENE,
-            start_position=valid_start_position,
-            start_rotation=start_rotation,
-            goals=[NavigationGoal(position=goal_position)],
-        )
-    ]
+    env.episode_iterator = iter(
+        [
+            NavigationEpisode(
+                episode_id="0",
+                scene_id=config.SIMULATOR.SCENE,
+                start_position=valid_start_position,
+                start_rotation=start_rotation,
+                goals=[NavigationGoal(position=goal_position)],
+            )
+        ]
+    )
 
     non_stop_actions = [
         act
@@ -211,15 +217,17 @@ def test_get_observations_at():
     # corresponds to simulator using z-negative as forward action
     start_rotation = [0, 0, 0, 1]
 
-    env.episodes = [
-        NavigationEpisode(
-            episode_id="0",
-            scene_id=config.SIMULATOR.SCENE,
-            start_position=valid_start_position,
-            start_rotation=start_rotation,
-            goals=[NavigationGoal(position=goal_position)],
-        )
-    ]
+    env.episode_iterator = iter(
+        [
+            NavigationEpisode(
+                episode_id="0",
+                scene_id=config.SIMULATOR.SCENE,
+                start_position=valid_start_position,
+                start_rotation=start_rotation,
+                goals=[NavigationGoal(position=goal_position)],
+            )
+        ]
+    )
     non_stop_actions = [
         act
         for act in range(env.action_space.n)
-- 
GitLab