From 1d392504d9239dbdecdeac27da1d205d9074a147 Mon Sep 17 00:00:00 2001
From: Erik Wijmans <ewijmans2@gmail.com>
Date: Sat, 15 Jun 2019 01:53:48 -0700
Subject: [PATCH] Extendable action space (#110)

SimulatorActions is no longer an enum. It is an extendable singleton. It allows you to extend the action space but it is quite strict about it. Actions cannot be put on different ints, actions cannot be remove.
There is now a registry for action space configurations.
---
 examples/benchmark.py                         |   4 +-
 examples/new_actions.py                       | 155 ++++++++++++++++++
 examples/register_new_sensors_and_measures.py |   6 +-
 examples/shortest_path_follower_example.py    |   2 +-
 habitat/__init__.py                           |   8 +-
 habitat/config/default.py                     |   1 +
 habitat/core/registry.py                      |  29 +++-
 habitat/core/simulator.py                     |  94 +++++++++++
 habitat/core/utils.py                         |  11 ++
 habitat/datasets/utils.py                     |   9 +-
 habitat/sims/habitat_simulator/__init__.py    |  24 +++
 .../sims/habitat_simulator/action_spaces.py   |  55 +++++++
 .../habitat_simulator.py                      |  49 +-----
 habitat/sims/registration.py                  |  20 +--
 habitat/tasks/nav/shortest_path_follower.py   |  23 ++-
 habitat_baselines/agents/simple_agents.py     |  38 ++---
 habitat_baselines/agents/slam_agents.py       |  43 +++--
 habitat_baselines/train_ppo.py                |   5 +-
 notebooks/habitat-api-demo.ipynb              |  10 +-
 ...era_views_transform_and_warping_demo.ipynb |   2 +-
 test/test_demo_notebook.py                    |   6 +
 test/test_habitat_env.py                      |  41 +++--
 test/test_habitat_example.py                  |  10 ++
 test/test_pointnav_dataset.py                 |   4 +-
 test/test_relative_camera.py                  |   6 +
 test/test_sensors.py                          |  32 ++--
 test/test_trajectory_sim.py                   |   4 +-
 27 files changed, 507 insertions(+), 184 deletions(-)
 create mode 100644 examples/new_actions.py
 create mode 100644 habitat/sims/habitat_simulator/__init__.py
 create mode 100644 habitat/sims/habitat_simulator/action_spaces.py
 rename habitat/sims/{ => habitat_simulator}/habitat_simulator.py (91%)

diff --git a/examples/benchmark.py b/examples/benchmark.py
index 7a0da5bcd..5edfd6afc 100644
--- a/examples/benchmark.py
+++ b/examples/benchmark.py
@@ -7,7 +7,7 @@
 import argparse
 
 import habitat
-from habitat.sims.habitat_simulator import SimulatorActions
+from habitat import SimulatorActions
 
 
 class ForwardOnlyAgent(habitat.Agent):
@@ -15,7 +15,7 @@ class ForwardOnlyAgent(habitat.Agent):
         pass
 
     def act(self, observations):
-        action = SimulatorActions.MOVE_FORWARD.value
+        action = SimulatorActions.MOVE_FORWARD
         return action
 
 
diff --git a/examples/new_actions.py b/examples/new_actions.py
new file mode 100644
index 000000000..f03b1656b
--- /dev/null
+++ b/examples/new_actions.py
@@ -0,0 +1,155 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+r"""
+This is an example of how to add new actions to habitat-api
+
+
+We will use the strafe action outline in the habitat_sim example
+"""
+
+import attr
+import numpy as np
+
+import habitat
+import habitat_sim
+import habitat_sim.utils
+from habitat.sims.habitat_simulator.action_spaces import (
+    HabitatSimV1ActionSpaceConfiguration,
+)
+from habitat_sim.agent.controls import register_move_fn
+
+
+@attr.s(auto_attribs=True, slots=True)
+class NoisyStrafeActuationSpec:
+    move_amount: float
+    # Classic strafing is to move perpendicular (90 deg) to the forward direction
+    strafe_angle: float = 90.0
+    noise_amount: float = 0.05
+
+
+def _strafe_impl(
+    scene_node: habitat_sim.SceneNode,
+    move_amount: float,
+    strafe_angle: float,
+    noise_amount: float,
+):
+    forward_ax = (
+        scene_node.absolute_transformation()[0:3, 0:3] @ habitat_sim.geo.FRONT
+    )
+    strafe_angle = np.deg2rad(strafe_angle)
+    strafe_angle = np.random.uniform(
+        (1 - noise_amount) * strafe_angle, (1 + noise_amount) * strafe_angle
+    )
+
+    rotation = habitat_sim.utils.quat_from_angle_axis(
+        np.deg2rad(strafe_angle), habitat_sim.geo.UP
+    )
+    move_ax = habitat_sim.utils.quat_rotate_vector(rotation, forward_ax)
+
+    move_amount = np.random.uniform(
+        (1 - noise_amount) * move_amount, (1 + noise_amount) * move_amount
+    )
+    scene_node.translate_local(move_ax * move_amount)
+
+
+@register_move_fn(body_action=True)
+class NoisyStrafeLeft(habitat_sim.SceneNodeControl):
+    def __call__(
+        self,
+        scene_node: habitat_sim.SceneNode,
+        actuation_spec: NoisyStrafeActuationSpec,
+    ):
+        print(f"strafing left with noise_amount={actuation_spec.noise_amount}")
+        _strafe_impl(
+            scene_node,
+            actuation_spec.move_amount,
+            actuation_spec.strafe_angle,
+            actuation_spec.noise_amount,
+        )
+
+
+@register_move_fn(body_action=True)
+class NoisyStrafeRight(habitat_sim.SceneNodeControl):
+    def __call__(
+        self,
+        scene_node: habitat_sim.SceneNode,
+        actuation_spec: NoisyStrafeActuationSpec,
+    ):
+        print(
+            f"strafing right with noise_amount={actuation_spec.noise_amount}"
+        )
+        _strafe_impl(
+            scene_node,
+            actuation_spec.move_amount,
+            -actuation_spec.strafe_angle,
+            actuation_spec.noise_amount,
+        )
+
+
+@habitat.registry.register_action_space_configuration
+class NoNoiseStrafe(HabitatSimV1ActionSpaceConfiguration):
+    def get(self):
+        config = super().get()
+
+        config[habitat.SimulatorActions.STRAFE_LEFT] = habitat_sim.ActionSpec(
+            "noisy_strafe_left",
+            NoisyStrafeActuationSpec(0.25, noise_amount=0.0),
+        )
+        config[habitat.SimulatorActions.STRAFE_RIGHT] = habitat_sim.ActionSpec(
+            "noisy_strafe_right",
+            NoisyStrafeActuationSpec(0.25, noise_amount=0.0),
+        )
+
+        return config
+
+
+@habitat.registry.register_action_space_configuration
+class NoiseStrafe(HabitatSimV1ActionSpaceConfiguration):
+    def get(self):
+        config = super().get()
+
+        config[habitat.SimulatorActions.STRAFE_LEFT] = habitat_sim.ActionSpec(
+            "noisy_strafe_left",
+            NoisyStrafeActuationSpec(0.25, noise_amount=0.05),
+        )
+        config[habitat.SimulatorActions.STRAFE_RIGHT] = habitat_sim.ActionSpec(
+            "noisy_strafe_right",
+            NoisyStrafeActuationSpec(0.25, noise_amount=0.05),
+        )
+
+        return config
+
+
+def main():
+    habitat.SimulatorActions.extend_action_space("STRAFE_LEFT")
+    habitat.SimulatorActions.extend_action_space("STRAFE_RIGHT")
+
+    config = habitat.get_config(config_paths="configs/tasks/pointnav.yaml")
+    config.defrost()
+    config.SIMULATOR.ACTION_SPACE_CONFIG = "NoNoiseStrafe"
+    config.freeze()
+
+    env = habitat.Env(config=config)
+    env.reset()
+    env.step(habitat.SimulatorActions.STRAFE_LEFT)
+    env.step(habitat.SimulatorActions.STRAFE_RIGHT)
+    env.close()
+
+    config.defrost()
+    config.SIMULATOR.ACTION_SPACE_CONFIG = "NoiseStrafe"
+    config.freeze()
+
+    env = habitat.Env(config=config)
+    env.reset()
+    env.step(habitat.SimulatorActions.STRAFE_LEFT)
+    env.step(habitat.SimulatorActions.STRAFE_RIGHT)
+    env.close()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/examples/register_new_sensors_and_measures.py b/examples/register_new_sensors_and_measures.py
index a857f1866..0d17ed5f2 100644
--- a/examples/register_new_sensors_and_measures.py
+++ b/examples/register_new_sensors_and_measures.py
@@ -98,11 +98,7 @@ def main():
     env = habitat.Env(config=config)
     print(env.reset()["agent_position"])
     print(env.get_metrics()["episode_info"])
-    print(
-        env.step(
-            habitat.sims.habitat_simulator.SimulatorActions.MOVE_FORWARD.value
-        )["agent_position"]
-    )
+    print(env.step(habitat.SimulatorActions.MOVE_FORWARD)["agent_position"])
     print(env.get_metrics()["episode_info"])
 
 
diff --git a/examples/shortest_path_follower_example.py b/examples/shortest_path_follower_example.py
index e78cdd363..db781de5d 100644
--- a/examples/shortest_path_follower_example.py
+++ b/examples/shortest_path_follower_example.py
@@ -83,7 +83,7 @@ def shortest_path_example(mode):
             best_action = follower.get_next_action(
                 env.habitat_env.current_episode.goals[0].position
             )
-            observations, reward, done, info = env.step(best_action.value)
+            observations, reward, done, info = env.step(best_action)
             im = observations["rgb"]
             top_down_map = draw_top_down_map(
                 info, observations["heading"], im.shape[0]
diff --git a/habitat/__init__.py b/habitat/__init__.py
index a37e4dc41..2b82de2c5 100644
--- a/habitat/__init__.py
+++ b/habitat/__init__.py
@@ -13,7 +13,13 @@ from habitat.core.embodied_task import EmbodiedTask, Measure, Measurements
 from habitat.core.env import Env, RLEnv
 from habitat.core.logging import logger
 from habitat.core.registry import registry
-from habitat.core.simulator import Sensor, SensorSuite, SensorTypes, Simulator
+from habitat.core.simulator import (
+    Sensor,
+    SensorSuite,
+    SensorTypes,
+    Simulator,
+    SimulatorActions,
+)
 from habitat.core.vector_env import ThreadedVectorEnv, VectorEnv
 from habitat.datasets import make_dataset
 from habitat.version import VERSION as __version__  # noqa
diff --git a/habitat/config/default.py b/habitat/config/default.py
index cd95808cf..8f90919ae 100644
--- a/habitat/config/default.py
+++ b/habitat/config/default.py
@@ -81,6 +81,7 @@ _C.TASK.COLLISIONS.TYPE = "Collisions"
 # -----------------------------------------------------------------------------
 _C.SIMULATOR = CN()
 _C.SIMULATOR.TYPE = "Sim-v0"
+_C.SIMULATOR.ACTION_SPACE_CONFIG = "v0"
 _C.SIMULATOR.FORWARD_STEP_SIZE = 0.25  # in metres
 _C.SIMULATOR.SCENE = (
     "data/scene_datasets/habitat-test-scenes/" "van-gogh-room.glb"
diff --git a/habitat/core/registry.py b/habitat/core/registry.py
index b77ccb4af..bf224e534 100644
--- a/habitat/core/registry.py
+++ b/habitat/core/registry.py
@@ -26,8 +26,10 @@ Various decorators for registry different kind of classes with unique keys
 import collections
 from typing import Optional
 
+from habitat.core.utils import Singleton
 
-class _Registry:
+
+class Registry(metaclass=Singleton):
     mapping = collections.defaultdict(dict)
 
     @classmethod
@@ -160,6 +162,25 @@ class _Registry:
             "dataset", to_register, name, assert_type=Dataset
         )
 
+    @classmethod
+    def register_action_space_configuration(
+        cls, to_register=None, *, name: Optional[str] = None
+    ):
+        r"""Register a action space configuration to registry with key 'name'
+
+        Args:
+            name: Key with which the action space will be registered.
+                If None will use the name of the class
+        """
+        from habitat.core.simulator import ActionSpaceConfiguration
+
+        return cls._register_impl(
+            "action_space_config",
+            to_register,
+            name,
+            assert_type=ActionSpaceConfiguration,
+        )
+
     @classmethod
     def _get_impl(cls, _type, name):
         return cls.mapping[_type].get(name, None)
@@ -184,5 +205,9 @@ class _Registry:
     def get_dataset(cls, name):
         return cls._get_impl("dataset", name)
 
+    @classmethod
+    def get_action_space_configuration(cls, name):
+        return cls._get_impl("action_space_config", name)
+
 
-registry = _Registry()
+registry = Registry()
diff --git a/habitat/core/simulator.py b/habitat/core/simulator.py
index b199ba5f3..a9b670fd2 100644
--- a/habitat/core/simulator.py
+++ b/habitat/core/simulator.py
@@ -4,14 +4,101 @@
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
 
+import abc
 from collections import OrderedDict
 from enum import Enum
 from typing import Any, Dict, List, Optional
 
+import attr
 from gym import Space
 from gym.spaces.dict_space import Dict as SpaceDict
 
 from habitat.config import Config
+from habitat.core.utils import Singleton
+
+
+@attr.s(auto_attribs=True)
+class ActionSpaceConfiguration(abc.ABC):
+    config: Config
+
+    @abc.abstractmethod
+    def get(self):
+        pass
+
+
+class _DefaultSimulatorActions(Enum):
+    STOP = 0
+    MOVE_FORWARD = 1
+    TURN_LEFT = 2
+    TURN_RIGHT = 3
+    LOOK_UP = 4
+    LOOK_DOWN = 5
+
+
+@attr.s(auto_attribs=True, slots=True)
+class SimulatorActionsSingleton(metaclass=Singleton):
+    r"""Implements an extendable Enum for the mapping of action names
+    to their integer values.  
+
+    This means that new action names can be added, but old action names cannot be
+    removed nor can their mapping be altered.  This also ensures that all actions
+    are always contigously mapped in ``[0, len(SimulatorActions) - 1]``
+
+    This accesible as the global singleton SimulatorActions
+    """
+
+    _known_actions: Dict[str, int] = attr.ib(init=False, factory=dict)
+
+    def __attrs_post_init__(self):
+        for action in _DefaultSimulatorActions:
+            self._known_actions[action.name] = action.value
+
+    def extend_action_space(self, name: str) -> int:
+        r"""Extends the action space to accomidate a new action with
+        the name ``name``
+
+        Args
+            name (str): The name of the new action
+
+        Returns
+            int: The number the action is registered on
+
+        Usage::
+
+            from habitat import SimulatorActions
+            SimulatorActions.extend_action_space("MY_ACTION")
+            print(SimulatorActions.MY_ACTION)
+        """
+        assert (
+            name not in self._known_actions
+        ), "Cannot register an action name twice"
+        self._known_actions[name] = len(self._known_actions)
+
+        return self._known_actions[name]
+
+    def has_action(self, name: str) -> bool:
+        r"""Checks to see if action ``name`` is already register
+
+        Args
+            name (str): The name to check
+
+        Returns
+            bool: Whether or not ``name`` already exists
+        """
+
+        return name in self._known_actions
+
+    def __getattr__(self, name):
+        return self._known_actions[name]
+
+    def __getitem__(self, name):
+        return self._known_actions[name]
+
+    def __len__(self):
+        return len(self._known_actions)
+
+
+SimulatorActions = SimulatorActionsSingleton()
 
 
 class SensorTypes(Enum):
@@ -368,6 +455,13 @@ class Simulator:
         raise NotImplementedError
 
     @property
+    def index_stop_action(self):
+        return SimulatorActions.STOP
+
+    @property
+    def index_forward_action(self):
+        return SimulatorActions.MOVE_FORWARD
+
     def previous_step_collided(self):
         r"""Whether or not the previous step resulted in a collision
 
diff --git a/habitat/core/utils.py b/habitat/core/utils.py
index b5e338fc3..e2fa28b7c 100644
--- a/habitat/core/utils.py
+++ b/habitat/core/utils.py
@@ -45,3 +45,14 @@ def tile_images(images: List[np.ndarray]) -> np.ndarray:
 def not_none_validator(self, attribute, value):
     if value is None:
         raise ValueError(f"Argument '{attribute.name}' must be set")
+
+
+class Singleton(type):
+    _instances = {}
+
+    def __call__(cls, *args, **kwargs):
+        if cls not in cls._instances:
+            cls._instances[cls] = super(Singleton, cls).__call__(
+                *args, **kwargs
+            )
+        return cls._instances[cls]
diff --git a/habitat/datasets/utils.py b/habitat/datasets/utils.py
index 67395525c..9684d5371 100644
--- a/habitat/datasets/utils.py
+++ b/habitat/datasets/utils.py
@@ -7,8 +7,7 @@
 from typing import List
 
 from habitat.core.logging import logger
-from habitat.core.simulator import ShortestPathPoint
-from habitat.sims.habitat_simulator import SimulatorActions
+from habitat.core.simulator import ShortestPathPoint, SimulatorActions
 from habitat.tasks.nav.shortest_path_follower import ShortestPathFollower
 from habitat.utils.geometry_utils import quaternion_to_list
 
@@ -30,17 +29,17 @@ def get_action_shortest_path(
     shortest_path = []
     action = None
     step_count = 0
-    while action != SimulatorActions.STOP and step_count < max_episode_steps:
+    while action != sim.index_stop_action and step_count < max_episode_steps:
         action = follower.get_next_action(goal_position)
         state = sim.get_agent_state()
         shortest_path.append(
             ShortestPathPoint(
                 state.position.tolist(),
                 quaternion_to_list(state.rotation),
-                action.value,
+                action,
             )
         )
-        sim.step(action.value)
+        sim.step(action)
         step_count += 1
     if step_count == max_episode_steps:
         logger.warning("Shortest path wasn't found.")
diff --git a/habitat/sims/habitat_simulator/__init__.py b/habitat/sims/habitat_simulator/__init__.py
new file mode 100644
index 000000000..36c03883f
--- /dev/null
+++ b/habitat/sims/habitat_simulator/__init__.py
@@ -0,0 +1,24 @@
+from habitat.core.registry import registry
+from habitat.core.simulator import Simulator
+
+
+def _try_register_habitat_sim():
+    try:
+        import habitat_sim
+
+        has_habitat_sim = True
+    except ImportError as e:
+        has_habitat_sim = False
+        habitat_sim_import_error = e
+
+    if has_habitat_sim:
+        from habitat.sims.habitat_simulator.habitat_simulator import HabitatSim
+        from habitat.sims.habitat_simulator.action_spaces import (
+            HabitatSimV1ActionSpaceConfiguration,
+        )
+    else:
+
+        @registry.register_simulator(name="Sim-v0")
+        class HabitatSimImportError(Simulator):
+            def __init__(self, *args, **kwargs):
+                raise habitat_sim_import_error
diff --git a/habitat/sims/habitat_simulator/action_spaces.py b/habitat/sims/habitat_simulator/action_spaces.py
new file mode 100644
index 000000000..4cee1cbeb
--- /dev/null
+++ b/habitat/sims/habitat_simulator/action_spaces.py
@@ -0,0 +1,55 @@
+from enum import Enum
+
+import attr
+
+import habitat_sim
+from habitat.core.registry import registry
+from habitat.core.simulator import (
+    ActionSpaceConfiguration,
+    Config,
+    SimulatorActions,
+)
+
+
+@registry.register_action_space_configuration(name="v0")
+class HabitatSimV0ActionSpaceConfiguration(ActionSpaceConfiguration):
+    def get(self):
+        return {
+            SimulatorActions.STOP: habitat_sim.ActionSpec("stop"),
+            SimulatorActions.MOVE_FORWARD: habitat_sim.ActionSpec(
+                "move_forward",
+                habitat_sim.ActuationSpec(
+                    amount=self.config.FORWARD_STEP_SIZE
+                ),
+            ),
+            SimulatorActions.TURN_LEFT: habitat_sim.ActionSpec(
+                "turn_left",
+                habitat_sim.ActuationSpec(amount=self.config.TURN_ANGLE),
+            ),
+            SimulatorActions.TURN_RIGHT: habitat_sim.ActionSpec(
+                "turn_right",
+                habitat_sim.ActuationSpec(amount=self.config.TURN_ANGLE),
+            ),
+        }
+
+
+@registry.register_action_space_configuration(name="v1")
+class HabitatSimV1ActionSpaceConfiguration(
+    HabitatSimV0ActionSpaceConfiguration
+):
+    def get(self):
+        config = super().get()
+        new_config = {
+            SimulatorActions.LOOK_UP: habitat_sim.ActionSpec(
+                "look_up",
+                habitat_sim.ActuationSpec(amount=self.config.TILT_ANGLE),
+            ),
+            SimulatorActions.LOOK_DOWN: habitat_sim.ActionSpec(
+                "look_down",
+                habitat_sim.ActuationSpec(amount=self.config.TILT_ANGLE),
+            ),
+        }
+
+        config.update(new_config)
+
+        return config
diff --git a/habitat/sims/habitat_simulator.py b/habitat/sims/habitat_simulator/habitat_simulator.py
similarity index 91%
rename from habitat/sims/habitat_simulator.py
rename to habitat/sims/habitat_simulator/habitat_simulator.py
index 32da8af5b..b5b896b34 100644
--- a/habitat/sims/habitat_simulator.py
+++ b/habitat/sims/habitat_simulator/habitat_simulator.py
@@ -24,6 +24,7 @@ from habitat.core.simulator import (
     SensorSuite,
     ShortestPathPoint,
     Simulator,
+    SimulatorActions,
 )
 
 RGBSENSOR_DIMENSION = 3
@@ -42,15 +43,6 @@ def check_sim_obs(obs, sensor):
     )
 
 
-class SimulatorActions(Enum):
-    STOP = 0
-    MOVE_FORWARD = 1
-    TURN_LEFT = 2
-    TURN_RIGHT = 3
-    LOOK_UP = 4
-    LOOK_DOWN = 5
-
-
 @registry.register_sensor
 class HabitatSimRGBSensor(RGBSensor):
     sim_sensor_type: habitat_sim.SensorType
@@ -205,31 +197,9 @@ class HabitatSim(Simulator):
             sensor_specifications.append(sim_sensor_cfg)
 
         agent_config.sensor_specifications = sensor_specifications
-        agent_config.action_space = {
-            SimulatorActions.STOP.value: habitat_sim.ActionSpec("stop"),
-            SimulatorActions.MOVE_FORWARD.value: habitat_sim.ActionSpec(
-                "move_forward",
-                habitat_sim.ActuationSpec(
-                    amount=self.config.FORWARD_STEP_SIZE
-                ),
-            ),
-            SimulatorActions.TURN_LEFT.value: habitat_sim.ActionSpec(
-                "turn_left",
-                habitat_sim.ActuationSpec(amount=self.config.TURN_ANGLE),
-            ),
-            SimulatorActions.TURN_RIGHT.value: habitat_sim.ActionSpec(
-                "turn_right",
-                habitat_sim.ActuationSpec(amount=self.config.TURN_ANGLE),
-            ),
-            SimulatorActions.LOOK_UP.value: habitat_sim.ActionSpec(
-                "look_up",
-                habitat_sim.ActuationSpec(amount=self.config.TILT_ANGLE),
-            ),
-            SimulatorActions.LOOK_DOWN.value: habitat_sim.ActionSpec(
-                "look_down",
-                habitat_sim.ActuationSpec(amount=self.config.TILT_ANGLE),
-            ),
-        }
+        agent_config.action_space = registry.get_action_space_configuration(
+            self.config.ACTION_SPACE_CONFIG
+        )(self.config).get()
 
         return habitat_sim.Configuration(sim_config, [agent_config])
 
@@ -256,6 +226,7 @@ class HabitatSim(Simulator):
                     agent_id,
                 )
                 is_updated = True
+
         return is_updated
 
     def reset(self):
@@ -273,7 +244,7 @@ class HabitatSim(Simulator):
             "STOP action called previously"
         )
 
-        if action == SimulatorActions.STOP.value:
+        if action == self.index_stop_action:
             self._is_episode_active = False
             sim_obs = self._sim.get_sensor_observations()
         else:
@@ -398,14 +369,6 @@ class HabitatSim(Simulator):
     def close(self):
         self._sim.close()
 
-    @property
-    def index_stop_action(self):
-        return SimulatorActions.STOP.value
-
-    @property
-    def index_forward_action(self):
-        return SimulatorActions.MOVE_FORWARD.value
-
     def _get_agent_config(self, agent_id: Optional[int] = None) -> Any:
         if agent_id is None:
             agent_id = self.config.DEFAULT_AGENT_ID
diff --git a/habitat/sims/registration.py b/habitat/sims/registration.py
index 0f59808d5..d756906d5 100644
--- a/habitat/sims/registration.py
+++ b/habitat/sims/registration.py
@@ -7,25 +7,7 @@
 from habitat.core.logging import logger
 from habitat.core.registry import registry
 from habitat.core.simulator import Simulator
-
-
-def _try_register_habitat_sim():
-    try:
-        import habitat_sim
-
-        has_habitat_sim = True
-    except ImportError as e:
-        has_habitat_sim = False
-        habitat_sim_import_error = e
-
-    if has_habitat_sim:
-        from habitat.sims.habitat_simulator import HabitatSim
-    else:
-
-        @registry.register_simulator(name="Sim-v0")
-        class HabitatSimImportError(Simulator):
-            def __init__(self, *args, **kwargs):
-                raise habitat_sim_import_error
+from habitat.sims.habitat_simulator import _try_register_habitat_sim
 
 
 def make_sim(id_sim, **kwargs):
diff --git a/habitat/tasks/nav/shortest_path_follower.py b/habitat/tasks/nav/shortest_path_follower.py
index a358cab5b..2d7b2f184 100644
--- a/habitat/tasks/nav/shortest_path_follower.py
+++ b/habitat/tasks/nav/shortest_path_follower.py
@@ -9,7 +9,8 @@ from typing import Union
 import numpy as np
 
 import habitat_sim
-from habitat.sims.habitat_simulator import HabitatSim, SimulatorActions
+from habitat.core.simulator import SimulatorActions
+from habitat.sims.habitat_simulator.habitat_simulator import HabitatSim
 from habitat.utils.geometry_utils import (
     angle_between_quaternions,
     quaternion_from_two_vectors,
@@ -58,18 +59,14 @@ class ShortestPathFollower:
         )
         self._return_one_hot = return_one_hot
 
-    def _get_return_value(
-        self, action: SimulatorActions
-    ) -> Union[SimulatorActions, np.array]:
+    def _get_return_value(self, action) -> Union[int, np.array]:
         if self._return_one_hot:
-            return action_to_one_hot(action.value)
+            return action_to_one_hot(action)
         else:
             return action
 
-    def get_next_action(
-        self, goal_pos: np.array
-    ) -> Union[SimulatorActions, np.array]:
-        r"""Returns the next action along the shortest path.
+    def get_next_action(self, goal_pos: np.array) -> Union[int, np.array]:
+        """Returns the next action along the shortest path.
         """
         if (
             np.linalg.norm(goal_pos - self._sim.get_agent_state().position)
@@ -84,13 +81,13 @@ class ShortestPathFollower:
 
     def _step_along_grad(
         self, grad_dir: np.quaternion
-    ) -> Union[SimulatorActions, np.array]:
+    ) -> Union[int, np.array]:
         current_state = self._sim.get_agent_state()
         alpha = angle_between_quaternions(grad_dir, current_state.rotation)
         if alpha <= np.deg2rad(self._sim.config.TURN_ANGLE) + EPSILON:
             return self._get_return_value(SimulatorActions.MOVE_FORWARD)
         else:
-            sim_action = SimulatorActions.TURN_LEFT.value
+            sim_action = SimulatorActions.TURN_LEFT
             self._sim.step(sim_action)
             best_turn = (
                 SimulatorActions.TURN_LEFT
@@ -144,7 +141,7 @@ class ShortestPathFollower:
             best_geodesic_delta = -2 * self._max_delta
             best_rotation = current_rotation
             for _ in range(0, 360, self._sim.config.TURN_ANGLE):
-                sim_action = SimulatorActions.MOVE_FORWARD.value
+                sim_action = SimulatorActions.MOVE_FORWARD
                 self._sim.step(sim_action)
                 new_delta = current_dist - self._geo_dist(goal_pos)
 
@@ -168,7 +165,7 @@ class ShortestPathFollower:
                     reset_sensors=False,
                 )
 
-                sim_action = SimulatorActions.TURN_LEFT.value
+                sim_action = SimulatorActions.TURN_LEFT
                 self._sim.step(sim_action)
 
             self._reset_agent_state(current_state)
diff --git a/habitat_baselines/agents/simple_agents.py b/habitat_baselines/agents/simple_agents.py
index 733b82e4c..83568ab7c 100644
--- a/habitat_baselines/agents/simple_agents.py
+++ b/habitat_baselines/agents/simple_agents.py
@@ -11,13 +11,8 @@ from math import pi
 import numpy as np
 
 import habitat
-from habitat.config import Config
+from habitat import SimulatorActions
 from habitat.config.default import get_config
-from habitat.sims.habitat_simulator import SimulatorActions
-
-NON_STOP_ACTIONS = [
-    v for v in range(len(SimulatorActions)) if v != SimulatorActions.STOP.value
-]
 
 
 class RandomAgent(habitat.Agent):
@@ -34,18 +29,24 @@ class RandomAgent(habitat.Agent):
 
     def act(self, observations):
         if self.is_goal_reached(observations):
-            action = SimulatorActions.STOP.value
+            action = SimulatorActions.STOP
         else:
-            action = np.random.choice(NON_STOP_ACTIONS)
+            action = np.random.choice(
+                [
+                    SimulatorActions.MOVE_FORWARD,
+                    SimulatorActions.TURN_LEFT,
+                    SimulatorActions.TURN_RIGHT,
+                ]
+            )
         return action
 
 
 class ForwardOnlyAgent(RandomAgent):
     def act(self, observations):
         if self.is_goal_reached(observations):
-            action = SimulatorActions.STOP.value
+            action = SimulatorActions.STOP
         else:
-            action = SimulatorActions.MOVE_FORWARD.value
+            action = SimulatorActions.MOVE_FORWARD
         return action
 
 
@@ -56,16 +57,13 @@ class RandomForwardAgent(RandomAgent):
 
     def act(self, observations):
         if self.is_goal_reached(observations):
-            action = SimulatorActions.STOP.value
+            action = SimulatorActions.STOP
         else:
             if np.random.uniform(0, 1, 1) < self.FORWARD_PROBABILITY:
-                action = SimulatorActions.MOVE_FORWARD.value
+                action = SimulatorActions.MOVE_FORWARD
             else:
                 action = np.random.choice(
-                    [
-                        SimulatorActions.TURN_LEFT.value,
-                        SimulatorActions.TURN_RIGHT.value,
-                    ]
+                    [SimulatorActions.TURN_LEFT, SimulatorActions.TURN_RIGHT]
                 )
 
         return action
@@ -89,20 +87,20 @@ class GoalFollower(RandomAgent):
         if angle_to_goal > pi or (
             (angle_to_goal < 0) and (angle_to_goal > -pi)
         ):
-            action = SimulatorActions.TURN_RIGHT.value
+            action = SimulatorActions.TURN_RIGHT
         else:
-            action = SimulatorActions.TURN_LEFT.value
+            action = SimulatorActions.TURN_LEFT
         return action
 
     def act(self, observations):
         if self.is_goal_reached(observations):
-            action = SimulatorActions.STOP.value
+            action = SimulatorActions.STOP
         else:
             angle_to_goal = self.normalize_angle(
                 np.array(observations[self.goal_sensor_uuid][1])
             )
             if abs(angle_to_goal) < self.angle_th:
-                action = SimulatorActions.MOVE_FORWARD.value
+                action = SimulatorActions.MOVE_FORWARD
             else:
                 action = self.turn_towards_goal(angle_to_goal)
 
diff --git a/habitat_baselines/agents/slam_agents.py b/habitat_baselines/agents/slam_agents.py
index 846584062..9d7b291a5 100644
--- a/habitat_baselines/agents/slam_agents.py
+++ b/habitat_baselines/agents/slam_agents.py
@@ -19,8 +19,8 @@ import torch.nn.functional as F
 
 import habitat
 import orbslam2
+from habitat import SimulatorActions
 from habitat.config.default import get_config
-from habitat.sims.habitat_simulator import SimulatorActions
 from habitat_baselines.config.default import get_config as cfg_baseline
 from habitat_baselines.slambased.mappers import DirectDepthMapper
 from habitat_baselines.slambased.monodepth import MonoDepthEstimator
@@ -115,7 +115,7 @@ class RandomAgent(object):
         # Act
         # Check if we are done
         if self.is_goal_reached():
-            action = SimulatorActions.STOP.value
+            action = SimulatorActions.STOP
         else:
             action = random.randint(0, self.num_actions - 1)
         return action
@@ -132,20 +132,20 @@ class BlindAgent(RandomAgent):
     def decide_what_to_do(self):
         distance_to_goal = self.obs["pointgoal"][0]
         angle_to_goal = norm_ang(np.array(self.obs["pointgoal"][1]))
-        command = SimulatorActions.STOP.value
+        command = SimulatorActions.STOP
         if distance_to_goal <= self.pos_th:
             return command
         if abs(angle_to_goal) < self.angle_th:
-            command = SimulatorActions.MOVE_FORWARD.value
+            command = SimulatorActions.MOVE_FORWARD
         else:
             if (angle_to_goal > 0) and (angle_to_goal < pi):
-                command = SimulatorActions.TURN_LEFT.value
+                command = SimulatorActions.TURN_LEFT
             elif angle_to_goal > pi:
-                command = SimulatorActions.TURN_RIGHT.value
+                command = SimulatorActions.TURN_RIGHT
             elif (angle_to_goal < 0) and (angle_to_goal > -pi):
-                command = SimulatorActions.TURN_RIGHT.value
+                command = SimulatorActions.TURN_RIGHT
             else:
-                command = SimulatorActions.TURN_LEFT.value
+                command = SimulatorActions.TURN_LEFT
 
         return command
 
@@ -153,7 +153,7 @@ class BlindAgent(RandomAgent):
         self.update_internal_state(habitat_observation)
         # Act
         if self.is_goal_reached():
-            return SimulatorActions.STOP.value
+            return SimulatorActions.STOP
         command = self.decide_what_to_do()
         random_action = random.randint(0, self.num_actions - 1)
         act_randomly = np.random.uniform(0, 1, 1) < random_prob
@@ -266,10 +266,7 @@ class ORBSLAM2Agent(RandomAgent):
                     .view(4, 4)
                     .to(self.device),
                 )
-                if (
-                    self.action_history[-1]
-                    == SimulatorActions.MOVE_FORWARD.value
-                ):
+                if self.action_history[-1] == SimulatorActions.MOVE_FORWARD:
                     self.unseen_obstacle = (
                         previous_step.item() <= 0.001
                     )  # hardcoded threshold for not moving
@@ -333,7 +330,7 @@ class ORBSLAM2Agent(RandomAgent):
         )
         success = self.is_goal_reached()
         if success:
