From 8da461f537d950fe450aa14d0070d0816c789605 Mon Sep 17 00:00:00 2001
From: Facebook Community Bot <facebook-github-bot@users.noreply.github.com>
Date: Wed, 3 Apr 2019 04:25:15 -0700
Subject: [PATCH] Benchmark and challenge implementations; PPO agents, Random,
 ForwardOnly, GoalFollower agents; High resolution top down map and bird-eye
 visualizations; Collisions measure introduced; SPL success criteria fix;
 Added deleter that closes subprocesses

---
 Dockerfile                                    |   5 +-
 baselines/agents/__init__.py                  |   0
 baselines/agents/ppo_agents.py                | 140 ++++++++
 baselines/agents/simple_agents.py             | 155 +++++++++
 baselines/rl/ppo/__init__.py                  |   6 +-
 baselines/rl/ppo/policy.py                    | 128 ++++---
 configs/test/habitat_all_sensors_test.yaml    |  17 +
 examples/visualization_examples.py            | 107 ++++++
 habitat/__init__.py                           |   2 +
 habitat/config/default.py                     |   5 +
 habitat/core/benchmark.py                     |  11 +-
 habitat/core/challenge.py                     |  21 ++
 habitat/core/env.py                           |  12 +-
 habitat/core/simulator.py                     |   1 +
 habitat/core/vector_env.py                    |  17 +
 habitat/sims/habitat_simulator.py             |   7 +
 habitat/tasks/nav/nav_task.py                 |  56 +++-
 habitat/utils/__init__.py                     |   7 +
 habitat/utils/visualizations/__init__.py      |   9 +
 .../maps_topdown_agent_sprite/100x100.png     | Bin 0 -> 5616 bytes
 habitat/utils/visualizations/maps.py          | 314 ++++++++++++++++++
 habitat/utils/visualizations/utils.py         |  86 +++++
 requirements.txt                              |   5 +-
 test/test_baseline_agents.py                  |  57 ++++
 test/test_habitat_env.py                      |  50 ++-
 test/test_habitat_example.py                  |   9 +
 test/test_sensors.py                          |  53 +++
 test/test_trajectory_sim.py                   |  13 +-
 28 files changed, 1191 insertions(+), 102 deletions(-)
 create mode 100644 baselines/agents/__init__.py
 create mode 100644 baselines/agents/ppo_agents.py
 create mode 100644 baselines/agents/simple_agents.py
 create mode 100644 examples/visualization_examples.py
 create mode 100644 habitat/core/challenge.py
 create mode 100644 habitat/utils/__init__.py
 create mode 100644 habitat/utils/visualizations/__init__.py
 create mode 100644 habitat/utils/visualizations/assets/maps_topdown_agent_sprite/100x100.png
 create mode 100644 habitat/utils/visualizations/maps.py
 create mode 100644 habitat/utils/visualizations/utils.py
 create mode 100644 test/test_baseline_agents.py
 create mode 100644 test/test_sensors.py

diff --git a/Dockerfile b/Dockerfile
index 211bfb00b..d8dcc0b5b 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -31,9 +31,9 @@ RUN curl -o ~/miniconda.sh -O  https://repo.continuum.io/miniconda/Miniconda3-la
 ENV PATH /opt/conda/bin:$PATH
 
 # Install cmake
-ADD cmake-3.13.3-Linux-x86_64.sh /cmake-3.13.3-Linux-x86_64.sh
+RUN wget https://github.com/Kitware/CMake/releases/download/v3.13.4/cmake-3.13.4-Linux-x86_64.sh
 RUN mkdir /opt/cmake
-RUN sh /cmake-3.13.3-Linux-x86_64.sh --prefix=/opt/cmake --skip-license
+RUN sh /cmake-3.13.4-Linux-x86_64.sh --prefix=/opt/cmake --skip-license
 RUN ln -s /opt/cmake/bin/cmake /usr/local/bin/cmake
 RUN cmake --version
 
@@ -56,4 +56,3 @@ RUN rm habitat-test-scenes.zip
 # Silence habitat-sim logs
 ENV GLOG_minloglevel=2
 ENV MAGNUM_LOG="quiet"
-
diff --git a/baselines/agents/__init__.py b/baselines/agents/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/baselines/agents/ppo_agents.py b/baselines/agents/ppo_agents.py
new file mode 100644
index 000000000..9a0080bed
--- /dev/null
+++ b/baselines/agents/ppo_agents.py
@@ -0,0 +1,140 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import argparse
+import random
+
+import numpy as np
+import torch
+from gym.spaces import Discrete, Dict, Box
+
+import habitat
+from baselines.rl.ppo import Policy
+from baselines.rl.ppo.utils import batch_obs
+from habitat import Config
+from habitat.core.agent import Agent
+
+
+def get_defaut_config():
+    c = Config()
+    c.INPUT_TYPE = "blind"
+    c.MODEL_PATH = "data/checkpoints/blind.pth"
+    c.RESOLUTION = 256
+    c.HIDDEN_SIZE = 512
+    c.RANDOM_SEED = 7
+    c.PTH_GPU_ID = 0
+    return c
+
+
+class PPOAgent(Agent):
+    def __init__(self, config: Config):
+        spaces = {
+            "pointgoal": Box(
+                low=np.finfo(np.float32).min,
+                high=np.finfo(np.float32).max,
+                shape=(2,),
+                dtype=np.float32,
+            )
+        }
+
+        if config.INPUT_TYPE in ["depth", "rgbd"]:
+            spaces["depth"] = Box(
+                low=0,
+                high=1,
+                shape=(config.RESOLUTION, config.RESOLUTION, 1),
+                dtype=np.float32,
+            )
+
+        if config.INPUT_TYPE in ["rgb", "rgbd"]:
+            spaces["rgb"] = Box(
+                low=0,
+                high=255,
+                shape=(config.RESOLUTION, config.RESOLUTION, 3),
+                dtype=np.uint8,
+            )
+        observation_spaces = Dict(spaces)
+
+        action_spaces = Discrete(4)
+
+        self.device = torch.device("cuda:{}".format(config.PTH_GPU_ID))
+        self.hidden_size = config.HIDDEN_SIZE
+
+        random.seed(config.RANDOM_SEED)
+        torch.random.manual_seed(config.RANDOM_SEED)
+        torch.backends.cudnn.deterministic = True
+
+        self.actor_critic = Policy(
+            observation_space=observation_spaces,
+            action_space=action_spaces,
+            hidden_size=self.hidden_size,
+        )
+        self.actor_critic.to(self.device)
+
+        if config.MODEL_PATH:
+            ckpt = torch.load(config.MODEL_PATH, map_location=self.device)
+            #  Filter only actor_critic weights
+            self.actor_critic.load_state_dict(
+                {
+                    k.replace("actor_critic.", ""): v
+                    for k, v in ckpt["state_dict"].items()
+                    if "actor_critic" in k
+                }
+            )
+
+        else:
+            habitat.logger.error(
+                "Model checkpoint wasn't loaded, evaluating " "a random model."
+            )
+
+        self.test_recurrent_hidden_states = None
+        self.not_done_masks = None
+
+    def reset(self):
+        self.test_recurrent_hidden_states = torch.zeros(
+            1, self.hidden_size, device=self.device
+        )
+        self.not_done_masks = torch.zeros(1, 1, device=self.device)
+
+    def act(self, observations):
+        batch = batch_obs([observations])
+        for sensor in batch:
+            batch[sensor] = batch[sensor].to(self.device)
+
+        with torch.no_grad():
+            _, actions, _, self.test_recurrent_hidden_states = self.actor_critic.act(
+                batch,
+                self.test_recurrent_hidden_states,
+                self.not_done_masks,
+                deterministic=False,
+            )
+            #  Make masks not done till reset (end of episode) will be called
+            self.not_done_masks = torch.ones(1, 1, device=self.device)
+
+        return actions[0][0].item()
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--input_type",
+        default="blind",
+        choices=["blind", "rgb", "depth", "rgbd"],
+    )
+    parser.add_argument("--model_path", default="", type=str)
+    args = parser.parse_args()
+
+    config = get_defaut_config()
+    config.INPUT_TYPE = args.input_type
+    config.MODEL_PATH = args.model_path
+
+    agent = PPOAgent(config)
+    challenge = habitat.Challenge()
+    challenge.submit(agent)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/baselines/agents/simple_agents.py b/baselines/agents/simple_agents.py
new file mode 100644
index 000000000..6e83cd0c9
--- /dev/null
+++ b/baselines/agents/simple_agents.py
@@ -0,0 +1,155 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import argparse
+import random
+from math import pi
+
+import numpy as np
+
+import habitat
+from habitat.sims.habitat_simulator import (
+    SimulatorActions,
+    SIM_ACTION_TO_NAME,
+    SIM_NAME_TO_ACTION,
+)
+
+NON_STOP_ACTIONS = [
+    k
+    for k, v in SIM_ACTION_TO_NAME.items()
+    if v != SimulatorActions.STOP.value
+]
+
+
+class RandomAgent(habitat.Agent):
+    def __init__(self, config):
+        self.dist_threshold_to_stop = config.TASK.SUCCESS_DISTANCE
+
+    def reset(self):
+        pass
+
+    def act(self, observations):
+        action = SIM_NAME_TO_ACTION[SimulatorActions.FORWARD.value]
+        return action
+
+    def is_goal_reached(self, observations):
+        dist = observations["pointgoal"][0]
+        return dist <= self.dist_threshold_to_stop
+
+    def act(self, observations):
+        if self.is_goal_reached(observations):
+            action = SIM_NAME_TO_ACTION[SimulatorActions.STOP.value]
+        else:
+            action = np.random.choice(NON_STOP_ACTIONS)
+        return action
+
+
+class ForwardOnlyAgent(RandomAgent):
+    def act(self, observations):
+        if self.is_goal_reached(observations):
+            action = SIM_NAME_TO_ACTION[SimulatorActions.STOP.value]
+        else:
+            action = SIM_NAME_TO_ACTION[SimulatorActions.FORWARD.value]
+        return action
+
+
+class RandomForwardAgent(RandomAgent):
+    def __init__(self, config):
+        super(RandomForwardAgent, self).__init__(config)
+        self.dist_threshold_to_stop = config.TASK.SUCCESS_DISTANCE
+        self.FORWARD_PROBABILITY = 0.8
+
+    def act(self, observations):
+        if self.is_goal_reached(observations):
+            action = SIM_NAME_TO_ACTION[SimulatorActions.STOP.value]
+        else:
+            if np.random.uniform(0, 1, 1) < self.FORWARD_PROBABILITY:
+                action = SIM_NAME_TO_ACTION[SimulatorActions.FORWARD.value]
+            else:
+                action = np.random.choice(
+                    [
+                        SIM_NAME_TO_ACTION[SimulatorActions.LEFT.value],
+                        SIM_NAME_TO_ACTION[SimulatorActions.RIGHT.value],
+                    ]
+                )
+
+        return action
+
+
+class GoalFollower(RandomAgent):
+    def __init__(self, config):
+        super(GoalFollower, self).__init__(config)
+        self.pos_th = self.dist_threshold_to_stop
+        self.angle_th = float(np.deg2rad(15))
+        self.random_prob = 0
+
+    def normalize_angle(self, angle):
+        if angle < -pi:
+            angle = 2.0 * pi + angle
+        if angle > pi:
+            angle = -2.0 * pi + angle
+        return angle
+
+    def turn_towards_goal(self, angle_to_goal):
+        if angle_to_goal > pi or (
+            (angle_to_goal < 0) and (angle_to_goal > -pi)
+        ):
+            action = SIM_NAME_TO_ACTION[SimulatorActions.RIGHT.value]
+        else:
+            action = SIM_NAME_TO_ACTION[SimulatorActions.LEFT.value]
+        return action
+
+    def act(self, observations):
+        if self.is_goal_reached(observations):
+            action = SIM_NAME_TO_ACTION[SimulatorActions.STOP.value]
+        else:
+            angle_to_goal = self.normalize_angle(
+                np.array(observations["pointgoal"][1])
+            )
+            if abs(angle_to_goal) < self.angle_th:
+                action = SIM_NAME_TO_ACTION[SimulatorActions.FORWARD.value]
+            else:
+                action = self.turn_towards_goal(angle_to_goal)
+
+        return action
+
+
+def get_all_subclasses(cls):
+    return set(cls.__subclasses__()).union(
+        [s for c in cls.__subclasses__() for s in get_all_subclasses(c)]
+    )
+
+
+def get_agent_cls(agent_class_name):
+    sub_classes = [
+        sub_class
+        for sub_class in get_all_subclasses(habitat.Agent)
+        if sub_class.__name__ == agent_class_name
+    ]
+    return sub_classes[0]
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--task-config", type=str, default="tasks/pointnav.yaml"
+    )
+    parser.add_argument("--agent_class", type=str, default="GoalFollower")
+    args = parser.parse_args()
+
+    agent = get_agent_cls(args.agent_class)(
+        habitat.get_config(args.task_config)
+    )
+    benchmark = habitat.Benchmark(args.task_config)
+    metrics = benchmark.evaluate(agent)
+
+    for k, v in metrics.items():
+        habitat.logger.info("{}: {:.3f}".format(k, v))
+
+
+if __name__ == "__main__":
+    main()
diff --git a/baselines/rl/ppo/__init__.py b/baselines/rl/ppo/__init__.py
index 843a232a1..248f2b30f 100644
--- a/baselines/rl/ppo/__init__.py
+++ b/baselines/rl/ppo/__init__.py
@@ -4,8 +4,8 @@
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
 