-            action = SimulatorActions.STOP.value
+            action = SimulatorActions.STOP
             self.action_history.append(action)
             return action
         # Plan action
@@ -485,7 +482,7 @@ class ORBSLAM2Agent(RandomAgent):
         return path, planned_waypoints
 
     def planner_prediction_to_command(self, p_next):
-        command = SimulatorActions.STOP.value
+        command = SimulatorActions.STOP
         p_init = self.pose6D.squeeze()
         d_angle_rot_th = self.angle_th
         pos_th = self.pos_th
@@ -495,27 +492,27 @@ class ORBSLAM2Agent(RandomAgent):
             get_direction(p_init, p_next, ang_th=d_angle_rot_th, pos_th=pos_th)
         )
         if abs(d_angle) < d_angle_rot_th:
-            command = SimulatorActions.MOVE_FORWARD.value
+            command = SimulatorActions.MOVE_FORWARD
         else:
             if (d_angle > 0) and (d_angle < pi):
-                command = SimulatorActions.TURN_LEFT.value
+                command = SimulatorActions.TURN_LEFT
             elif d_angle > pi:
-                command = SimulatorActions.TURN_RIGHT.value
+                command = SimulatorActions.TURN_RIGHT
             elif (d_angle < 0) and (d_angle > -pi):