-from rl.ppo.ppo import PPO
-from rl.ppo.policy import Policy
-from rl.ppo.utils import RolloutStorage
+from baselines.rl.ppo.ppo import PPO
+from baselines.rl.ppo.policy import Policy
+from baselines.rl.ppo.utils import RolloutStorage
 
 __all__ = ["PPO", "Policy", "RolloutStorage"]
diff --git a/baselines/rl/ppo/policy.py b/baselines/rl/ppo/policy.py
index dea82476c..2c85d8246 100644
--- a/baselines/rl/ppo/policy.py
+++ b/baselines/rl/ppo/policy.py
@@ -7,7 +7,7 @@
 import torch
 import torch.nn as nn
 
-from rl.ppo.utils import Flatten, CategoricalNet
+from baselines.rl.ppo.utils import Flatten, CategoricalNet
 
 import numpy as np
 
@@ -67,20 +67,34 @@ class Net(nn.Module):
     def __init__(self, observation_space, hidden_size):
         super().__init__()
 
+        self._n_input_goal = observation_space.spaces["pointgoal"].shape[0]
+        self._hidden_size = hidden_size
+
+        self.cnn = self._init_perception_model(observation_space)
+
+        if self.is_blind:
+            self.rnn = nn.GRU(self._n_input_goal, self._hidden_size)
+        else:
+            self.rnn = nn.GRU(
+                self.output_size + self._n_input_goal, self._hidden_size
+            )
+
+        self.critic_linear = nn.Linear(self._hidden_size, 1)
+
+        self.layer_init()
+        self.train()
+
+    def _init_perception_model(self, observation_space):
         if "rgb" in observation_space.spaces:
             self._n_input_rgb = observation_space.spaces["rgb"].shape[2]
         else:
             self._n_input_rgb = 0
+
         if "depth" in observation_space.spaces:
             self._n_input_depth = observation_space.spaces["depth"].shape[2]
         else:
             self._n_input_depth = 0
 
-        assert self._n_input_rgb + self._n_input_depth > 0
-
-        self._n_input_goal = observation_space.spaces["pointgoal"].shape[0]
-        self._hidden_size = hidden_size
-
         # kernel size for different CNN layers
         self._cnn_layers_kernel_size = [(8, 8), (4, 4), (3, 3)]
 
@@ -91,52 +105,50 @@ class Net(nn.Module):
             cnn_dims = np.array(
                 observation_space.spaces["rgb"].shape[:2], dtype=np.float32
             )
-        else:
+        elif self._n_input_depth > 0:
             cnn_dims = np.array(
                 observation_space.spaces["depth"].shape[:2], dtype=np.float32
             )
 
-        for kernel_size, stride in zip(
-            self._cnn_layers_kernel_size, self._cnn_layers_stride
-        ):
-            cnn_dims = self._conv_output_dim(
-                dimension=cnn_dims,
-                padding=np.array([0, 0], dtype=np.float32),
-                dilation=np.array([1, 1], dtype=np.float32),
-                kernel_size=np.array(kernel_size, dtype=np.float32),
-                stride=np.array(stride, dtype=np.float32),
-            )
-
-        self.cnn = nn.Sequential(
-            nn.Conv2d(
-                in_channels=self._n_input_rgb + self._n_input_depth,
-                out_channels=32,
-                kernel_size=self._cnn_layers_kernel_size[0],
-                stride=self._cnn_layers_stride[0],
-            ),
-            nn.ReLU(),
-            nn.Conv2d(
-                in_channels=32,
-                out_channels=64,
-                kernel_size=self._cnn_layers_kernel_size[1],
-                stride=self._cnn_layers_stride[1],
-            ),
-            nn.ReLU(),
-            nn.Conv2d(
-                in_channels=64,
-                out_channels=32,
-                kernel_size=self._cnn_layers_kernel_size[2],
-                stride=self._cnn_layers_stride[2],
-            ),
-            Flatten(),
-            nn.Linear(32 * cnn_dims[0] * cnn_dims[1], hidden_size),
-            nn.ReLU(),
-        )
-        self.rnn = nn.GRU(hidden_size + self._n_input_goal, hidden_size)
-        self.critic_linear = nn.Linear(hidden_size, 1)
+        if self.is_blind:
+            return nn.Sequential()
+        else:
+            for kernel_size, stride in zip(
+                self._cnn_layers_kernel_size, self._cnn_layers_stride
+            ):
+                cnn_dims = self._conv_output_dim(
+                    dimension=cnn_dims,
+                    padding=np.array([0, 0], dtype=np.float32),
+                    dilation=np.array([1, 1], dtype=np.float32),
+                    kernel_size=np.array(kernel_size, dtype=np.float32),
+                    stride=np.array(stride, dtype=np.float32),
+                )
 
-        self.layer_init()
-        self.train()
+            return nn.Sequential(
+                nn.Conv2d(
+                    in_channels=self._n_input_rgb + self._n_input_depth,
+                    out_channels=32,
+                    kernel_size=self._cnn_layers_kernel_size[0],
+                    stride=self._cnn_layers_stride[0],
+                ),
+                nn.ReLU(),
+                nn.Conv2d(
+                    in_channels=32,
+                    out_channels=64,
+                    kernel_size=self._cnn_layers_kernel_size[1],
+                    stride=self._cnn_layers_stride[1],
+                ),
+                nn.ReLU(),
+                nn.Conv2d(
+                    in_channels=64,
+                    out_channels=32,
+                    kernel_size=self._cnn_layers_kernel_size[2],
+                    stride=self._cnn_layers_stride[2],
+                ),
+                Flatten(),
+                nn.Linear(32 * cnn_dims[0] * cnn_dims[1], self._hidden_size),
+                nn.ReLU(),
+            )
 
     def _conv_output_dim(
         self, dimension, padding, dilation, kernel_size, stride
@@ -240,28 +252,36 @@ class Net(nn.Module):
 
         return x, hidden_states
 
-    def forward(self, observations, rnn_hidden_states, masks):
+    @property
+    def is_blind(self):
+        return self._n_input_rgb + self._n_input_depth == 0
+
+    def forward_perception_model(self, observations):
         cnn_input = []
         if self._n_input_rgb > 0:
             rgb_observations = observations["rgb"]
             # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH]
             rgb_observations = rgb_observations.permute(0, 3, 1, 2)
             rgb_observations = rgb_observations / 255.0  # normalize RGB
-
             cnn_input.append(rgb_observations)
+
         if self._n_input_depth > 0:
             depth_observations = observations["depth"]
-
             # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH]
             depth_observations = depth_observations.permute(0, 3, 1, 2)
-
             cnn_input.append(depth_observations)
 
         cnn_input = torch.cat(cnn_input, dim=1)
-        goal_observations = observations["pointgoal"]
 
-        x = self.cnn(cnn_input)
-        x = torch.cat([x, goal_observations], dim=1)  # concatenate goal vector
+        return self.cnn(cnn_input)
+
+    def forward(self, observations, rnn_hidden_states, masks):
+        x = observations["pointgoal"]
+
+        if not self.is_blind:
+            perception_embed = self.forward_perception_model(observations)
+            x = torch.cat([perception_embed, x], dim=1)
+
         x, rnn_hidden_states = self.forward_rnn(x, rnn_hidden_states, masks)
 
         return self.critic_linear(x), x, rnn_hidden_states
diff --git a/configs/test/habitat_all_sensors_test.yaml b/configs/test/habitat_all_sensors_test.yaml
index 79375eef1..1ea4783fb 100644
--- a/configs/test/habitat_all_sensors_test.yaml
+++ b/configs/test/habitat_all_sensors_test.yaml
@@ -3,8 +3,25 @@ ENVIRONMENT:
 SIMULATOR:
   AGENT_0:
     SENSORS: ['RGB_SENSOR', 'DEPTH_SENSOR']
+  RGB_SENSOR:
+    WIDTH: 256
+    HEIGHT: 256
+  DEPTH_SENSOR:
+    WIDTH: 256
+    HEIGHT: 256
 DATASET:
   TYPE: PointNav-v1
   SPLIT: train
   POINTNAVV1:
     DATA_PATH: data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz
+TASK:
+  TYPE: Nav-v0
+  SUCCESS_DISTANCE: 0.2
+  SENSORS: ['POINTGOAL_SENSOR']
+  POINTGOAL_SENSOR:
+    TYPE: PointGoalSensor
+    GOAL_FORMAT: POLAR
+  MEASUREMENTS: ['SPL']
+  SPL:
+    TYPE: SPL
+    SUCCESS_DISTANCE: 0.2
\ No newline at end of file
diff --git a/examples/visualization_examples.py b/examples/visualization_examples.py
new file mode 100644
index 000000000..ab5e4d281
--- /dev/null
+++ b/examples/visualization_examples.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import numpy as np
+import imageio
+
+import habitat
+from habitat.tasks.nav.nav_task import NavigationEpisode, NavigationGoal
+from habitat.utils.visualizations import maps
+
+
+def example_pointnav_draw_target_birdseye_view():
+    goal_radius = 0.5
+    goal = NavigationGoal([10, 0.25, 10], goal_radius)
+    agent_position = np.array([0, 0.25, 0])
+    agent_rotation = -np.pi / 4
+
+    dummy_episode = NavigationEpisode(
+        [goal],
+        episode_id="dummy_id",
+        scene_id="dummy_scene",
+        start_position=agent_position,
+        start_rotation=agent_rotation,
+    )
+    target_image = maps.pointnav_draw_target_birdseye_view(
+        agent_position,
+        agent_rotation,
+        np.asarray(dummy_episode.goals[0].position),
+        goal_radius=dummy_episode.goals[0].radius,
+        agent_radius_px=25,
+    )
+
+    imageio.imsave("pointnav_target_image.png", target_image)
+
+
+def example_pointnav_draw_target_birdseye_view_agent_on_border():
+    goal_radius = 0.5
+    goal = NavigationGoal([0, 0.25, 0], goal_radius)
+    ii = 0
+    for x_edge in [-1, 0, 1]:
+        for y_edge in [-1, 0, 1]:
+            if not np.bitwise_xor(x_edge == 0, y_edge == 0):
+                continue
+            ii += 1
+            agent_position = np.array([7.8 * x_edge, 0.25, 7.8 * y_edge])
+            agent_rotation = np.pi / 2
+
+            dummy_episode = NavigationEpisode(
+                [goal],
+                episode_id="dummy_id",
+                scene_id="dummy_scene",
+                start_position=agent_position,
+                start_rotation=agent_rotation,
+            )
+            target_image = maps.pointnav_draw_target_birdseye_view(
+                agent_position,
+                agent_rotation,
+                np.asarray(dummy_episode.goals[0].position),
+                goal_radius=dummy_episode.goals[0].radius,
+                agent_radius_px=25,
+            )
+            imageio.imsave(
+                "pointnav_target_image_edge_%d.png" % ii, target_image
+            )
+
+
+def example_get_topdown_map():
+    config = habitat.get_config(config_file="tasks/pointnav.yaml")
+    dataset = habitat.make_dataset(
+        id_dataset=config.DATASET.TYPE, config=config.DATASET
+    )
+    env = habitat.Env(config=config, dataset=dataset)
+    env.reset()
+    top_down_map = maps.get_topdown_map(env.sim, map_resolution=(5000, 5000))
+    recolor_map = np.array(
+        [[255, 255, 255], [128, 128, 128], [0, 0, 0]], dtype=np.uint8
+    )
+    range_x = np.where(np.any(top_down_map, axis=1))[0]
+    range_y = np.where(np.any(top_down_map, axis=0))[0]
+    padding = int(np.ceil(top_down_map.shape[0] / 125))
+    range_x = (
+        max(range_x[0] - padding, 0),
+        min(range_x[-1] + padding + 1, top_down_map.shape[0]),
+    )
+    range_y = (
+        max(range_y[0] - padding, 0),
+        min(range_y[-1] + padding + 1, top_down_map.shape[1]),
+    )
+    top_down_map = top_down_map[
+        range_x[0] : range_x[1], range_y[0] : range_y[1]
+    ]
+    top_down_map = recolor_map[top_down_map]
+    imageio.imsave("top_down_map.png", top_down_map)
+
+
+def main():
+    example_pointnav_draw_target_birdseye_view()
+    example_get_topdown_map()
+    example_pointnav_draw_target_birdseye_view_agent_on_border()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/habitat/__init__.py b/habitat/__init__.py
index 04d5e3cf7..847fb295c 100644
--- a/habitat/__init__.py
+++ b/habitat/__init__.py
@@ -6,6 +6,7 @@
 
 from habitat.core.agent import Agent
 from habitat.core.benchmark import Benchmark
+from habitat.core.challenge import Challenge
 from habitat.config import Config, get_config
 from habitat.core.dataset import Dataset
 from habitat.core.embodied_task import EmbodiedTask, Measure, Measurements
@@ -19,6 +20,7 @@ from habitat.version import VERSION as __version__  # noqa
 __all__ = [
     "Agent",
     "Benchmark",
+    "Challenge",
     "Config",
     "Dataset",
     "EmbodiedTask",
diff --git a/habitat/config/default.py b/habitat/config/default.py
index 10ac1d886..8bd9189b2 100644
--- a/habitat/config/default.py
+++ b/habitat/config/default.py
@@ -37,6 +37,11 @@ _C.TASK.POINTGOAL_SENSOR = CN()
 _C.TASK.POINTGOAL_SENSOR.TYPE = "PointGoalSensor"
 _C.TASK.POINTGOAL_SENSOR.GOAL_FORMAT = "POLAR"
 # -----------------------------------------------------------------------------
+# # HEADING SENSOR
+# -----------------------------------------------------------------------------
+_C.TASK.HEADING_SENSOR = CN()
+_C.TASK.HEADING_SENSOR.TYPE = "HeadingSensor"
+# -----------------------------------------------------------------------------
 # # SPL MEASUREMENT
 # -----------------------------------------------------------------------------
 _C.TASK.SPL = CN()
diff --git a/habitat/core/benchmark.py b/habitat/core/benchmark.py
index eaf97ad11..cb49a6b81 100644
--- a/habitat/core/benchmark.py
+++ b/habitat/core/benchmark.py
@@ -7,8 +7,8 @@
 from collections import defaultdict
 from typing import Dict, Optional
 
+from habitat.config.default import get_config, DEFAULT_CONFIG_DIR
 from habitat.core.agent import Agent
-from habitat.config.default import get_config
 from habitat.core.env import Env
 
 
@@ -18,10 +18,15 @@ class Benchmark:
 
     Args:
         config_file: file to be used for creating the environment.
+        config_dir: directory where config_file is located.
     """
 
-    def __init__(self, config_file: Optional[str] = None) -> None:
-        config_env = get_config(config_file=config_file)
+    def __init__(
+        self,
+        config_file: Optional[str] = None,
+        config_dir: str = DEFAULT_CONFIG_DIR,
+    ) -> None:
+        config_env = get_config(config_file=config_file, config_dir=config_dir)
         self._env = Env(config=config_env)
 
     def evaluate(
diff --git a/habitat/core/challenge.py b/habitat/core/challenge.py
new file mode 100644
index 000000000..5094333e5
--- /dev/null
+++ b/habitat/core/challenge.py
@@ -0,0 +1,21 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+
+from habitat.core.benchmark import Benchmark
+from habitat.core.logging import logger
+
+
+class Challenge(Benchmark):
+    def __init__(self):
+        config_file = os.environ["CHALLENGE_CONFIG_FILE"]
+        super().__init__(config_file)
+
+    def submit(self, agent):
+        metrics = super().evaluate(agent)
+        for k, v in metrics.items():
+            logger.info("{}: {}".format(k, v))
diff --git a/habitat/core/env.py b/habitat/core/env.py
index d48695d23..cfa3438c2 100644
--- a/habitat/core/env.py
+++ b/habitat/core/env.py
@@ -68,6 +68,16 @@ class Env:
                 id_dataset=config.DATASET.TYPE, config=config.DATASET
             )
         self._episodes = self._dataset.episodes if self._dataset else []
+
+        # load the first scene if dataset is present
+        if self._dataset:
+            assert len(self._dataset.episodes) > 0, (
+                "dataset should have " "non-empty episodes list"
+            )
+            self._config.defrost()
+            self._config.SIMULATOR.SCENE = self._dataset.episodes[0].scene_id
+            self._config.freeze()
+
         self._sim = make_sim(
             id_sim=self._config.SIMULATOR.TYPE, config=self._config.SIMULATOR
         )
@@ -75,7 +85,7 @@ class Env:
             self._config.TASK.TYPE,
             task_config=self._config.TASK,
             sim=self._sim,
-            dataset=dataset,
+            dataset=self._dataset,
         )
         self.observation_space = SpaceDict(
             {
diff --git a/habitat/core/simulator.py b/habitat/core/simulator.py
index 6cd0c1cba..fbeaac1e9 100644
--- a/habitat/core/simulator.py
+++ b/habitat/core/simulator.py
@@ -28,6 +28,7 @@ class SensorTypes(Enum):
     TENSOR = 8
     TEXT = 9
     MEASUREMENT = 10
+    HEADING = 11
 
 
 class Sensor:
diff --git a/habitat/core/vector_env.py b/habitat/core/vector_env.py
index 0451c8a16..4618e77c7 100644
--- a/habitat/core/vector_env.py
+++ b/habitat/core/vector_env.py
@@ -83,6 +83,7 @@ class VectorEnv:
     ) -> None:
 
         self._is_waiting = False
+        self._is_closed = True
 
         assert (
             env_fn_args is not None and len(env_fn_args) > 0
@@ -103,6 +104,8 @@ class VectorEnv:
             env_fn_args, make_env_fn
         )
 
+        self._is_closed = False
+
         for write_fn in self._connection_write_fns:
             write_fn((OBSERVATION_SPACE_COMMAND, None))
         self.observation_spaces = [
@@ -294,6 +297,9 @@ class VectorEnv:
         return self.wait_step()
 
     def close(self) -> None:
+        if self._is_closed:
+            return
+
         if self._is_waiting:
             for read_fn in self._connection_read_fns:
                 read_fn()
@@ -302,6 +308,8 @@ class VectorEnv:
         for process in self._workers:
             process.join()
 
+        self._is_closed = True
+
     def render(
         self, mode: str = "human", *args, **kwargs
     ) -> Union[np.ndarray, None]:
@@ -326,6 +334,15 @@ class VectorEnv:
     def _valid_start_methods(self) -> Set[str]:
         return {"forkserver", "spawn", "fork"}
 
+    def __del__(self):
+        self.close()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
+
 
 class ThreadedVectorEnv(VectorEnv):
     def _spawn_workers(
diff --git a/habitat/sims/habitat_simulator.py b/habitat/sims/habitat_simulator.py
index d6d29b27e..3a5311454 100644
--- a/habitat/sims/habitat_simulator.py
+++ b/habitat/sims/habitat_simulator.py
@@ -198,6 +198,13 @@ class HabitatSim(habitat.Simulator):
             sim_sensor_cfg.sensor_type = sensor.sim_sensor_type  # type: ignore
             sensor_specifications.append(sim_sensor_cfg)
 
+        # If there is no sensors specified create a dummy sensor so simulator
+        # won't throw an error
+        if not _sensor_suite.sensors.values():
+            sim_sensor_cfg = habitat_sim.SensorSpec()
+            sim_sensor_cfg.resolution = [1, 1]
+            sensor_specifications.append(sim_sensor_cfg)
+
         agent_config.sensor_specifications = sensor_specifications
         agent_config.action_space = {
             SimulatorActions.LEFT.value: habitat_sim.ActionSpec(
diff --git a/habitat/tasks/nav/nav_task.py b/habitat/tasks/nav/nav_task.py
index b0737e4d6..bf44de919 100644
--- a/habitat/tasks/nav/nav_task.py
+++ b/habitat/tasks/nav/nav_task.py
@@ -21,7 +21,9 @@ from habitat.core.simulator import (
 from habitat.tasks.utils import quaternion_to_rotation, cartesian_to_polar
 
 
-def merge_sim_episode_config(sim_config: Any, episode: Type[Episode]) -> Any:
+def merge_sim_episode_config(
+    sim_config: Config, episode: Type[Episode]
+) -> Any:
     sim_config.defrost()
     sim_config.SCENE = episode.scene_id
     sim_config.freeze()
@@ -148,7 +150,7 @@ class PointGoalSensor(habitat.Sensor):
             in cartesian or polar coordinates.
     """
 
-    def __init__(self, sim, config):
+    def __init__(self, sim: Simulator, config: Config):
         self._sim = sim
 
         self._goal_format = getattr(config, "GOAL_FORMAT", "CARTESIAN")
@@ -199,6 +201,45 @@ class PointGoalSensor(habitat.Sensor):
         return direction_vector_agent
 
 
+class HeadingSensor(habitat.Sensor):
+    """
+       Sensor for observing the agent's heading in the global coordinate frame.
+
+       Args:
+           sim: reference to the simulator for calculating task observations.
+           config: config for the sensor.
+       """
+
+    def __init__(self, sim: Simulator, config: Config):
+        self._sim = sim
+        super().__init__(config=config)
+
+    def _get_uuid(self, *args: Any, **kwargs: Any):
+        return "heading"
+
+    def _get_sensor_type(self, *args: Any, **kwargs: Any):
+        return SensorTypes.HEADING
+
+    def _get_observation_space(self, *args: Any, **kwargs: Any):
+        return spaces.Box(low=-np.pi, high=np.pi, shape=(1,), dtype=np.float)
+
+    def get_observation(self, observations, episode):
+        agent_state = self._sim.get_agent_state()
+        # Quaternion is in x, y, z, w format
+        ref_rotation = agent_state.rotation
+
+        direction_vector = np.array([0, 0, -1])
+
+        rotation_world_agent = quaternion_to_rotation(
+            ref_rotation[3], ref_rotation[0], ref_rotation[1], ref_rotation[2]
+        )
+
+        heading_vector = np.dot(rotation_world_agent.T, direction_vector)
+
+        phi = cartesian_to_polar(-heading_vector[2], heading_vector[0])[1]
+        return np.array(phi)
+
+
 class SPL(habitat.Measure):
     """SPL (Success weighted by Path Length)
 
@@ -206,7 +247,7 @@ class SPL(habitat.Measure):
     https://arxiv.org/pdf/1807.06757.pdf
     """
 
-    def __init__(self, sim, config):
+    def __init__(self, sim: Simulator, config: Config):
         self._previous_position = None
         self._start_end_episode_distance = None
         self._agent_episode_distance = None
@@ -233,12 +274,13 @@ class SPL(habitat.Measure):
         ep_success = 0
         current_position = self._sim.get_agent_state().position.tolist()
 
+        distance_to_target = self._sim.geodesic_distance(
+            current_position, episode.goals[0].position
+        )
+
         if (
             action == self._sim.index_stop_action
-            and self._euclidean_distance(
-                current_position, episode.goals[0].position
-            )
-            < self._config.SUCCESS_DISTANCE
+            and distance_to_target < self._config.SUCCESS_DISTANCE
         ):
             ep_success = 1
 
diff --git a/habitat/utils/__init__.py b/habitat/utils/__init__.py
new file mode 100644
index 000000000..d1f5f3fbc
--- /dev/null
+++ b/habitat/utils/__init__.py
@@ -0,0 +1,7 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+__all__ = ["visualizations"]
diff --git a/habitat/utils/visualizations/__init__.py b/habitat/utils/visualizations/__init__.py
new file mode 100644
index 000000000..5f27292dc
--- /dev/null
+++ b/habitat/utils/visualizations/__init__.py
@@ -0,0 +1,9 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from habitat.utils.visualizations import maps
+
+__all__ = ["maps"]
diff --git a/habitat/utils/visualizations/assets/maps_topdown_agent_sprite/100x100.png b/habitat/utils/visualizations/assets/maps_topdown_agent_sprite/100x100.png
new file mode 100644
index 0000000000000000000000000000000000000000..94b0f9396ebe80a25ac06a9baec83b462f815804
GIT binary patch
literal 5616
zcmeAS@N?(olHy`uVBq!ia0y~yU`PRB4mJh`hJr^^Ll_uDw|lxchE&XXJ2!hmOzO<z
z|IdG(_s*_%|CtoE#zc>R$7%s;({y$SDk?|0Wfht_P0<SEZi*?%PPBi)t~+_{qEp<d
z3%9;{r5EsYsh7B*Zsg7#2Ut>8il%yoad1a-smw^xd-?lm$@j|p>GoeNCwp2?-nk_I
z#J|<g4wuZgd;aD3`{#f6ZvVWe#m?J>gMnE=sp$5-&0R+yv90~8#xUD8WpT$GU8U?!
z4P_-m21UV*7pxk&7x=lP3Z4|YXsA@x+7ZW65TN<BMrmS2=EBzu0t{^(OT_P28wT#z
z(^1-zf8Wm3(D328>fYj=cPh`+Of-C^BzTkS!5aq!0ZtYs28LvwE!QqY9hP7gZ1h;o
z)5+n&lGYZ)<)vw0%&_2Q&Z3(+VT^m6Tz=2|$(ejc_+`njNF}3xwmyD-QMb3}U;k^g
z?0H*9hdl4C-}AkeUNTEQ!`^FX%`j)Shl+yIitDdeTz`Ebiko5og^msdB^MX&d;7QC
z6FIN3KOs`nbzuO<!T=7JmZXgw4Cf6$H=p@)xx(h(lKhqq4F;q8YvNX<w{4N0csA|0
z?&)70hR?Lp?n&{rKeXVJIo{Z@<KEth-)HvSS^oUO9eKgXZMlcH<zC4$l}P?({Fy)Y
z`rkmc$(B!Qmn}b5a#vu!;6A~fF?rg?^`|5@Ys>Yk%N%#y8YRkb%errOt$*(K^b1{f
zqMa@$w&hOoTB>AxKey*+n$_HU?`4-?U_MrIRX{NCvr-bbV6E7JbrCwtx{fMsxu*9k
z;ZV}2vw>cboi0pIiz1nKM7!*pb25dm;`o``<#{W#+bXV37MOVE%$Yffk=LGgW|XT4
zaf)!UN+g>I7A{}4$j>{<=**f)CskHm%?fsTR}lOCnrZemnU9O+?F#HI{WIOt(lS$U
zqN38tw9V#h%?HyTef-gTHqCg#>8B<P2Aicqj2pr{jW;tj#ON)beKt)&uK&IS58Hf@
zYN3e&6KBquv&E|Jt>4@XJMFmIuUE~FCvD7kS$wg4`sq}Y*}l^|Le$q^Pwwn^;h>Ok
zG|4bTYij%Hr=OM|KXxqcg-Nv4)v84Mn)LMa^_(nB+uNNMx@?RnVJxwg`ZMXpw%c*q
zzUo4ptK-&-&p!K1_weDv<(2OX)mLq}{EyG>NnFjxqZdw`IAQOmH1YMG6&LDmlm|08
zREn-Q+so;g+{4Ghq$Sq9Flz0;(_g-Pxu17u$Hse{59W(G+r+&%+|J+M+tXts)qi}w
zwg}g{*Sjuh8|$cvXiFXKYCGBzSF(O)<j!^4txgAfk3YT+Qd#xm!$U)*OUqAB4gJ5s
z<%p8f_p9OYwb?du|7Ygje*ClMUh1>H<~+kp<r7DAYu?^dW@wmQKCdz7Wb31Q3r|Tr
zUKXTzx^Ry3(x8{j$;rv*f23}n^?9w9(;sd*tCAO!QjC7C2+G<Mpz-AjkDx+^;XjMp
z2lm|;-SEDg^-tlp#sx<w$n`lHZVuHDiCKE3!f0mAmkSG>rwdGc&0n$ZM~(L1hWuk2
zlaK59F2DR;Iyq>nm#c8#`7;a*%1s3wAG?<qX-qNYO6`rend)qD#9(g;3xi<fR4-TQ
z<e--&zy2j3?~C0P6&#k*ah&J-iHXYM?n)DXZ@DITywFF#!Z5zBM8@{s0sZ;s6h%^6
zntI~x<{rNIcyGVZf}^6x3wy3v33a->vo3ma;;74$byGA?o4L4LIdY_>`p1Wdx*H?j
zJXaE0cl~wS-;ZzQPWAV(*G)cLvahp2iA5yQ>i!Yt{V)BeFfbSw?~OB866!meWN%;d
z<HNMhPR{GESL?i&Ic<3&JNfm@nKO+aef&{-W=@HW`}e~=J71hI`|#khrHd0!GRrXz
zCB<XS{?7|cZqL_Xj9ONaI=Scg<L_tBoSEb3a!cAT>epv)*$4f%J3J=Le9zl{I6YKr
z>bx(8VwYk{Y^EOe+1J@{wD|E$#%{yO`HvdyHXac-UhTDX(cy<5BG1&U?_jy8by#(F
z{pmtq!HM(c&HERoHFej8>|5WC_WXQP@gj5mg|nSb=F3vAIoGCi>ZgVMthwjAG)Uaf
z-@pF!YXiN=yyYfGURUl44i1(-pJH@f&c1HXa^Kr!TQ@7ea6NwK@1oR*{gIElXRdvG
zynpvhpJ%sLu3Xt)t-C&W)}|GTadXp`En9Xk;r2K4Jv$rsxwyTq6H9a7&DN^=dC94#
zo3`ctmJSXM&SyyQJ32Yom7`L5M){3xve_n5FP=~Mne-;(y23`54AG>;7aMYK8#0tv
zr0&YF)XV6w_b@Xy{(N%b!i8_+^!NO~xRJy9mZ^shW0Z}i2-mv{7cP98Idi6e)!KEe
zr$77i2nycYUA}&b*HXcySHw?0-Q+00S}P&5Y)kfqh)14NPp2-P@+)a&h}6y9<?D4q
zO8<Jk_;JzD)b#7yg$ozH-h6Y-?bnGDC%<RkAfgyw__6%|7xs`WsVvFGyIQV2NH$sX
z;j8p{N3E=bSDWh21$TAKKAE!S@$vrp^pcV<ZLMvucfMujIdJCOxjepBr|Zu?S3W-b
z<*V=MRQKI%y{ZfjE?V=?&koQidHd$gn=~DUAKxY@DhaLqI?qETNivz|*REZKsc$l_
zFWAV!Aero=BD8VKwQ4b@KXVlSi9ETu*!?^=TeG~DSogB|p`sg#%C=++L^`>IZHzEE
znDD{by<g6@zr(L<(oaFB4Bfjwe@Eyn+w<bZ>}j{(M|d&c&~G$6b|lI0=$Y*czV4pc
zEyUJfTm9_}uUyo$t*aw+xY(AaZgm2i$mAf}?HcG+dqckF&gOJ}=6S~dmRvq_=1lsf
zmtUr{FG~$AEBw6ue?!pQFV9YMHe|DW+;Hf6o(22L8xOYZDB5`^{rZ~7&!YF4_n*3?
zw5Z2WzF3C+#UZ9$Q(u(aytvLWg_EI7hW*%#{>r&Wt|>XK*VPFQ4i;}Y{IEEC>#Ud0
zqi*|cTpZhcOi1Zb)Y`N|4?pY#NmumyYqW*uYA#*Z;iB}f>G8gm!Jve)$;HJ-MW`|8
z71w*!q&X|r>Afy5E#39EZ>dScU!ncCcAZ(jZAQnF8yk~nGc_KVoxOF|m(9CgnONU4
z&CvO{N|%#`Dc`Q_%?&lBtWu8|KOdzSc`9An`dn^W(DoY_*Exd9{S+h5Gc|?}LjBMD
z@k{3U^|QqRBE>0qP=YzFPW?hwYNXPpt#RughZslR4cQLzU{YtthVXCEu16ldeEIS#
z!?o+LOBeVvcisi%fgaT-Q~&+~iJv($JNNdst=BH~Yu<_j`S_5KlG56(H@4*(e@-(y
zmcz%^e4z9Bf$CKtN#<ih!aQt=JvYys*%SDyM0qd7tb}L1I!dqiPW(16JRB_4q55Pg
zXXJT?Ygua><)km^mu<;*h;W+6o-OHc?Dz8Juf(lynQG{Kv^ilQ=(^W0k^l3&FVkh;
zzB2dj=&<2m)DyOE<-J(hzv=F~*&0>nu`BLhe(x|hKb!TftNGI^D)+DXXZ!a`dfArj
zfQUmbPrh-<Hy`UO`#M=Q&dg@|#Z!SdCU^5mS>L+qKdnOUN~DX{nK#crzL>=}Rrc*G
zc5fFK8-6Vwh1DV7rzh}gS>L*<KdnM8Me=~*`}yl)OE28KxXx~aiw%FDq2S&(ws-Gr
z^)8z&>+<5oqV0AYTxLkRd@+(vj(hWJ>aH_<M_IzTV{~PGtuFIVlSqzBNwd3=W7epr
z`~Rhw^$qPmHAh{wvJ^$PPFTM~eQ&wdD?{%cH*=b1@<$z;(cJgg!fR#Ud8x8y*}BN-
z*-~W<AI$5g9&Fc-WM=qr{dn4k5A{z2*GeCGbM5`>5<T<zs_%cFE(yQJ(%qr4^hBYI
zy9?i^6DLlbW{Bf!b#iF8yyRIXuyxh)sKX0_%x`Sd?7O}0c0+X8yv~v{PF86Tr_D`&
zwC%N7!tG@iTBXfVhaC*p?qaAuBJ}apzo!!?PMrRC-mg0Szh%dl6<QzVb2{CzOj{!)
zX3xtr2H)yRWUAN69iM%UXK_H`bBRa$_dkDC_~!46&hNAHy7gCoTzXqaQR(s2zyEe#
zeSd65e5LH<zStGQZvuP_m_)Sn+JqSI?D@=dXIr)Hw%oqOx5OAfE3b$$zp>4-?{?Yx
z<7chU9|%xta5~y5xJKg7-Ly{gQtsq4>lyN=eRy$k@$*)vh5z4OIrH$P-Ti|;0)kVr
zrB1GB-M`PEnLp=rf`*{O&Si-`I-Gg`XGBCitk~E&-*)EF{eP{xosPQdtl7hwb$a(=
zJ^OeiFU>2jziwX>z5QLJm|)cIvbUelrhRTLu<)51@poTTnZedGD=egV!#~{ne0j&;
zU%WT=CW}0r;FQR^R^$H+r=zYGOEy#=*T~yn$9%84PKq}?pwMIn=V{Ivc^h3^d@N*^
zrEUIs=+@Tk`Qm~}6*uc<{H(jZt?_aU-xjH|A6I)X=eQodeNW=>^mUQ_=VL6}N}U~_
zXzP8^`YqYT8O8lcRQz>Y(~mP#lYd+?ef@Ep`LTd?nHIj0Vn22t5lX&oHsfcVeYC*g
z?8wZu?^lMcwzzxsy>@E*{f*ioF>`E%4n1Edb$IQ3iQ|>8l{La*O{PUAET5ELu~<_h
zEOyRzyQXM+!@l#erfr+uor-yzgC+0rP7@FeUmcpU_Weiw$M(q{3>8&XzZ^`YY_~_P
zUH3|J)zl5nhK{qtTzZx<-}u>=xaWP7Nzyi_oy!ytY+5z>>Isi-?u{FqZeFZ-uu|sT
z-~LFSr?Wa*uin|ab@gPOk8;-@dM*8Q=fsH<Vihk>{dM}jdhPqUep|Q7lx?`Kq!kt`
z(4eg^ckFt;4f}3ANw$O4b9bd5bNasF^`vXIF_j-zzF&A)=i}<t@6Y<@9b?!~{yy%0
zgZz`vKf8jw)|S+zXqNrB>b5oSfeDY=wk02~zBYfb?DTa-!Hv%v?w(dTktO#-SdiO1
zb{}8--a2NrZA%_hoPGZ7)h1z`k5$jL=ARE|_`_iFknybb34g2NbhXJl)ux|b`{H6o
zF3ZUkt-Iys9QrOl)6FQ$QTk(DCd=}~3kAH}<_I4R70xdBH#@0d>w;?r30AglB~D6l
zE$ouhl>U6HKmFt36CS=l3!4@k-O#)?Ltwpr@QNKfeq9ob*_<rWdg|1s%<hiae#_V9
zm{nhY@#4kS&yz2lF>mSg=##UpT9<qK_ttA$?mTxDOKDd!5S*2ly(0PWir(1iJjqGk
zLX6!3D>ECWU4QdtX%EAMFtLTpv}f1^9Jt$m_R&&J_KO0_30r##zrKoea&gmC?R%UP
zyFKr2@jbS5wi)J!UO&GRu|4nZ=3ckQ>YXl4d(ZBi_EFU@tz9;!b>n^Z9c7?YoXA`@
zXTIBI(U7du@0q_u71u7D<#t*0jfH*h?fe@J+0VbWx-eMpu$tD<G226B$Hk04DR+03
zp5K;V?3l}L%wSpf=ZD(qO?*4novxd8<EVat()_Es_v<)Q-4CynZfR=yzQ?px`<`Xb
zMpMNQnY$k?yW>2h8<%nfMnAkd)9T&Nt2+Pw&a`>O+bElKB&YRu{*4WHYGRyRk|L)S
z#+IMHaN)w|jN4W(muQ`=Zf<XH=fC#)t99^75&f?UD+C>nq}hDAEyL$~{bC7g{(X_-
z;br*;&RU;8Qc(2uu}AZQpN}&SpJd3pSE;sdQ#JFweVTpCv@J|xk2~)_tC*xOtE#jp
zc%{ho*I%tc1~1ym^LmNcTc-=tY^%Q&s0jUZS{SgyNP=gXRPV8-&nij{vI#Qt);&Jn
zfBuo~bC-gzOO-2r#~*pVk41Q|TJMbYKb6mYUDn;3-u96{y#KNN|K%<3?Tvcn?d`hD
z_nWqwo@9`EcPLOvNr;DSXMjeE^28J3ds2++A3aF2PCci3&}aGO`;j`5y@HM=OFmpy
z?_Ye+V8^3Ji7X-u-)!-z@BO`D#h+_6XVOfzKA3krU69MjZbvT1)Lm@)4^He!F}}V&
za`UtK;<tOBi{#A6X7gVbwDRrGn!1Bu9w|8lb6+y#7EP~I-dL~qX@B;WuoF{bS*{#9
zTb|tGw=_s={`vRPva)k6u3qz4o^`E@{mb3H$5WaQK4|KG*}-7#pb=_qKL2mj#Q(A0
zq2gr;iN<?YarHf(!uaOfsgl4OI$0h3KQ3Opc;CA3^UDYwv*}JNV%_56-L`bwe$ZEv
z?LIo?vBi8jS=qm~Vhle$pIcpR-P5xF+KU%2$~WJ9GuPSKd75EbXtnlUxzwJ*_O`Y)
zIcC+*L9%;QQkUm!)%&@5ny~bR)j4LrD|Ezm7Z}d_v)}K;bGMx_|4P*+@03`3EA=F=
z#&j8}kA4a+R}7B5S#ve3eeuN=M#sMB2A*ATT4M3yiz`-L&1&y8wC-HNxITEtmGtSR
zVNZEiuK@MWHK%$_mrmx|miyhd_@%-riR3+-G^cniJ-zpA&()&vFFMsir|z44T(kL{
zRe;{~`~EIl12m5O&p*W}ne3w>lBBe1>glK7=a*iKWdC#4e0p`CsCn<R{Mq8YZuNOR
zofXx3d)LiO>bbZrS37N8nAX(q>tlC+n|L+HVCSREsQUVSTeGj*O*;KlFZo%GiB#xm
zS;>FJQ#y8B%m^_owwr!Bc3YHjL#W<{Z@!Z9^7U&1H1>FB*J{l4*{3)A$T11?nLbfb
zubo?+9>zX?{J8&X&Fjzwsi&SiSz>Bt_RVKwM2-8}u<D|UiW^K1%6@;`;W**D?A?bC
z6IX|=wvE&gyKiV__N{0C&u3ZZf6uzVZ@2ubXW#DDi%H-4?iCkOQri5x&t}QvufZNJ
zS3Z3B5W6(!CDZXjIltwX!xRJ*_<yW_n&snigqMYBir;eG<Arj1Q@zA_`1$XvDrNnf
z^W){JWPe}5i9c$mZrr$0JZR;WYOkf2e9zd-J(;4jrO0XTv#Rg+Y=hM%Z=6xND@5z=
z6tAWCjLpoxHQhfR;m2RI^OaZVqw{BlCssr!)XeKBb61+^urQ#(&`ns7KPyyCS28W^
zWQx$$tgV_#<vZ3@ufG0SZOhr#a<5q32|p*CPTL&a@nSPW+0i7!P#v*_Tgxs=^}TzY
z$FpzG)c<apsj5Pp6|*z%=_<vaOEF^Klk6vbXL-=|`i1{8HvU&F+Ic4|?cWr`v~U|a
ze~Z56tgTuZy=~v_$lPaP&=TsLkb7HWs+VipzS|O?&z4y2y>IQ8D7kp+^3`%`y25SK
zMMW7l-F&knHB#R2Sj^Eg(T5)z$Q*BcZ1F*g@!u!qS0PuQZJqS&U&o0}xzgKm7l*Ar
zD6x6|%=#C*-t`;JJfryD_>}cF_2jbJ6PK$u87es~4a%Ht_^kQZ3|WSAzA8cjf>&Ob
zUU^-bv6fX`tU*}tVD?R&Z63~h{#rB$ueA>I3e*s3`tBp&bySJL=J*+YqtE`2KmM3p
z$MWSCM||xaA^(i~`)ae2k`B%NF7>=<=jRzem6Ff6Ds2)yaI?`tL5PEef#H}$#@dBj
z+oqo^es|MD|ES;$NiR)}r4Le9I&8k_vias3ri#{%Z^oa6kIhiO^zzH*&)l^|MVD^v
ztF3mA_R|i$A8q+LO=V^62}hSTNgH=KX7A-ZkS2IjLn*!cM3mC04iBESBCTRk>IE0Q
z%tOT{>IljzFADE?;mW{1HR!y$Q0H{9U3^g!uavs!Ed9^Sx<3AV+P1s{3=9kmp00i_
I>zopr082ad6aWAK

literal 0
HcmV?d00001

diff --git a/habitat/utils/visualizations/maps.py b/habitat/utils/visualizations/maps.py
new file mode 100644
index 000000000..f845cbf2c
--- /dev/null
+++ b/habitat/utils/visualizations/maps.py
@@ -0,0 +1,314 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import cv2
+import imageio
+import scipy.ndimage
+import os
+from typing import List, Tuple, Optional
+from habitat.core.simulator import Simulator
+from habitat.utils.visualizations import utils
+
+AGENT_SPRITE = imageio.imread(
+    os.path.join(
+        os.path.dirname(__file__),
+        "assets",
+        "maps_topdown_agent_sprite",
+        "100x100.png",
+    )
+)
+AGENT_SPRITE = np.ascontiguousarray(np.flipud(AGENT_SPRITE))
+COORDINATE_EPSILON = 1e-6
+COORDINATE_MIN = -62.3241 - COORDINATE_EPSILON
+COORDINATE_MAX = 90.0399 + COORDINATE_EPSILON
+
+
+def draw_agent(
+    image: np.ndarray,
+    agent_center_coord: Tuple[int, int],
+    agent_rotation: float,
+    agent_radius_px: int = 5,
+) -> np.ndarray:
+    """Return an image with the agent image composited onto it.
+    Args:
+        image: the image onto which to put the agent.
+        agent_center_coord: the image coordinates where to paste the agent.
+        agent_rotation: the agent's current rotation in radians.
+        agent_radius_px: 1/2 number of pixels the agent will be resized to.
+    Returns:
+        The modified background image. This operation is in place.
+    """
+
+    # Rotate before resize to keep good resolution.
+    rotated_agent = scipy.ndimage.interpolation.rotate(
+        AGENT_SPRITE, agent_rotation * -180 / np.pi
+    )
+    # Rescale because rotation may result in larger image than original, but the
+    # agent sprite size should stay the same.
+    initial_agent_size = AGENT_SPRITE.shape[0]
+    new_size = rotated_agent.shape[0]
+    agent_size_px = max(
+        1, int(agent_radius_px * 2 * new_size / initial_agent_size)
+    )
+    resized_agent = cv2.resize(
+        rotated_agent,
+        (agent_size_px, agent_size_px),
+        interpolation=cv2.INTER_LINEAR,
+    )
+    utils.paste_overlapping_image(image, resized_agent, agent_center_coord)
+    return image
+
+
+def pointnav_draw_target_birdseye_view(
+    agent_position: np.ndarray,
+    agent_heading: float,
+    goal_position: np.ndarray,
+    resolution_px: int = 800,
+    goal_radius: float = 0.2,
+    agent_radius_px: int = 20,
+    target_band_radii: Optional[List[float]] = None,
+    target_band_colors: Optional[List[Tuple[int, int, int]]] = None,
+) -> np.ndarray:
+    """Return an image of agent w.r.t. centered target location for pointnav
+    tasks.
+
+    Args:
+        agent_position: the agent's current position.
+        agent_heading: the agent's current rotation in radians. This can be
+            found using the HeadingSensor.
+        goal_position: the pointnav task goal position.
+        resolution_px: number of pixels for the output image width and height.
+        goal_radius: how near the agent needs to be to be successful for the
+            pointnav task.
+        agent_radius_px: 1/2 number of pixels the agent will be resized to.
+        target_band_radii: distance in meters to the outer-radius of each band
+            in the target image.
+        target_band_colors: colors in RGB 0-255 for the bands in the target.
+    Returns:
+        Image centered on the goal with the agent's current relative position
+        and rotation represented by an arrow. To make the rotations align
+        visually with habitat, positive-z is up, positive-x is left and a
+        rotation of 0 points upwards in the output image and rotates clockwise.
+    """
+    if target_band_radii is None:
+        target_band_radii = [20, 10, 5, 2.5, 1]
+    if target_band_colors is None:
+        target_band_colors = [
+            (47, 19, 122),
+            (22, 99, 170),
+            (92, 177, 0),
+            (226, 169, 0),
+            (226, 12, 29),
+        ]
+
+    assert len(target_band_radii) == len(
+        target_band_colors
+    ), "There must be an equal number of scales and colors."
+
+    goal_agent_dist = np.linalg.norm(agent_position - goal_position, 2)
+
+    goal_distance_padding = max(
+        2, 2 ** np.ceil(np.log(max(1e-6, goal_agent_dist)) / np.log(2))
+    )
+    movement_scale = 1.0 / goal_distance_padding
+    half_res = resolution_px // 2
+    im_position = np.full(
+        (resolution_px, resolution_px, 3), 255, dtype=np.uint8
+    )
+
+    # Draw bands:
+    for scale, color in zip(target_band_radii, target_band_colors):
+        if goal_distance_padding * 4 > scale:
+            cv2.circle(
+                im_position,
+                (half_res, half_res),
+                max(2, int(half_res * scale * movement_scale)),
+                color,
+                thickness=-1,
+            )
+
+    # Draw such that the agent being inside the radius is the circles
+    # overlapping.
+    cv2.circle(
+        im_position,
+        (half_res, half_res),
+        max(2, int(half_res * goal_radius * movement_scale)),
+        (127, 0, 0),
+        thickness=-1,
+    )
+
+    relative_position = agent_position - goal_position
+    # swap x and z, remove y for (x,y,z) -> image coordinates.
+    relative_position = relative_position[[2, 0]]
+    relative_position *= half_res * movement_scale
+    relative_position += half_res
+    relative_position = np.round(relative_position).astype(np.int32)
+
+    # Draw the agent
+    draw_agent(im_position, relative_position, agent_heading, agent_radius_px)
+
+    # Rotate twice to fix coordinate system to upwards being positive-z.
+    # Rotate instead of flip to keep agent rotations in sync with egocentric
+    # view.
+    im_position = np.rot90(im_position, 2)
+    return im_position
+
+
+def _to_grid(
+    realworld_x: float,
+    realworld_y: float,
+    coordinate_min: float,
+    coordinate_max: float,
+    grid_resolution: Tuple[int, int],
+) -> Tuple[int, int]:
+    """Return gridworld index of realworld coordinates assuming top-left corner
+    is the origin. The real world coordinates of lower left corner are
+    (coordinate_min, coordinate_min) and of top right corner are
+    (coordinate_max, coordinate_max)
+    """
+    grid_size = (
+        (coordinate_max - coordinate_min) / grid_resolution[0],
+        (coordinate_max - coordinate_min) / grid_resolution[1],
+    )
+    grid_x = int((coordinate_max - realworld_x) / grid_size[0])
+    grid_y = int((realworld_y - coordinate_min) / grid_size[1])
+    return grid_x, grid_y
+
+
+def _from_grid(
+    grid_x: int,
+    grid_y: int,
+    coordinate_min: float,
+    coordinate_max: float,
+    grid_resolution: Tuple[int, int],
+) -> Tuple[float, float]:
+    """Inverse of _to_grid function. Return real world coordinate from gridworld
+    assuming top-left corner is the origin. The real world coordinates of lower
+    left corner are (coordinate_min, coordinate_min) and of top right corner
+    are (coordinate_max, coordinate_max)
+    """
+    grid_size = (
+        (coordinate_max - coordinate_min) / grid_resolution[0],
+        (coordinate_max - coordinate_min) / grid_resolution[1],
+    )
+    realworld_x = coordinate_max - grid_x * grid_size[0]
+    realworld_y = coordinate_min + grid_y * grid_size[1]
+    return realworld_x, realworld_y
+
+
+def _check_valid_nav_point(sim: Simulator, point: List[float]) -> bool:
+    """Checks if a given point is inside a wall or other object or not."""
+    return (
+        0.01
+        < sim.geodesic_distance(point, [point[0], point[1] + 0.1, point[2]])
+        < 0.11
+    )
+
+
+def _outline_border(top_down_map):
+    left_right_block_nav = (top_down_map[:, :-1] == 1) & (
+        top_down_map[:, :-1] != top_down_map[:, 1:]
+    )
+    left_right_nav_block = (top_down_map[:, 1:] == 1) & (
+        top_down_map[:, :-1] != top_down_map[:, 1:]
+    )
+
+    up_down_block_nav = (top_down_map[:-1] == 1) & (
+        top_down_map[:-1] != top_down_map[1:]
+    )
+    up_down_nav_block = (top_down_map[1:] == 1) & (
+        top_down_map[:-1] != top_down_map[1:]
+    )
+
+    top_down_map[:, :-1][left_right_block_nav] = 2
+    top_down_map[:, 1:][left_right_nav_block] = 2
+
+    top_down_map[:-1][up_down_block_nav] = 2
+    top_down_map[1:][up_down_nav_block] = 2
+
+
+def get_topdown_map(
+    sim: Simulator,
+    map_resolution: Tuple[int, int] = (1250, 1250),
+    num_samples: int = 20000,
+    draw_border: bool = True,
+) -> np.ndarray:
+    """Return a top-down occupancy map for a sim. Note, this only returns valid
+    values for whatever floor the agent is currently on.
+
+    Args:
+        sim: The simulator.
+        map_resolution: The resolution of map which will be computed and returned.
+        num_samples: The number of random navigable points which will be initially
+            sampled. For large environments it may need to be increased.
+        draw_border: Whether to outline the border of the occupied spaces.
+
+    Returns:
+        Image containing 0 if occupied, 1 if unoccupied, and 2 if border (if
+        the flag is set).
+    """
+    top_down_map = np.zeros(map_resolution, dtype=np.uint8)
+    grid_delta = 3
+
+    start_height = sim.get_agent_state().position[1]
+
+    # Use sampling to find the extrema points that might be navigable.
+    for _ in range(num_samples):
+        point = sim.sample_navigable_point()
+        # Check if on same level as original
+        if np.abs(start_height - point[1]) > 0.5:
+            continue
+        g_x, g_y = _to_grid(
+            point[0], point[2], COORDINATE_MIN, COORDINATE_MAX, map_resolution
+        )
+
+        top_down_map[g_x, g_y] = 1
+
+    range_x = np.where(np.any(top_down_map, axis=1))[0]
+    range_y = np.where(np.any(top_down_map, axis=0))[0]
+    # Pad the range just in case not enough points were sampled to get the true
+    # extrema.
+    padding = int(np.ceil(map_resolution[0] / 125))
+    range_x = (
+        max(range_x[0] - padding, 0),
+        min(range_x[-1] + padding + 1, top_down_map.shape[0]),
+    )
+    range_y = (
+        max(range_y[0] - padding, 0),
+        min(range_y[-1] + padding + 1, top_down_map.shape[1]),
+    )
+    top_down_map[:] = 0
+    # Search over grid for valid points.
+    for ii in range(range_x[0], range_x[1]):
+        for jj in range(range_y[0], range_y[1]):
+            realworld_x, realworld_y = _from_grid(
+                ii, jj, COORDINATE_MIN, COORDINATE_MAX, map_resolution
+            )
+            valid_point = _check_valid_nav_point(
+                sim, [realworld_x, start_height + 0.5, realworld_y]
+            )
+            if valid_point:
+                top_down_map[ii, jj] = 1
+
+    # Draw border if necessary
+    if draw_border:
+        # Recompute range in case padding added any more values.
+        range_x = np.where(np.any(top_down_map, axis=1))[0]
+        range_y = np.where(np.any(top_down_map, axis=0))[0]
+        range_x = (
+            max(range_x[0] - grid_delta, 0),
+            min(range_x[-1] + grid_delta + 1, top_down_map.shape[0]),
+        )
+        range_y = (
+            max(range_y[0] - grid_delta, 0),
+            min(range_y[-1] + grid_delta + 1, top_down_map.shape[1]),
+        )
+
+        _outline_border(
+            top_down_map[range_x[0] : range_x[1], range_y[0] : range_y[1]]
+        )
+    return top_down_map
diff --git a/habitat/utils/visualizations/utils.py b/habitat/utils/visualizations/utils.py
new file mode 100644
index 000000000..b4d5980bd
--- /dev/null
+++ b/habitat/utils/visualizations/utils.py
@@ -0,0 +1,86 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+from typing import Tuple, Optional
+
+
+def paste_overlapping_image(
+    background: np.ndarray,
+    foreground: np.ndarray,
+    location: Tuple[int, int],
+    mask: Optional[np.ndarray] = None,
+):
+    """Composites the foreground onto the background dealing with edge
+    boundaries.
+    Args:
+        background: the background image to paste on.
+        foreground: the image to paste. Can be RGB or RGBA. If using alpha
+            blending, values for foreground and background should both be
+            between 0 and 255. Otherwise behavior is undefined.
+        location: the image coordinates to paste the foreground.
+        mask: If not None, a mask for deciding what part of the foreground to
+            use. Must be the same size as the foreground if provided.
+    Returns:
+        The modified background image. This operation is in place.
+    """
+    assert mask is None or mask.shape[:2] == foreground.shape[:2]
+    foreground_size = foreground.shape[:2]
+    min_pad = (
+        max(0, foreground_size[0] // 2 - location[0]),
+        max(0, foreground_size[1] // 2 - location[1]),
+    )
+
+    max_pad = (
+        max(
+            0,
+            (location[0] + (foreground_size[0] - foreground_size[0] // 2))
+            - background.shape[0],
+        ),
+        max(
+            0,
+            (location[1] + (foreground_size[1] - foreground_size[1] // 2))
+            - background.shape[1],
+        ),
+    )
+
+    background_patch = background[
+        (location[0] - foreground_size[0] // 2 + min_pad[0]) : (
+            location[0]
+            + (foreground_size[0] - foreground_size[0] // 2)
+            - max_pad[0]
+        ),
+        (location[1] - foreground_size[1] // 2 + min_pad[1]) : (
+            location[1]
+            + (foreground_size[1] - foreground_size[1] // 2)
+            - max_pad[1]
+        ),
+    ]
+    foreground = foreground[
+        min_pad[0] : foreground.shape[0] - max_pad[0],
+        min_pad[1] : foreground.shape[1] - max_pad[1],
+    ]
+    if foreground.size == 0 or background_patch.size == 0:
+        # Nothing to do, no overlap.
+        return background
+
+    if mask is not None:
+        mask = mask[
+            min_pad[0] : foreground.shape[0] - max_pad[0],
+            min_pad[1] : foreground.shape[1] - max_pad[1],
+        ]
+
+    if foreground.shape[2] == 4:
+        # Alpha blending
+        foreground = (
+            background_patch.astype(np.int32) * (255 - foreground[:, :, [3]])
+            + foreground[:, :, :3].astype(np.int32) * foreground[:, :, [3]]
+        ) // 255
+    if mask is not None:
+        background_patch[mask] = foreground[mask]
+    else:
+        background_patch[:] = foreground
+    return background
diff --git a/requirements.txt b/requirements.txt
index fe169669a..f73903283 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,7 @@
-h5py
 gym==0.10.9
 numpy==1.15
 yacs>=0.1.5
+# visualization optional dependencies
+imageio>=2.2.0
+opencv-python>=3.3.0
+scipy>=1.0.0
diff --git a/test/test_baseline_agents.py b/test/test_baseline_agents.py
new file mode 100644
index 000000000..f849669e6
--- /dev/null
+++ b/test/test_baseline_agents.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from baselines.agents import simple_agents, ppo_agents
+import habitat
+import os
+import pytest
+
+CFG_TEST = "test/habitat_all_sensors_test.yaml"
+
+
+def test_ppo_agents():
+    config = ppo_agents.get_defaut_config()
+    config.MODEL_PATH = ""
+    config_env = habitat.get_config(config_file=CFG_TEST)
+    config_env.defrost()
+    if not os.path.exists(config_env.SIMULATOR.SCENE):
+        pytest.skip("Please download Habitat test data to data folder.")
+
+    benchmark = habitat.Benchmark(config_file=CFG_TEST, config_dir="configs")
+
+    for input_type in ["blind", "rgb", "depth", "rgbd"]:
+        config_env.defrost()
+        config_env.SIMULATOR.AGENT_0.SENSORS = []
+        if input_type in ["rgb", "rgbd"]:
+            config_env.SIMULATOR.AGENT_0.SENSORS += ["RGB_SENSOR"]
+        if input_type in ["depth", "rgbd"]:
+            config_env.SIMULATOR.AGENT_0.SENSORS += ["DEPTH_SENSOR"]
+        config_env.freeze()
+        del benchmark._env
+        benchmark._env = habitat.Env(config=config_env)
+        config.INPUT_TYPE = input_type
+
+        agent = ppo_agents.PPOAgent(config)
+        habitat.logger.info(benchmark.evaluate(agent, num_episodes=10))
+
+
+def test_simple_agents():
+    config_env = habitat.get_config(config_file=CFG_TEST)
+
+    if not os.path.exists(config_env.SIMULATOR.SCENE):
+        pytest.skip("Please download Habitat test data to data folder.")
+
+    benchmark = habitat.Benchmark(config_file=CFG_TEST, config_dir="configs")
+
+    for agent_class in [
+        simple_agents.ForwardOnlyAgent,
+        simple_agents.GoalFollower,
+        simple_agents.RandomAgent,
+        simple_agents.RandomForwardAgent,
+    ]:
+        agent = agent_class(config_env)
+        habitat.logger.info(agent_class.__name__)
+        habitat.logger.info(benchmark.evaluate(agent, num_episodes=100))
diff --git a/test/test_habitat_env.py b/test/test_habitat_env.py
index 75e3cb812..72d9a2380 100644
--- a/test/test_habitat_env.py
+++ b/test/test_habitat_env.py
@@ -19,7 +19,7 @@ from habitat.sims.habitat_simulator import (
     SIM_ACTION_TO_NAME,
     SIM_NAME_TO_ACTION,
 )
-from habitat.tasks.nav.nav_task import NavigationEpisode
+from habitat.tasks.nav.nav_task import NavigationEpisode, NavigationGoal
 
 CFG_TEST = "test/habitat_all_sensors_test.yaml"
 NUM_ENVS = 2
@@ -84,8 +84,6 @@ def _vec_env_test_fn(configs, datasets, multiprocessing_start_method):
         observations = envs.step(np.random.choice(non_stop_actions, num_envs))
         assert len(observations) == num_envs
 
-    envs.close()
-
 
 def test_vectorized_envs_forkserver():
     configs, datasets = _load_test_data()
@@ -112,6 +110,18 @@ def test_vectorized_envs_fork():
     assert p.exitcode == 0
 
 
+def test_with_scope():
+    configs, datasets = _load_test_data()
+    num_envs = len(configs)
+    env_fn_args = tuple(zip(configs, datasets, range(num_envs)))
+    with habitat.VectorEnv(
+        env_fn_args=env_fn_args, multiprocessing_start_method="forkserver"
+    ) as envs:
+        envs.reset()
+
+    assert envs._is_closed
+
+
 def test_threaded_vectorized_env():
     configs, datasets = _load_test_data()
     num_envs = len(configs)
@@ -140,9 +150,10 @@ def test_env():
         NavigationEpisode(
             episode_id="0",
             scene_id=config.SIMULATOR.SCENE,
-            start_position=[03.00611, 0.072447, -2.67867],
+            start_position=[3.00611, 0.072447, -2.67867],
             start_rotation=[0, 0.163276, 0, 0.98658],
-            goals=[],
+            goals=[NavigationGoal([3.00611, 0.072447, -2.67867])],
+            info={"geodesic_distance": 0.001},
         )
     ]
 
@@ -219,9 +230,10 @@ def test_rl_env():
         NavigationEpisode(
             episode_id="0",
             scene_id=config.SIMULATOR.SCENE,
-            start_position=[03.00611, 0.072447, -2.67867],
+            start_position=[3.00611, 0.072447, -2.67867],
             start_rotation=[0, 0.163276, 0, 0.98658],
-            goals=[],
+            goals=[NavigationGoal([3.00611, 0.072447, -2.67867])],
+            info={"geodesic_distance": 0.001},
         )
     ]
 
@@ -276,6 +288,8 @@ def test_action_space_shortest_path():
 
     while len(unreachable_targets) < 3:
         position = env.sim.sample_navigable_point()
+        # Change height of the point to make it unreachable
+        position[1] = 100
         angles = [x for x in range(-180, 180, config.SIMULATOR.TURN_ANGLE)]
         angle = np.radians(np.random.choice(angles))
         rotation = [0, np.sin(angle / 2), 0, np.cos(angle / 2)]
@@ -289,26 +303,4 @@ def test_action_space_shortest_path():
     targets = unreachable_targets
     shortest_path2 = env.sim.action_space_shortest_path(source, targets)
     assert shortest_path2 == []
-
-    targets = reachable_targets + unreachable_targets
-    shortest_path3 = env.sim.action_space_shortest_path(source, targets)
-
-    # shortest_path1 should be identical to shortest_path3
-    assert len(shortest_path1) == len(shortest_path3)
-    for i in range(len(shortest_path1)):
-        assert np.allclose(
-            shortest_path1[i].position, shortest_path3[i].position
-        )
-        assert np.allclose(
-            shortest_path1[i].rotation, shortest_path3[i].rotation
-        )
-        assert shortest_path1[i].action == shortest_path3[i].action
-
-    targets = unreachable_targets + [source]
-    shortest_path4 = env.sim.action_space_shortest_path(source, targets)
-    assert len(shortest_path4) == 1
-    assert np.allclose(shortest_path4[0].position, source.position)
-    assert np.allclose(shortest_path4[0].rotation, source.rotation)
-    assert shortest_path4[0].action is None
-
     env.close()
diff --git a/test/test_habitat_example.py b/test/test_habitat_example.py
index 11f193446..227014d60 100644
--- a/test/test_habitat_example.py
+++ b/test/test_habitat_example.py
@@ -9,6 +9,7 @@ import pytest
 import habitat
 from examples.example import example
 from habitat.datasets.pointnav.pointnav_dataset import PointNavDatasetV1
+from examples import visualization_examples
 
 
 def test_readme_example():
@@ -17,3 +18,11 @@ def test_readme_example():
     ):
         pytest.skip("Please download Habitat test data to data folder.")
     example()
+
+
+def test_visualizations_example():
+    if not PointNavDatasetV1.check_config_paths_exist(
+        config=habitat.get_config().DATASET
+    ):
+        pytest.skip("Please download Habitat test data to data folder.")
+    visualization_examples.main()
diff --git a/test/test_sensors.py b/test/test_sensors.py
new file mode 100644
index 000000000..d6f8d1b50
--- /dev/null
+++ b/test/test_sensors.py
@@ -0,0 +1,53 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import os
+import pytest
+import random
+
+import habitat
+from habitat.config.default import get_config
+from habitat.tasks.nav.nav_task import NavigationEpisode
+
+CFG_TEST = "test/habitat_all_sensors_test.yaml"
+
+
+def test_heading_sensor():
+    config = get_config(CFG_TEST)
+    if not os.path.exists(config.SIMULATOR.SCENE):
+        pytest.skip("Please download Habitat test data to data folder.")
+    config = get_config()
+    config.defrost()
+    config.TASK.SENSORS = ["HEADING_SENSOR"]
+    config.freeze()
+    env = habitat.Env(config=config, dataset=None)
+    env.reset()
+    random.seed(1234)
+
+    for _ in range(100):
+        random_heading = np.random.uniform(-np.pi, np.pi)
+        random_rotation = [
+            0,
+            np.sin(random_heading / 2),
+            0,
+            np.cos(random_heading / 2),
+        ]
+        env.episodes = [
+            NavigationEpisode(
+                episode_id="0",
+                scene_id=config.SIMULATOR.SCENE,
+                start_position=[03.00611, 0.072447, -2.67867],
+                start_rotation=random_rotation,
+                goals=[],
+            )
+        ]
+
+        obs = env.reset()
+        heading = obs["heading"]
+        assert np.allclose(heading, random_heading)
+
+    env.close()
diff --git a/test/test_trajectory_sim.py b/test/test_trajectory_sim.py
index e50384d09..9123cbf11 100644
--- a/test/test_trajectory_sim.py
+++ b/test/test_trajectory_sim.py
@@ -21,7 +21,7 @@ def init_sim():
     return make_sim(config.SIMULATOR.TYPE, config=config.SIMULATOR)
 
 
-def test_sim():
+def test_sim_trajectory():
     with open("test/data/habitat-sim_trajectory_data.json", "r") as f:
         test_trajectory = json.load(f)
     sim = init_sim()
@@ -61,3 +61,14 @@ def test_sim():
             assert sim.is_episode_active is False
 
     sim.close()
+
+
+def test_sim_no_sensors():
+    config = get_config()
+    config.defrost()
+    config.SIMULATOR.AGENT_0.SENSORS = []
+    if not os.path.exists(config.SIMULATOR.SCENE):
+        pytest.skip("Please download Habitat test data to data folder.")
+    sim = make_sim(config.SIMULATOR.TYPE, config=config.SIMULATOR)
+    sim.reset()
+    sim.close()
-- 
GitLab