-                command = SimulatorActions.TURN_RIGHT.value
+                command = SimulatorActions.TURN_RIGHT
             else:
-                command = SimulatorActions.TURN_LEFT.value
+                command = SimulatorActions.TURN_LEFT
         return command
 
     def decide_what_to_do(self):
         action = None
         if self.is_goal_reached():
-            action = SimulatorActions.STOP.value
+            action = SimulatorActions.STOP
             return action
         if self.unseen_obstacle:
-            command = SimulatorActions.TURN_RIGHT.value
+            command = SimulatorActions.TURN_RIGHT
             return command
-        command = SimulatorActions.STOP.value
+        command = SimulatorActions.STOP
         command = self.planner_prediction_to_command(self.waypointPose6D)
         return command
 
diff --git a/habitat_baselines/train_ppo.py b/habitat_baselines/train_ppo.py
index 97efa74a6..c57d79a7f 100644
--- a/habitat_baselines/train_ppo.py
+++ b/habitat_baselines/train_ppo.py
@@ -14,10 +14,9 @@ import torch
 
 import habitat
 from config.default import get_config as cfg_baseline
-from habitat import logger
+from habitat import SimulatorActions, logger
 from habitat.config.default import get_config as cfg_env
 from habitat.datasets.registration import make_dataset
-from habitat.sims.habitat_simulator import SimulatorActions
 from rl.ppo import PPO, Policy, RolloutStorage
 from rl.ppo.utils import batch_obs, ppo_args, update_linear_schedule
 
@@ -73,7 +72,7 @@ class NavRLEnv(habitat.RLEnv):
 
     def _episode_success(self):
         if (
-            self._previous_action == SimulatorActions.STOP.value
+            self._previous_action == SimulatorActions.STOP
             and self._distance_target() < self._config_env.SUCCESS_DISTANCE
         ):
             return True
diff --git a/notebooks/habitat-api-demo.ipynb b/notebooks/habitat-api-demo.ipynb
index d41b7f1d9..bd9ed01c9 100644
--- a/notebooks/habitat-api-demo.ipynb
+++ b/notebooks/habitat-api-demo.ipynb
@@ -39,7 +39,7 @@
     "\n",
     "config = habitat.get_config(config_paths='../configs/tasks/pointnav_mp3d.yaml')\n",
     "config.defrost()\n",
-    "config.DATASET.POINTNAVV1.DATA_PATH = '../data/datasets/pointnav/mp3d/v1/val/val.json.gz'\n",
+    "config.DATASET.DATA_PATH = '../data/datasets/pointnav/mp3d/v1/val/val.json.gz'\n",
     "config.DATASET.SCENES_DIR = '../data/scene_datasets/'\n",
     "config.freeze()\n",
     "\n",
@@ -236,7 +236,7 @@
    "source": [
     "config = habitat.get_config(config_paths='../configs/tasks/pointnav_mp3d.yaml')\n",
     "config.defrost()\n",
-    "config.DATASET.POINTNAVV1.DATA_PATH = '../data/datasets/pointnav/mp3d/v1/val/val.json.gz'\n",
+    "config.DATASET.DATA_PATH = '../data/datasets/pointnav/mp3d/v1/val/val.json.gz'\n",
     "config.DATASET.SCENES_DIR = '../data/scene_datasets/'\n",
     "config.SIMULATOR.AGENT_0.SENSORS = ['RGB_SENSOR', 'DEPTH_SENSOR', 'SEMANTIC_SENSOR']\n",
     "config.SIMULATOR.SEMANTIC_SENSOR.WIDTH = 256\n",
@@ -283,9 +283,9 @@
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "notebooks",
+   "display_name": "Python 3",
    "language": "python",
-   "name": "notebooks"
+   "name": "python3"
   },
   "language_info": {
    "codemirror_mode": {
@@ -297,7 +297,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.6.8"
+   "version": "3.7.3"
   }
  },
  "nbformat": 4,
diff --git a/notebooks/relative_camera_views_transform_and_warping_demo.ipynb b/notebooks/relative_camera_views_transform_and_warping_demo.ipynb
index 57d8e07bb..769d1bd34 100644
--- a/notebooks/relative_camera_views_transform_and_warping_demo.ipynb
+++ b/notebooks/relative_camera_views_transform_and_warping_demo.ipynb
@@ -46,7 +46,7 @@
     "# Set up the environment for testing\n",
     "config = habitat.get_config(config_paths=\"../configs/tasks/pointnav_rgbd.yaml\")\n",
     "config.defrost()\n",
-    "config.DATASET.POINTNAVV1.DATA_PATH = '../data/datasets/pointnav/habitat-test-scenes/v1/val/val.json.gz'\n",
+    "config.DATASET.DATA_PATH = '../data/datasets/pointnav/habitat-test-scenes/v1/val/val.json.gz'\n",
     "config.DATASET.SCENES_DIR = '../data/scene_datasets/'\n",
     "config.freeze()\n",
     "\n",
diff --git a/test/test_demo_notebook.py b/test/test_demo_notebook.py
index ca4e5ad23..fd9168df9 100644
--- a/test/test_demo_notebook.py
+++ b/test/test_demo_notebook.py
@@ -3,6 +3,7 @@
 # Copyright (c) Facebook, Inc. and its affiliates.
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
+import gc
 
 import pytest
 
@@ -21,3 +22,8 @@ def test_demo_notebook():
         )
     else:
         pytest.main(["--nbval-lax", "notebooks/habitat-api-demo.ipynb"])
+
+        # NB: Force a gc collect run as it can take a little bit for
+        # the cleanup to happen after the notebook and we get
+        # a double context crash!
+        gc.collect()
diff --git a/test/test_habitat_env.py b/test/test_habitat_env.py
index 7e8b08e77..61851e54f 100644
--- a/test/test_habitat_env.py
+++ b/test/test_habitat_env.py
@@ -12,9 +12,8 @@ import pytest
 
 import habitat
 from habitat.config.default import get_config
-from habitat.core.simulator import AgentState
+from habitat.core.simulator import AgentState, SimulatorActions
 from habitat.datasets.pointnav.pointnav_dataset import PointNavDatasetV1
-from habitat.sims.habitat_simulator import SimulatorActions
 from habitat.tasks.nav.nav_task import NavigationEpisode, NavigationGoal
 
 CFG_TEST = "configs/test/habitat_all_sensors_test.yaml"
@@ -81,9 +80,9 @@ def _vec_env_test_fn(configs, datasets, multiprocessing_start_method):
     )
     envs.reset()
     non_stop_actions = [
-        v
-        for v in range(len(SimulatorActions))
-        if v != SimulatorActions.STOP.value
+        act
+        for act in range(envs.action_spaces[0].n)
+        if act != SimulatorActions.STOP
     ]
 
     for _ in range(2 * configs[0].ENVIRONMENT.MAX_EPISODE_STEPS):
@@ -135,9 +134,9 @@ def test_threaded_vectorized_env():
     envs = habitat.ThreadedVectorEnv(env_fn_args=env_fn_args)
     envs.reset()
     non_stop_actions = [
-        v
-        for v in range(len(SimulatorActions))
-        if v != SimulatorActions.STOP.value
+        act
+        for act in range(envs.action_spaces[0].n)
+        if act != SimulatorActions.STOP
     ]
 
     for i in range(2 * configs[0].ENVIRONMENT.MAX_EPISODE_STEPS):
@@ -167,9 +166,9 @@ def test_env():
     env.reset()
 
     non_stop_actions = [
-        v
-        for v in range(len(SimulatorActions))
-        if v != SimulatorActions.STOP.value
+        act
+        for act in range(env.action_space.n)
+        if act != SimulatorActions.STOP
     ]
     for _ in range(config.ENVIRONMENT.MAX_EPISODE_STEPS):
         act = np.random.choice(non_stop_actions)
@@ -182,7 +181,7 @@ def test_env():
 
     env.reset()
 
-    env.step(SimulatorActions.STOP.value)
+    env.step(SimulatorActions.STOP)
     # check for STOP action
     assert env.episode_over is True, (
         "episode should be over after STOP " "action"
@@ -211,9 +210,9 @@ def test_rl_vectorized_envs():
     envs = habitat.VectorEnv(make_env_fn=make_rl_env, env_fn_args=env_fn_args)
     envs.reset()
     non_stop_actions = [
-        v
-        for v in range(len(SimulatorActions))
-        if v != SimulatorActions.STOP.value
+        act
+        for act in range(envs.action_spaces[0].n)
+        if act != SimulatorActions.STOP
     ]
 
     for i in range(2 * configs[0].ENVIRONMENT.MAX_EPISODE_STEPS):
@@ -263,9 +262,9 @@ def test_rl_env():
     observation = env.reset()
 
     non_stop_actions = [
-        v
-        for v in range(len(SimulatorActions))
-        if v != SimulatorActions.STOP.value
+        act
+        for act in range(env.action_space.n)
+        if act != SimulatorActions.STOP
     ]
     for _ in range(config.ENVIRONMENT.MAX_EPISODE_STEPS):
         observation, reward, done, info = env.step(
@@ -276,7 +275,7 @@ def test_rl_env():
     assert done is True, "episodes should be over after max_episode_steps"
 
     env.reset()
-    observation, reward, done, info = env.step(SimulatorActions.STOP.value)
+    observation, reward, done, info = env.step(SimulatorActions.STOP)
     assert done is True, "done should be true after STOP action"
 
     env.close()
@@ -381,10 +380,10 @@ def test_action_space_shortest_path():
             unreachable_targets.append(AgentState(position, rotation))
 
     targets = reachable_targets
-    shortest_path1 = env.sim.action_space_shortest_path(source, targets)
+    shortest_path1 = env.action_space_shortest_path(source, targets)
     assert shortest_path1 != []
 
     targets = unreachable_targets
-    shortest_path2 = env.sim.action_space_shortest_path(source, targets)
+    shortest_path2 = env.action_space_shortest_path(source, targets)
     assert shortest_path2 == []
     env.close()
diff --git a/test/test_habitat_example.py b/test/test_habitat_example.py
index d4884c122..493c2412f 100644
--- a/test/test_habitat_example.py
+++ b/test/test_habitat_example.py
@@ -8,6 +8,7 @@ import pytest
 
 import habitat
 from examples import (
+    new_actions,
     register_new_sensors_and_measures,
     shortest_path_follower_example,
     visualization_examples,
@@ -47,3 +48,12 @@ def test_register_new_sensors_and_measures():
         pytest.skip("Please download Habitat test data to data folder.")
 
     register_new_sensors_and_measures.main()
+
+
+def test_new_actions():
+    if not PointNavDatasetV1.check_config_paths_exist(
+        config=habitat.get_config().DATASET
+    ):
+        pytest.skip("Please download Habitat test data to data folder.")
+
+    new_actions.main()
diff --git a/test/test_pointnav_dataset.py b/test/test_pointnav_dataset.py
index 4b3ce98bf..7a3954c64 100644
--- a/test/test_pointnav_dataset.py
+++ b/test/test_pointnav_dataset.py
@@ -70,7 +70,7 @@ def test_multiple_files_scene_path():
         len(scenes) > 0
     ), "Expected dataset contains separate episode file per scene."
     dataset_config.defrost()
-    dataset_config.POINTNAVV1.CONTENT_SCENES = scenes[:PARTIAL_LOAD_SCENES]
+    dataset_config.CONTENT_SCENES = scenes[:PARTIAL_LOAD_SCENES]
     dataset_config.SCENES_DIR = os.path.join(
         os.getcwd(), DEFAULT_SCENE_PATH_PREFIX
     )
@@ -98,7 +98,7 @@ def test_multiple_files_pointnav_dataset():
         len(scenes) > 0
     ), "Expected dataset contains separate episode file per scene."
     dataset_config.defrost()
-    dataset_config.POINTNAVV1.CONTENT_SCENES = scenes[:PARTIAL_LOAD_SCENES]
+    dataset_config.CONTENT_SCENES = scenes[:PARTIAL_LOAD_SCENES]
     dataset_config.freeze()
     partial_dataset = make_dataset(
         id_dataset=dataset_config.TYPE, config=dataset_config
diff --git a/test/test_relative_camera.py b/test/test_relative_camera.py
index fd87dff85..0911861d4 100644
--- a/test/test_relative_camera.py
+++ b/test/test_relative_camera.py
@@ -3,6 +3,7 @@
 # Copyright (c) Facebook, Inc. and its affiliates.
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
+import gc
 
 import pytest
 
@@ -32,3 +33,8 @@ def test_demo_notebook():
                 "notebooks/relative_camera_views_transform_and_warping_demo.ipynb",
             ]
         )
+
+        # NB: Force a gc collect run as it can take a little bit for
+        # the cleanup to happen after the notebook and we get
+        # a double context crash!
+        gc.collect()
diff --git a/test/test_sensors.py b/test/test_sensors.py
index c99d4e826..5e34e6b5f 100644
--- a/test/test_sensors.py
+++ b/test/test_sensors.py
@@ -12,19 +12,9 @@ import pytest
 
 import habitat
 from habitat.config.default import get_config
-from habitat.sims.habitat_simulator import SimulatorActions
+from habitat.core.simulator import SimulatorActions
 from habitat.tasks.nav.nav_task import NavigationEpisode, NavigationGoal
 
-NON_STOP_ACTIONS = [
-    v for v in range(len(SimulatorActions)) if v != SimulatorActions.STOP.value
-]
-
-MOVEMENT_ACTIONS = [
-    SimulatorActions.MOVE_FORWARD.value,
-    SimulatorActions.TURN_LEFT.value,
-    SimulatorActions.TURN_RIGHT.value,
-]
-
 
 def _random_episode(env, config):
     random_location = env._sim.sample_navigable_point()
@@ -121,9 +111,9 @@ def test_collisions():
     np.random.seed(123)
 
     actions = [
-        SimulatorActions.MOVE_FORWARD.value,
-        SimulatorActions.TURN_LEFT.value,
-        SimulatorActions.TURN_RIGHT.value,
+        SimulatorActions.MOVE_FORWARD,
+        SimulatorActions.TURN_LEFT,
+        SimulatorActions.TURN_RIGHT,
     ]
 
     for _ in range(20):
@@ -187,9 +177,14 @@ def test_static_pointgoal_sensor():
         )
     ]
 
+    non_stop_actions = [
+        act
+        for act in range(env.action_space.n)
+        if act != SimulatorActions.STOP
+    ]
     env.reset()
     for _ in range(100):
-        obs = env.step(np.random.choice(NON_STOP_ACTIONS))
+        obs = env.step(np.random.choice(non_stop_actions))
         static_pointgoal = obs["static_pointgoal"]
         # check to see if taking non-stop actions will affect static point_goal
         assert np.allclose(static_pointgoal, expected_static_pointgoal)
@@ -225,13 +220,18 @@ def test_get_observations_at():
             goals=[NavigationGoal(position=goal_position)],
         )
     ]
+    non_stop_actions = [
+        act
+        for act in range(env.action_space.n)
+        if act != SimulatorActions.STOP
+    ]
 
     obs = env.reset()
     start_state = env.sim.get_agent_state()
     for _ in range(100):
         # Note, this test will not currently work for camera change actions
         # (look up/down), only for movement actions.
-        new_obs = env.step(np.random.choice(MOVEMENT_ACTIONS))
+        new_obs = env.step(np.random.choice(non_stop_actions))
         for key, val in new_obs.items():
             agent_state = env.sim.get_agent_state()
             if not (
diff --git a/test/test_trajectory_sim.py b/test/test_trajectory_sim.py
index 8d2302cba..285b8e12e 100644
--- a/test/test_trajectory_sim.py
+++ b/test/test_trajectory_sim.py
@@ -10,9 +10,9 @@ import os
 import numpy as np
 import pytest
 
+from habitat import SimulatorActions
 from habitat.config.default import get_config
 from habitat.sims import make_sim
-from habitat.sims.habitat_simulator import SimulatorActions
 
 
 def init_sim():
@@ -34,7 +34,7 @@ def test_sim_trajectory():
     )
 
     for i, action in enumerate(test_trajectory["actions"]):
-        action = SimulatorActions[action].value
+        action = SimulatorActions[action]
         if i > 0:  # ignore first step as habitat-sim doesn't update
             # agent until then
             state = sim.get_agent_state()
-- 
GitLab