From 8da461f537d950fe450aa14d0070d0816c789605 Mon Sep 17 00:00:00 2001 From: Facebook Community Bot <facebook-github-bot@users.noreply.github.com> Date: Wed, 3 Apr 2019 04:25:15 -0700 Subject: [PATCH] Benchmark and challenge implementations; PPO agents, Random, ForwardOnly, GoalFollower agents; High resolution top down map and bird-eye visualizations; Collisions measure introduced; SPL success criteria fix; Added deleter that closes subprocesses --- Dockerfile | 5 +- baselines/agents/__init__.py | 0 baselines/agents/ppo_agents.py | 140 ++++++++ baselines/agents/simple_agents.py | 155 +++++++++ baselines/rl/ppo/__init__.py | 6 +- baselines/rl/ppo/policy.py | 128 ++++--- configs/test/habitat_all_sensors_test.yaml | 17 + examples/visualization_examples.py | 107 ++++++ habitat/__init__.py | 2 + habitat/config/default.py | 5 + habitat/core/benchmark.py | 11 +- habitat/core/challenge.py | 21 ++ habitat/core/env.py | 12 +- habitat/core/simulator.py | 1 + habitat/core/vector_env.py | 17 + habitat/sims/habitat_simulator.py | 7 + habitat/tasks/nav/nav_task.py | 56 +++- habitat/utils/__init__.py | 7 + habitat/utils/visualizations/__init__.py | 9 + .../maps_topdown_agent_sprite/100x100.png | Bin 0 -> 5616 bytes habitat/utils/visualizations/maps.py | 314 ++++++++++++++++++ habitat/utils/visualizations/utils.py | 86 +++++ requirements.txt | 5 +- test/test_baseline_agents.py | 57 ++++ test/test_habitat_env.py | 50 ++- test/test_habitat_example.py | 9 + test/test_sensors.py | 53 +++ test/test_trajectory_sim.py | 13 +- 28 files changed, 1191 insertions(+), 102 deletions(-) create mode 100644 baselines/agents/__init__.py create mode 100644 baselines/agents/ppo_agents.py create mode 100644 baselines/agents/simple_agents.py create mode 100644 examples/visualization_examples.py create mode 100644 habitat/core/challenge.py create mode 100644 habitat/utils/__init__.py create mode 100644 habitat/utils/visualizations/__init__.py create mode 100644 habitat/utils/visualizations/assets/maps_topdown_agent_sprite/100x100.png create mode 100644 habitat/utils/visualizations/maps.py create mode 100644 habitat/utils/visualizations/utils.py create mode 100644 test/test_baseline_agents.py create mode 100644 test/test_sensors.py diff --git a/Dockerfile b/Dockerfile index 211bfb00b..d8dcc0b5b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -31,9 +31,9 @@ RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-la ENV PATH /opt/conda/bin:$PATH # Install cmake -ADD cmake-3.13.3-Linux-x86_64.sh /cmake-3.13.3-Linux-x86_64.sh +RUN wget https://github.com/Kitware/CMake/releases/download/v3.13.4/cmake-3.13.4-Linux-x86_64.sh RUN mkdir /opt/cmake -RUN sh /cmake-3.13.3-Linux-x86_64.sh --prefix=/opt/cmake --skip-license +RUN sh /cmake-3.13.4-Linux-x86_64.sh --prefix=/opt/cmake --skip-license RUN ln -s /opt/cmake/bin/cmake /usr/local/bin/cmake RUN cmake --version @@ -56,4 +56,3 @@ RUN rm habitat-test-scenes.zip # Silence habitat-sim logs ENV GLOG_minloglevel=2 ENV MAGNUM_LOG="quiet" - diff --git a/baselines/agents/__init__.py b/baselines/agents/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/baselines/agents/ppo_agents.py b/baselines/agents/ppo_agents.py new file mode 100644 index 000000000..9a0080bed --- /dev/null +++ b/baselines/agents/ppo_agents.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import random + +import numpy as np +import torch +from gym.spaces import Discrete, Dict, Box + +import habitat +from baselines.rl.ppo import Policy +from baselines.rl.ppo.utils import batch_obs +from habitat import Config +from habitat.core.agent import Agent + + +def get_defaut_config(): + c = Config() + c.INPUT_TYPE = "blind" + c.MODEL_PATH = "data/checkpoints/blind.pth" + c.RESOLUTION = 256 + c.HIDDEN_SIZE = 512 + c.RANDOM_SEED = 7 + c.PTH_GPU_ID = 0 + return c + + +class PPOAgent(Agent): + def __init__(self, config: Config): + spaces = { + "pointgoal": Box( + low=np.finfo(np.float32).min, + high=np.finfo(np.float32).max, + shape=(2,), + dtype=np.float32, + ) + } + + if config.INPUT_TYPE in ["depth", "rgbd"]: + spaces["depth"] = Box( + low=0, + high=1, + shape=(config.RESOLUTION, config.RESOLUTION, 1), + dtype=np.float32, + ) + + if config.INPUT_TYPE in ["rgb", "rgbd"]: + spaces["rgb"] = Box( + low=0, + high=255, + shape=(config.RESOLUTION, config.RESOLUTION, 3), + dtype=np.uint8, + ) + observation_spaces = Dict(spaces) + + action_spaces = Discrete(4) + + self.device = torch.device("cuda:{}".format(config.PTH_GPU_ID)) + self.hidden_size = config.HIDDEN_SIZE + + random.seed(config.RANDOM_SEED) + torch.random.manual_seed(config.RANDOM_SEED) + torch.backends.cudnn.deterministic = True + + self.actor_critic = Policy( + observation_space=observation_spaces, + action_space=action_spaces, + hidden_size=self.hidden_size, + ) + self.actor_critic.to(self.device) + + if config.MODEL_PATH: + ckpt = torch.load(config.MODEL_PATH, map_location=self.device) + # Filter only actor_critic weights + self.actor_critic.load_state_dict( + { + k.replace("actor_critic.", ""): v + for k, v in ckpt["state_dict"].items() + if "actor_critic" in k + } + ) + + else: + habitat.logger.error( + "Model checkpoint wasn't loaded, evaluating " "a random model." + ) + + self.test_recurrent_hidden_states = None + self.not_done_masks = None + + def reset(self): + self.test_recurrent_hidden_states = torch.zeros( + 1, self.hidden_size, device=self.device + ) + self.not_done_masks = torch.zeros(1, 1, device=self.device) + + def act(self, observations): + batch = batch_obs([observations]) + for sensor in batch: + batch[sensor] = batch[sensor].to(self.device) + + with torch.no_grad(): + _, actions, _, self.test_recurrent_hidden_states = self.actor_critic.act( + batch, + self.test_recurrent_hidden_states, + self.not_done_masks, + deterministic=False, + ) + # Make masks not done till reset (end of episode) will be called + self.not_done_masks = torch.ones(1, 1, device=self.device) + + return actions[0][0].item() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_type", + default="blind", + choices=["blind", "rgb", "depth", "rgbd"], + ) + parser.add_argument("--model_path", default="", type=str) + args = parser.parse_args() + + config = get_defaut_config() + config.INPUT_TYPE = args.input_type + config.MODEL_PATH = args.model_path + + agent = PPOAgent(config) + challenge = habitat.Challenge() + challenge.submit(agent) + + +if __name__ == "__main__": + main() diff --git a/baselines/agents/simple_agents.py b/baselines/agents/simple_agents.py new file mode 100644 index 000000000..6e83cd0c9 --- /dev/null +++ b/baselines/agents/simple_agents.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import random +from math import pi + +import numpy as np + +import habitat +from habitat.sims.habitat_simulator import ( + SimulatorActions, + SIM_ACTION_TO_NAME, + SIM_NAME_TO_ACTION, +) + +NON_STOP_ACTIONS = [ + k + for k, v in SIM_ACTION_TO_NAME.items() + if v != SimulatorActions.STOP.value +] + + +class RandomAgent(habitat.Agent): + def __init__(self, config): + self.dist_threshold_to_stop = config.TASK.SUCCESS_DISTANCE + + def reset(self): + pass + + def act(self, observations): + action = SIM_NAME_TO_ACTION[SimulatorActions.FORWARD.value] + return action + + def is_goal_reached(self, observations): + dist = observations["pointgoal"][0] + return dist <= self.dist_threshold_to_stop + + def act(self, observations): + if self.is_goal_reached(observations): + action = SIM_NAME_TO_ACTION[SimulatorActions.STOP.value] + else: + action = np.random.choice(NON_STOP_ACTIONS) + return action + + +class ForwardOnlyAgent(RandomAgent): + def act(self, observations): + if self.is_goal_reached(observations): + action = SIM_NAME_TO_ACTION[SimulatorActions.STOP.value] + else: + action = SIM_NAME_TO_ACTION[SimulatorActions.FORWARD.value] + return action + + +class RandomForwardAgent(RandomAgent): + def __init__(self, config): + super(RandomForwardAgent, self).__init__(config) + self.dist_threshold_to_stop = config.TASK.SUCCESS_DISTANCE + self.FORWARD_PROBABILITY = 0.8 + + def act(self, observations): + if self.is_goal_reached(observations): + action = SIM_NAME_TO_ACTION[SimulatorActions.STOP.value] + else: + if np.random.uniform(0, 1, 1) < self.FORWARD_PROBABILITY: + action = SIM_NAME_TO_ACTION[SimulatorActions.FORWARD.value] + else: + action = np.random.choice( + [ + SIM_NAME_TO_ACTION[SimulatorActions.LEFT.value], + SIM_NAME_TO_ACTION[SimulatorActions.RIGHT.value], + ] + ) + + return action + + +class GoalFollower(RandomAgent): + def __init__(self, config): + super(GoalFollower, self).__init__(config) + self.pos_th = self.dist_threshold_to_stop + self.angle_th = float(np.deg2rad(15)) + self.random_prob = 0 + + def normalize_angle(self, angle): + if angle < -pi: + angle = 2.0 * pi + angle + if angle > pi: + angle = -2.0 * pi + angle + return angle + + def turn_towards_goal(self, angle_to_goal): + if angle_to_goal > pi or ( + (angle_to_goal < 0) and (angle_to_goal > -pi) + ): + action = SIM_NAME_TO_ACTION[SimulatorActions.RIGHT.value] + else: + action = SIM_NAME_TO_ACTION[SimulatorActions.LEFT.value] + return action + + def act(self, observations): + if self.is_goal_reached(observations): + action = SIM_NAME_TO_ACTION[SimulatorActions.STOP.value] + else: + angle_to_goal = self.normalize_angle( + np.array(observations["pointgoal"][1]) + ) + if abs(angle_to_goal) < self.angle_th: + action = SIM_NAME_TO_ACTION[SimulatorActions.FORWARD.value] + else: + action = self.turn_towards_goal(angle_to_goal) + + return action + + +def get_all_subclasses(cls): + return set(cls.__subclasses__()).union( + [s for c in cls.__subclasses__() for s in get_all_subclasses(c)] + ) + + +def get_agent_cls(agent_class_name): + sub_classes = [ + sub_class + for sub_class in get_all_subclasses(habitat.Agent) + if sub_class.__name__ == agent_class_name + ] + return sub_classes[0] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--task-config", type=str, default="tasks/pointnav.yaml" + ) + parser.add_argument("--agent_class", type=str, default="GoalFollower") + args = parser.parse_args() + + agent = get_agent_cls(args.agent_class)( + habitat.get_config(args.task_config) + ) + benchmark = habitat.Benchmark(args.task_config) + metrics = benchmark.evaluate(agent) + + for k, v in metrics.items(): + habitat.logger.info("{}: {:.3f}".format(k, v)) + + +if __name__ == "__main__": + main() diff --git a/baselines/rl/ppo/__init__.py b/baselines/rl/ppo/__init__.py index 843a232a1..248f2b30f 100644 --- a/baselines/rl/ppo/__init__.py +++ b/baselines/rl/ppo/__init__.py @@ -4,8 +4,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from rl.ppo.ppo import PPO -from rl.ppo.policy import Policy -from rl.ppo.utils import RolloutStorage +from baselines.rl.ppo.ppo import PPO +from baselines.rl.ppo.policy import Policy +from baselines.rl.ppo.utils import RolloutStorage __all__ = ["PPO", "Policy", "RolloutStorage"] diff --git a/baselines/rl/ppo/policy.py b/baselines/rl/ppo/policy.py index dea82476c..2c85d8246 100644 --- a/baselines/rl/ppo/policy.py +++ b/baselines/rl/ppo/policy.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn -from rl.ppo.utils import Flatten, CategoricalNet +from baselines.rl.ppo.utils import Flatten, CategoricalNet import numpy as np @@ -67,20 +67,34 @@ class Net(nn.Module): def __init__(self, observation_space, hidden_size): super().__init__() + self._n_input_goal = observation_space.spaces["pointgoal"].shape[0] + self._hidden_size = hidden_size + + self.cnn = self._init_perception_model(observation_space) + + if self.is_blind: + self.rnn = nn.GRU(self._n_input_goal, self._hidden_size) + else: + self.rnn = nn.GRU( + self.output_size + self._n_input_goal, self._hidden_size + ) + + self.critic_linear = nn.Linear(self._hidden_size, 1) + + self.layer_init() + self.train() + + def _init_perception_model(self, observation_space): if "rgb" in observation_space.spaces: self._n_input_rgb = observation_space.spaces["rgb"].shape[2] else: self._n_input_rgb = 0 + if "depth" in observation_space.spaces: self._n_input_depth = observation_space.spaces["depth"].shape[2] else: self._n_input_depth = 0 - assert self._n_input_rgb + self._n_input_depth > 0 - - self._n_input_goal = observation_space.spaces["pointgoal"].shape[0] - self._hidden_size = hidden_size - # kernel size for different CNN layers self._cnn_layers_kernel_size = [(8, 8), (4, 4), (3, 3)] @@ -91,52 +105,50 @@ class Net(nn.Module): cnn_dims = np.array( observation_space.spaces["rgb"].shape[:2], dtype=np.float32 ) - else: + elif self._n_input_depth > 0: cnn_dims = np.array( observation_space.spaces["depth"].shape[:2], dtype=np.float32 ) - for kernel_size, stride in zip( - self._cnn_layers_kernel_size, self._cnn_layers_stride - ): - cnn_dims = self._conv_output_dim( - dimension=cnn_dims, - padding=np.array([0, 0], dtype=np.float32), - dilation=np.array([1, 1], dtype=np.float32), - kernel_size=np.array(kernel_size, dtype=np.float32), - stride=np.array(stride, dtype=np.float32), - ) - - self.cnn = nn.Sequential( - nn.Conv2d( - in_channels=self._n_input_rgb + self._n_input_depth, - out_channels=32, - kernel_size=self._cnn_layers_kernel_size[0], - stride=self._cnn_layers_stride[0], - ), - nn.ReLU(), - nn.Conv2d( - in_channels=32, - out_channels=64, - kernel_size=self._cnn_layers_kernel_size[1], - stride=self._cnn_layers_stride[1], - ), - nn.ReLU(), - nn.Conv2d( - in_channels=64, - out_channels=32, - kernel_size=self._cnn_layers_kernel_size[2], - stride=self._cnn_layers_stride[2], - ), - Flatten(), - nn.Linear(32 * cnn_dims[0] * cnn_dims[1], hidden_size), - nn.ReLU(), - ) - self.rnn = nn.GRU(hidden_size + self._n_input_goal, hidden_size) - self.critic_linear = nn.Linear(hidden_size, 1) + if self.is_blind: + return nn.Sequential() + else: + for kernel_size, stride in zip( + self._cnn_layers_kernel_size, self._cnn_layers_stride + ): + cnn_dims = self._conv_output_dim( + dimension=cnn_dims, + padding=np.array([0, 0], dtype=np.float32), + dilation=np.array([1, 1], dtype=np.float32), + kernel_size=np.array(kernel_size, dtype=np.float32), + stride=np.array(stride, dtype=np.float32), + ) - self.layer_init() - self.train() + return nn.Sequential( + nn.Conv2d( + in_channels=self._n_input_rgb + self._n_input_depth, + out_channels=32, + kernel_size=self._cnn_layers_kernel_size[0], + stride=self._cnn_layers_stride[0], + ), + nn.ReLU(), + nn.Conv2d( + in_channels=32, + out_channels=64, + kernel_size=self._cnn_layers_kernel_size[1], + stride=self._cnn_layers_stride[1], + ), + nn.ReLU(), + nn.Conv2d( + in_channels=64, + out_channels=32, + kernel_size=self._cnn_layers_kernel_size[2], + stride=self._cnn_layers_stride[2], + ), + Flatten(), + nn.Linear(32 * cnn_dims[0] * cnn_dims[1], self._hidden_size), + nn.ReLU(), + ) def _conv_output_dim( self, dimension, padding, dilation, kernel_size, stride @@ -240,28 +252,36 @@ class Net(nn.Module): return x, hidden_states - def forward(self, observations, rnn_hidden_states, masks): + @property + def is_blind(self): + return self._n_input_rgb + self._n_input_depth == 0 + + def forward_perception_model(self, observations): cnn_input = [] if self._n_input_rgb > 0: rgb_observations = observations["rgb"] # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] rgb_observations = rgb_observations.permute(0, 3, 1, 2) rgb_observations = rgb_observations / 255.0 # normalize RGB - cnn_input.append(rgb_observations) + if self._n_input_depth > 0: depth_observations = observations["depth"] - # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] depth_observations = depth_observations.permute(0, 3, 1, 2) - cnn_input.append(depth_observations) cnn_input = torch.cat(cnn_input, dim=1) - goal_observations = observations["pointgoal"] - x = self.cnn(cnn_input) - x = torch.cat([x, goal_observations], dim=1) # concatenate goal vector + return self.cnn(cnn_input) + + def forward(self, observations, rnn_hidden_states, masks): + x = observations["pointgoal"] + + if not self.is_blind: + perception_embed = self.forward_perception_model(observations) + x = torch.cat([perception_embed, x], dim=1) + x, rnn_hidden_states = self.forward_rnn(x, rnn_hidden_states, masks) return self.critic_linear(x), x, rnn_hidden_states diff --git a/configs/test/habitat_all_sensors_test.yaml b/configs/test/habitat_all_sensors_test.yaml index 79375eef1..1ea4783fb 100644 --- a/configs/test/habitat_all_sensors_test.yaml +++ b/configs/test/habitat_all_sensors_test.yaml @@ -3,8 +3,25 @@ ENVIRONMENT: SIMULATOR: AGENT_0: SENSORS: ['RGB_SENSOR', 'DEPTH_SENSOR'] + RGB_SENSOR: + WIDTH: 256 + HEIGHT: 256 + DEPTH_SENSOR: + WIDTH: 256 + HEIGHT: 256 DATASET: TYPE: PointNav-v1 SPLIT: train POINTNAVV1: DATA_PATH: data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz +TASK: + TYPE: Nav-v0 + SUCCESS_DISTANCE: 0.2 + SENSORS: ['POINTGOAL_SENSOR'] + POINTGOAL_SENSOR: + TYPE: PointGoalSensor + GOAL_FORMAT: POLAR + MEASUREMENTS: ['SPL'] + SPL: + TYPE: SPL + SUCCESS_DISTANCE: 0.2 \ No newline at end of file diff --git a/examples/visualization_examples.py b/examples/visualization_examples.py new file mode 100644 index 000000000..ab5e4d281 --- /dev/null +++ b/examples/visualization_examples.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import numpy as np +import imageio + +import habitat +from habitat.tasks.nav.nav_task import NavigationEpisode, NavigationGoal +from habitat.utils.visualizations import maps + + +def example_pointnav_draw_target_birdseye_view(): + goal_radius = 0.5 + goal = NavigationGoal([10, 0.25, 10], goal_radius) + agent_position = np.array([0, 0.25, 0]) + agent_rotation = -np.pi / 4 + + dummy_episode = NavigationEpisode( + [goal], + episode_id="dummy_id", + scene_id="dummy_scene", + start_position=agent_position, + start_rotation=agent_rotation, + ) + target_image = maps.pointnav_draw_target_birdseye_view( + agent_position, + agent_rotation, + np.asarray(dummy_episode.goals[0].position), + goal_radius=dummy_episode.goals[0].radius, + agent_radius_px=25, + ) + + imageio.imsave("pointnav_target_image.png", target_image) + + +def example_pointnav_draw_target_birdseye_view_agent_on_border(): + goal_radius = 0.5 + goal = NavigationGoal([0, 0.25, 0], goal_radius) + ii = 0 + for x_edge in [-1, 0, 1]: + for y_edge in [-1, 0, 1]: + if not np.bitwise_xor(x_edge == 0, y_edge == 0): + continue + ii += 1 + agent_position = np.array([7.8 * x_edge, 0.25, 7.8 * y_edge]) + agent_rotation = np.pi / 2 + + dummy_episode = NavigationEpisode( + [goal], + episode_id="dummy_id", + scene_id="dummy_scene", + start_position=agent_position, + start_rotation=agent_rotation, + ) + target_image = maps.pointnav_draw_target_birdseye_view( + agent_position, + agent_rotation, + np.asarray(dummy_episode.goals[0].position), + goal_radius=dummy_episode.goals[0].radius, + agent_radius_px=25, + ) + imageio.imsave( + "pointnav_target_image_edge_%d.png" % ii, target_image + ) + + +def example_get_topdown_map(): + config = habitat.get_config(config_file="tasks/pointnav.yaml") + dataset = habitat.make_dataset( + id_dataset=config.DATASET.TYPE, config=config.DATASET + ) + env = habitat.Env(config=config, dataset=dataset) + env.reset() + top_down_map = maps.get_topdown_map(env.sim, map_resolution=(5000, 5000)) + recolor_map = np.array( + [[255, 255, 255], [128, 128, 128], [0, 0, 0]], dtype=np.uint8 + ) + range_x = np.where(np.any(top_down_map, axis=1))[0] + range_y = np.where(np.any(top_down_map, axis=0))[0] + padding = int(np.ceil(top_down_map.shape[0] / 125)) + range_x = ( + max(range_x[0] - padding, 0), + min(range_x[-1] + padding + 1, top_down_map.shape[0]), + ) + range_y = ( + max(range_y[0] - padding, 0), + min(range_y[-1] + padding + 1, top_down_map.shape[1]), + ) + top_down_map = top_down_map[ + range_x[0] : range_x[1], range_y[0] : range_y[1] + ] + top_down_map = recolor_map[top_down_map] + imageio.imsave("top_down_map.png", top_down_map) + + +def main(): + example_pointnav_draw_target_birdseye_view() + example_get_topdown_map() + example_pointnav_draw_target_birdseye_view_agent_on_border() + + +if __name__ == "__main__": + main() diff --git a/habitat/__init__.py b/habitat/__init__.py index 04d5e3cf7..847fb295c 100644 --- a/habitat/__init__.py +++ b/habitat/__init__.py @@ -6,6 +6,7 @@ from habitat.core.agent import Agent from habitat.core.benchmark import Benchmark +from habitat.core.challenge import Challenge from habitat.config import Config, get_config from habitat.core.dataset import Dataset from habitat.core.embodied_task import EmbodiedTask, Measure, Measurements @@ -19,6 +20,7 @@ from habitat.version import VERSION as __version__ # noqa __all__ = [ "Agent", "Benchmark", + "Challenge", "Config", "Dataset", "EmbodiedTask", diff --git a/habitat/config/default.py b/habitat/config/default.py index 10ac1d886..8bd9189b2 100644 --- a/habitat/config/default.py +++ b/habitat/config/default.py @@ -37,6 +37,11 @@ _C.TASK.POINTGOAL_SENSOR = CN() _C.TASK.POINTGOAL_SENSOR.TYPE = "PointGoalSensor" _C.TASK.POINTGOAL_SENSOR.GOAL_FORMAT = "POLAR" # ----------------------------------------------------------------------------- +# # HEADING SENSOR +# ----------------------------------------------------------------------------- +_C.TASK.HEADING_SENSOR = CN() +_C.TASK.HEADING_SENSOR.TYPE = "HeadingSensor" +# ----------------------------------------------------------------------------- # # SPL MEASUREMENT # ----------------------------------------------------------------------------- _C.TASK.SPL = CN() diff --git a/habitat/core/benchmark.py b/habitat/core/benchmark.py index eaf97ad11..cb49a6b81 100644 --- a/habitat/core/benchmark.py +++ b/habitat/core/benchmark.py @@ -7,8 +7,8 @@ from collections import defaultdict from typing import Dict, Optional +from habitat.config.default import get_config, DEFAULT_CONFIG_DIR from habitat.core.agent import Agent -from habitat.config.default import get_config from habitat.core.env import Env @@ -18,10 +18,15 @@ class Benchmark: Args: config_file: file to be used for creating the environment. + config_dir: directory where config_file is located. """ - def __init__(self, config_file: Optional[str] = None) -> None: - config_env = get_config(config_file=config_file) + def __init__( + self, + config_file: Optional[str] = None, + config_dir: str = DEFAULT_CONFIG_DIR, + ) -> None: + config_env = get_config(config_file=config_file, config_dir=config_dir) self._env = Env(config=config_env) def evaluate( diff --git a/habitat/core/challenge.py b/habitat/core/challenge.py new file mode 100644 index 000000000..5094333e5 --- /dev/null +++ b/habitat/core/challenge.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os + +from habitat.core.benchmark import Benchmark +from habitat.core.logging import logger + + +class Challenge(Benchmark): + def __init__(self): + config_file = os.environ["CHALLENGE_CONFIG_FILE"] + super().__init__(config_file) + + def submit(self, agent): + metrics = super().evaluate(agent) + for k, v in metrics.items(): + logger.info("{}: {}".format(k, v)) diff --git a/habitat/core/env.py b/habitat/core/env.py index d48695d23..cfa3438c2 100644 --- a/habitat/core/env.py +++ b/habitat/core/env.py @@ -68,6 +68,16 @@ class Env: id_dataset=config.DATASET.TYPE, config=config.DATASET ) self._episodes = self._dataset.episodes if self._dataset else [] + + # load the first scene if dataset is present + if self._dataset: + assert len(self._dataset.episodes) > 0, ( + "dataset should have " "non-empty episodes list" + ) + self._config.defrost() + self._config.SIMULATOR.SCENE = self._dataset.episodes[0].scene_id + self._config.freeze() + self._sim = make_sim( id_sim=self._config.SIMULATOR.TYPE, config=self._config.SIMULATOR ) @@ -75,7 +85,7 @@ class Env: self._config.TASK.TYPE, task_config=self._config.TASK, sim=self._sim, - dataset=dataset, + dataset=self._dataset, ) self.observation_space = SpaceDict( { diff --git a/habitat/core/simulator.py b/habitat/core/simulator.py index 6cd0c1cba..fbeaac1e9 100644 --- a/habitat/core/simulator.py +++ b/habitat/core/simulator.py @@ -28,6 +28,7 @@ class SensorTypes(Enum): TENSOR = 8 TEXT = 9 MEASUREMENT = 10 + HEADING = 11 class Sensor: diff --git a/habitat/core/vector_env.py b/habitat/core/vector_env.py index 0451c8a16..4618e77c7 100644 --- a/habitat/core/vector_env.py +++ b/habitat/core/vector_env.py @@ -83,6 +83,7 @@ class VectorEnv: ) -> None: self._is_waiting = False + self._is_closed = True assert ( env_fn_args is not None and len(env_fn_args) > 0 @@ -103,6 +104,8 @@ class VectorEnv: env_fn_args, make_env_fn ) + self._is_closed = False + for write_fn in self._connection_write_fns: write_fn((OBSERVATION_SPACE_COMMAND, None)) self.observation_spaces = [ @@ -294,6 +297,9 @@ class VectorEnv: return self.wait_step() def close(self) -> None: + if self._is_closed: + return + if self._is_waiting: for read_fn in self._connection_read_fns: read_fn() @@ -302,6 +308,8 @@ class VectorEnv: for process in self._workers: process.join() + self._is_closed = True + def render( self, mode: str = "human", *args, **kwargs ) -> Union[np.ndarray, None]: @@ -326,6 +334,15 @@ class VectorEnv: def _valid_start_methods(self) -> Set[str]: return {"forkserver", "spawn", "fork"} + def __del__(self): + self.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + class ThreadedVectorEnv(VectorEnv): def _spawn_workers( diff --git a/habitat/sims/habitat_simulator.py b/habitat/sims/habitat_simulator.py index d6d29b27e..3a5311454 100644 --- a/habitat/sims/habitat_simulator.py +++ b/habitat/sims/habitat_simulator.py @@ -198,6 +198,13 @@ class HabitatSim(habitat.Simulator): sim_sensor_cfg.sensor_type = sensor.sim_sensor_type # type: ignore sensor_specifications.append(sim_sensor_cfg) + # If there is no sensors specified create a dummy sensor so simulator + # won't throw an error + if not _sensor_suite.sensors.values(): + sim_sensor_cfg = habitat_sim.SensorSpec() + sim_sensor_cfg.resolution = [1, 1] + sensor_specifications.append(sim_sensor_cfg) + agent_config.sensor_specifications = sensor_specifications agent_config.action_space = { SimulatorActions.LEFT.value: habitat_sim.ActionSpec( diff --git a/habitat/tasks/nav/nav_task.py b/habitat/tasks/nav/nav_task.py index b0737e4d6..bf44de919 100644 --- a/habitat/tasks/nav/nav_task.py +++ b/habitat/tasks/nav/nav_task.py @@ -21,7 +21,9 @@ from habitat.core.simulator import ( from habitat.tasks.utils import quaternion_to_rotation, cartesian_to_polar -def merge_sim_episode_config(sim_config: Any, episode: Type[Episode]) -> Any: +def merge_sim_episode_config( + sim_config: Config, episode: Type[Episode] +) -> Any: sim_config.defrost() sim_config.SCENE = episode.scene_id sim_config.freeze() @@ -148,7 +150,7 @@ class PointGoalSensor(habitat.Sensor): in cartesian or polar coordinates. """ - def __init__(self, sim, config): + def __init__(self, sim: Simulator, config: Config): self._sim = sim self._goal_format = getattr(config, "GOAL_FORMAT", "CARTESIAN") @@ -199,6 +201,45 @@ class PointGoalSensor(habitat.Sensor): return direction_vector_agent +class HeadingSensor(habitat.Sensor): + """ + Sensor for observing the agent's heading in the global coordinate frame. + + Args: + sim: reference to the simulator for calculating task observations. + config: config for the sensor. + """ + + def __init__(self, sim: Simulator, config: Config): + self._sim = sim + super().__init__(config=config) + + def _get_uuid(self, *args: Any, **kwargs: Any): + return "heading" + + def _get_sensor_type(self, *args: Any, **kwargs: Any): + return SensorTypes.HEADING + + def _get_observation_space(self, *args: Any, **kwargs: Any): + return spaces.Box(low=-np.pi, high=np.pi, shape=(1,), dtype=np.float) + + def get_observation(self, observations, episode): + agent_state = self._sim.get_agent_state() + # Quaternion is in x, y, z, w format + ref_rotation = agent_state.rotation + + direction_vector = np.array([0, 0, -1]) + + rotation_world_agent = quaternion_to_rotation( + ref_rotation[3], ref_rotation[0], ref_rotation[1], ref_rotation[2] + ) + + heading_vector = np.dot(rotation_world_agent.T, direction_vector) + + phi = cartesian_to_polar(-heading_vector[2], heading_vector[0])[1] + return np.array(phi) + + class SPL(habitat.Measure): """SPL (Success weighted by Path Length) @@ -206,7 +247,7 @@ class SPL(habitat.Measure): https://arxiv.org/pdf/1807.06757.pdf """ - def __init__(self, sim, config): + def __init__(self, sim: Simulator, config: Config): self._previous_position = None self._start_end_episode_distance = None self._agent_episode_distance = None @@ -233,12 +274,13 @@ class SPL(habitat.Measure): ep_success = 0 current_position = self._sim.get_agent_state().position.tolist() + distance_to_target = self._sim.geodesic_distance( + current_position, episode.goals[0].position + ) + if ( action == self._sim.index_stop_action - and self._euclidean_distance( - current_position, episode.goals[0].position - ) - < self._config.SUCCESS_DISTANCE + and distance_to_target < self._config.SUCCESS_DISTANCE ): ep_success = 1 diff --git a/habitat/utils/__init__.py b/habitat/utils/__init__.py new file mode 100644 index 000000000..d1f5f3fbc --- /dev/null +++ b/habitat/utils/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +__all__ = ["visualizations"] diff --git a/habitat/utils/visualizations/__init__.py b/habitat/utils/visualizations/__init__.py new file mode 100644 index 000000000..5f27292dc --- /dev/null +++ b/habitat/utils/visualizations/__init__.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from habitat.utils.visualizations import maps + +__all__ = ["maps"] diff --git a/habitat/utils/visualizations/assets/maps_topdown_agent_sprite/100x100.png b/habitat/utils/visualizations/assets/maps_topdown_agent_sprite/100x100.png new file mode 100644 index 0000000000000000000000000000000000000000..94b0f9396ebe80a25ac06a9baec83b462f815804 GIT binary patch literal 5616 zcmeAS@N?(olHy`uVBq!ia0y~yU`PRB4mJh`hJr^^Ll_uDw|lxchE&XXJ2!hmOzO<z z|IdG(_s*_%|CtoE#zc>R$7%s;({y$SDk?|0Wfht_P0<SEZi*?%PPBi)t~+_{qEp<d z3%9;{r5EsYsh7B*Zsg7#2Ut>8il%yoad1a-smw^xd-?lm$@j|p>GoeNCwp2?-nk_I z#J|<g4wuZgd;aD3`{#f6ZvVWe#m?J>gMnE=sp$5-&0R+yv90~8#xUD8WpT$GU8U?! z4P_-m21UV*7pxk&7x=lP3Z4|YXsA@x+7ZW65TN<BMrmS2=EBzu0t{^(OT_P28wT#z z(^1-zf8Wm3(D328>fYj=cPh`+Of-C^BzTkS!5aq!0ZtYs28LvwE!QqY9hP7gZ1h;o z)5+n&lGYZ)<)vw0%&_2Q&Z3(+VT^m6Tz=2|$(ejc_+`njNF}3xwmyD-QMb3}U;k^g z?0H*9hdl4C-}AkeUNTEQ!`^FX%`j)Shl+yIitDdeTz`Ebiko5og^msdB^MX&d;7QC z6FIN3KOs`nbzuO<!T=7JmZXgw4Cf6$H=p@)xx(h(lKhqq4F;q8YvNX<w{4N0csA|0 z?&)70hR?Lp?n&{rKeXVJIo{Z@<KEth-)HvSS^oUO9eKgXZMlcH<zC4$l}P?({Fy)Y z`rkmc$(B!Qmn}b5a#vu!;6A~fF?rg?^`|5@Ys>Yk%N%#y8YRkb%errOt$*(K^b1{f zqMa@$w&hOoTB>AxKey*+n$_HU?`4-?U_MrIRX{NCvr-bbV6E7JbrCwtx{fMsxu*9k z;ZV}2vw>cboi0pIiz1nKM7!*pb25dm;`o``<#{W#+bXV37MOVE%$Yffk=LGgW|XT4 zaf)!UN+g>I7A{}4$j>{<=**f)CskHm%?fsTR}lOCnrZemnU9O+?F#HI{WIOt(lS$U zqN38tw9V#h%?HyTef-gTHqCg#>8B<P2Aicqj2pr{jW;tj#ON)beKt)&uK&IS58Hf@ zYN3e&6KBquv&E|Jt>4@XJMFmIuUE~FCvD7kS$wg4`sq}Y*}l^|Le$q^Pwwn^;h>Ok zG|4bTYij%Hr=OM|KXxqcg-Nv4)v84Mn)LMa^_(nB+uNNMx@?RnVJxwg`ZMXpw%c*q zzUo4ptK-&-&p!K1_weDv<(2OX)mLq}{EyG>NnFjxqZdw`IAQOmH1YMG6&LDmlm|08 zREn-Q+so;g+{4Ghq$Sq9Flz0;(_g-Pxu17u$Hse{59W(G+r+&%+|J+M+tXts)qi}w zwg}g{*Sjuh8|$cvXiFXKYCGBzSF(O)<j!^4txgAfk3YT+Qd#xm!$U)*OUqAB4gJ5s z<%p8f_p9OYwb?du|7Ygje*ClMUh1>H<~+kp<r7DAYu?^dW@wmQKCdz7Wb31Q3r|Tr zUKXTzx^Ry3(x8{j$;rv*f23}n^?9w9(;sd*tCAO!QjC7C2+G<Mpz-AjkDx+^;XjMp z2lm|;-SEDg^-tlp#sx<w$n`lHZVuHDiCKE3!f0mAmkSG>rwdGc&0n$ZM~(L1hWuk2 zlaK59F2DR;Iyq>nm#c8#`7;a*%1s3wAG?<qX-qNYO6`rend)qD#9(g;3xi<fR4-TQ z<e--&zy2j3?~C0P6&#k*ah&J-iHXYM?n)DXZ@DITywFF#!Z5zBM8@{s0sZ;s6h%^6 zntI~x<{rNIcyGVZf}^6x3wy3v33a->vo3ma;;74$byGA?o4L4LIdY_>`p1Wdx*H?j zJXaE0cl~wS-;ZzQPWAV(*G)cLvahp2iA5yQ>i!Yt{V)BeFfbSw?~OB866!meWN%;d z<HNMhPR{GESL?i&Ic<3&JNfm@nKO+aef&{-W=@HW`}e~=J71hI`|#khrHd0!GRrXz zCB<XS{?7|cZqL_Xj9ONaI=Scg<L_tBoSEb3a!cAT>epv)*$4f%J3J=Le9zl{I6YKr z>bx(8VwYk{Y^EOe+1J@{wD|E$#%{yO`HvdyHXac-UhTDX(cy<5BG1&U?_jy8by#(F z{pmtq!HM(c&HERoHFej8>|5WC_WXQP@gj5mg|nSb=F3vAIoGCi>ZgVMthwjAG)Uaf z-@pF!YXiN=yyYfGURUl44i1(-pJH@f&c1HXa^Kr!TQ@7ea6NwK@1oR*{gIElXRdvG zynpvhpJ%sLu3Xt)t-C&W)}|GTadXp`En9Xk;r2K4Jv$rsxwyTq6H9a7&DN^=dC94# zo3`ctmJSXM&SyyQJ32Yom7`L5M){3xve_n5FP=~Mne-;(y23`54AG>;7aMYK8#0tv zr0&YF)XV6w_b@Xy{(N%b!i8_+^!NO~xRJy9mZ^shW0Z}i2-mv{7cP98Idi6e)!KEe zr$77i2nycYUA}&b*HXcySHw?0-Q+00S}P&5Y)kfqh)14NPp2-P@+)a&h}6y9<?D4q zO8<Jk_;JzD)b#7yg$ozH-h6Y-?bnGDC%<RkAfgyw__6%|7xs`WsVvFGyIQV2NH$sX z;j8p{N3E=bSDWh21$TAKKAE!S@$vrp^pcV<ZLMvucfMujIdJCOxjepBr|Zu?S3W-b z<*V=MRQKI%y{ZfjE?V=?&koQidHd$gn=~DUAKxY@DhaLqI?qETNivz|*REZKsc$l_ zFWAV!Aero=BD8VKwQ4b@KXVlSi9ETu*!?^=TeG~DSogB|p`sg#%C=++L^`>IZHzEE znDD{by<g6@zr(L<(oaFB4Bfjwe@Eyn+w<bZ>}j{(M|d&c&~G$6b|lI0=$Y*czV4pc zEyUJfTm9_}uUyo$t*aw+xY(AaZgm2i$mAf}?HcG+dqckF&gOJ}=6S~dmRvq_=1lsf zmtUr{FG~$AEBw6ue?!pQFV9YMHe|DW+;Hf6o(22L8xOYZDB5`^{rZ~7&!YF4_n*3? zw5Z2WzF3C+#UZ9$Q(u(aytvLWg_EI7hW*%#{>r&Wt|>XK*VPFQ4i;}Y{IEEC>#Ud0 zqi*|cTpZhcOi1Zb)Y`N|4?pY#NmumyYqW*uYA#*Z;iB}f>G8gm!Jve)$;HJ-MW`|8 z71w*!q&X|r>Afy5E#39EZ>dScU!ncCcAZ(jZAQnF8yk~nGc_KVoxOF|m(9CgnONU4 z&CvO{N|%#`Dc`Q_%?&lBtWu8|KOdzSc`9An`dn^W(DoY_*Exd9{S+h5Gc|?}LjBMD z@k{3U^|QqRBE>0qP=YzFPW?hwYNXPpt#RughZslR4cQLzU{YtthVXCEu16ldeEIS# z!?o+LOBeVvcisi%fgaT-Q~&+~iJv($JNNdst=BH~Yu<_j`S_5KlG56(H@4*(e@-(y zmcz%^e4z9Bf$CKtN#<ih!aQt=JvYys*%SDyM0qd7tb}L1I!dqiPW(16JRB_4q55Pg zXXJT?Ygua><)km^mu<;*h;W+6o-OHc?Dz8Juf(lynQG{Kv^ilQ=(^W0k^l3&FVkh; zzB2dj=&<2m)DyOE<-J(hzv=F~*&0>nu`BLhe(x|hKb!TftNGI^D)+DXXZ!a`dfArj zfQUmbPrh-<Hy`UO`#M=Q&dg@|#Z!SdCU^5mS>L+qKdnOUN~DX{nK#crzL>=}Rrc*G zc5fFK8-6Vwh1DV7rzh}gS>L*<KdnM8Me=~*`}yl)OE28KxXx~aiw%FDq2S&(ws-Gr z^)8z&>+<5oqV0AYTxLkRd@+(vj(hWJ>aH_<M_IzTV{~PGtuFIVlSqzBNwd3=W7epr z`~Rhw^$qPmHAh{wvJ^$PPFTM~eQ&wdD?{%cH*=b1@<$z;(cJgg!fR#Ud8x8y*}BN- z*-~W<AI$5g9&Fc-WM=qr{dn4k5A{z2*GeCGbM5`>5<T<zs_%cFE(yQJ(%qr4^hBYI zy9?i^6DLlbW{Bf!b#iF8yyRIXuyxh)sKX0_%x`Sd?7O}0c0+X8yv~v{PF86Tr_D`& zwC%N7!tG@iTBXfVhaC*p?qaAuBJ}apzo!!?PMrRC-mg0Szh%dl6<QzVb2{CzOj{!) zX3xtr2H)yRWUAN69iM%UXK_H`bBRa$_dkDC_~!46&hNAHy7gCoTzXqaQR(s2zyEe# zeSd65e5LH<zStGQZvuP_m_)Sn+JqSI?D@=dXIr)Hw%oqOx5OAfE3b$$zp>4-?{?Yx z<7chU9|%xta5~y5xJKg7-Ly{gQtsq4>lyN=eRy$k@$*)vh5z4OIrH$P-Ti|;0)kVr zrB1GB-M`PEnLp=rf`*{O&Si-`I-Gg`XGBCitk~E&-*)EF{eP{xosPQdtl7hwb$a(= zJ^OeiFU>2jziwX>z5QLJm|)cIvbUelrhRTLu<)51@poTTnZedGD=egV!#~{ne0j&; zU%WT=CW}0r;FQR^R^$H+r=zYGOEy#=*T~yn$9%84PKq}?pwMIn=V{Ivc^h3^d@N*^ zrEUIs=+@Tk`Qm~}6*uc<{H(jZt?_aU-xjH|A6I)X=eQodeNW=>^mUQ_=VL6}N}U~_ zXzP8^`YqYT8O8lcRQz>Y(~mP#lYd+?ef@Ep`LTd?nHIj0Vn22t5lX&oHsfcVeYC*g z?8wZu?^lMcwzzxsy>@E*{f*ioF>`E%4n1Edb$IQ3iQ|>8l{La*O{PUAET5ELu~<_h zEOyRzyQXM+!@l#erfr+uor-yzgC+0rP7@FeUmcpU_Weiw$M(q{3>8&XzZ^`YY_~_P zUH3|J)zl5nhK{qtTzZx<-}u>=xaWP7Nzyi_oy!ytY+5z>>Isi-?u{FqZeFZ-uu|sT z-~LFSr?Wa*uin|ab@gPOk8;-@dM*8Q=fsH<Vihk>{dM}jdhPqUep|Q7lx?`Kq!kt` z(4eg^ckFt;4f}3ANw$O4b9bd5bNasF^`vXIF_j-zzF&A)=i}<t@6Y<@9b?!~{yy%0 zgZz`vKf8jw)|S+zXqNrB>b5oSfeDY=wk02~zBYfb?DTa-!Hv%v?w(dTktO#-SdiO1 zb{}8--a2NrZA%_hoPGZ7)h1z`k5$jL=ARE|_`_iFknybb34g2NbhXJl)ux|b`{H6o zF3ZUkt-Iys9QrOl)6FQ$QTk(DCd=}~3kAH}<_I4R70xdBH#@0d>w;?r30AglB~D6l zE$ouhl>U6HKmFt36CS=l3!4@k-O#)?Ltwpr@QNKfeq9ob*_<rWdg|1s%<hiae#_V9 zm{nhY@#4kS&yz2lF>mSg=##UpT9<qK_ttA$?mTxDOKDd!5S*2ly(0PWir(1iJjqGk zLX6!3D>ECWU4QdtX%EAMFtLTpv}f1^9Jt$m_R&&J_KO0_30r##zrKoea&gmC?R%UP zyFKr2@jbS5wi)J!UO&GRu|4nZ=3ckQ>YXl4d(ZBi_EFU@tz9;!b>n^Z9c7?YoXA`@ zXTIBI(U7du@0q_u71u7D<#t*0jfH*h?fe@J+0VbWx-eMpu$tD<G226B$Hk04DR+03 zp5K;V?3l}L%wSpf=ZD(qO?*4novxd8<EVat()_Es_v<)Q-4CynZfR=yzQ?px`<`Xb zMpMNQnY$k?yW>2h8<%nfMnAkd)9T&Nt2+Pw&a`>O+bElKB&YRu{*4WHYGRyRk|L)S z#+IMHaN)w|jN4W(muQ`=Zf<XH=fC#)t99^75&f?UD+C>nq}hDAEyL$~{bC7g{(X_- z;br*;&RU;8Qc(2uu}AZQpN}&SpJd3pSE;sdQ#JFweVTpCv@J|xk2~)_tC*xOtE#jp zc%{ho*I%tc1~1ym^LmNcTc-=tY^%Q&s0jUZS{SgyNP=gXRPV8-&nij{vI#Qt);&Jn zfBuo~bC-gzOO-2r#~*pVk41Q|TJMbYKb6mYUDn;3-u96{y#KNN|K%<3?Tvcn?d`hD z_nWqwo@9`EcPLOvNr;DSXMjeE^28J3ds2++A3aF2PCci3&}aGO`;j`5y@HM=OFmpy z?_Ye+V8^3Ji7X-u-)!-z@BO`D#h+_6XVOfzKA3krU69MjZbvT1)Lm@)4^He!F}}V& za`UtK;<tOBi{#A6X7gVbwDRrGn!1Bu9w|8lb6+y#7EP~I-dL~qX@B;WuoF{bS*{#9 zTb|tGw=_s={`vRPva)k6u3qz4o^`E@{mb3H$5WaQK4|KG*}-7#pb=_qKL2mj#Q(A0 zq2gr;iN<?YarHf(!uaOfsgl4OI$0h3KQ3Opc;CA3^UDYwv*}JNV%_56-L`bwe$ZEv z?LIo?vBi8jS=qm~Vhle$pIcpR-P5xF+KU%2$~WJ9GuPSKd75EbXtnlUxzwJ*_O`Y) zIcC+*L9%;QQkUm!)%&@5ny~bR)j4LrD|Ezm7Z}d_v)}K;bGMx_|4P*+@03`3EA=F= z#&j8}kA4a+R}7B5S#ve3eeuN=M#sMB2A*ATT4M3yiz`-L&1&y8wC-HNxITEtmGtSR zVNZEiuK@MWHK%$_mrmx|miyhd_@%-riR3+-G^cniJ-zpA&()&vFFMsir|z44T(kL{ zRe;{~`~EIl12m5O&p*W}ne3w>lBBe1>glK7=a*iKWdC#4e0p`CsCn<R{Mq8YZuNOR zofXx3d)LiO>bbZrS37N8nAX(q>tlC+n|L+HVCSREsQUVSTeGj*O*;KlFZo%GiB#xm zS;>FJQ#y8B%m^_owwr!Bc3YHjL#W<{Z@!Z9^7U&1H1>FB*J{l4*{3)A$T11?nLbfb zubo?+9>zX?{J8&X&Fjzwsi&SiSz>Bt_RVKwM2-8}u<D|UiW^K1%6@;`;W**D?A?bC z6IX|=wvE&gyKiV__N{0C&u3ZZf6uzVZ@2ubXW#DDi%H-4?iCkOQri5x&t}QvufZNJ zS3Z3B5W6(!CDZXjIltwX!xRJ*_<yW_n&snigqMYBir;eG<Arj1Q@zA_`1$XvDrNnf z^W){JWPe}5i9c$mZrr$0JZR;WYOkf2e9zd-J(;4jrO0XTv#Rg+Y=hM%Z=6xND@5z= z6tAWCjLpoxHQhfR;m2RI^OaZVqw{BlCssr!)XeKBb61+^urQ#(&`ns7KPyyCS28W^ zWQx$$tgV_#<vZ3@ufG0SZOhr#a<5q32|p*CPTL&a@nSPW+0i7!P#v*_Tgxs=^}TzY z$FpzG)c<apsj5Pp6|*z%=_<vaOEF^Klk6vbXL-=|`i1{8HvU&F+Ic4|?cWr`v~U|a ze~Z56tgTuZy=~v_$lPaP&=TsLkb7HWs+VipzS|O?&z4y2y>IQ8D7kp+^3`%`y25SK zMMW7l-F&knHB#R2Sj^Eg(T5)z$Q*BcZ1F*g@!u!qS0PuQZJqS&U&o0}xzgKm7l*Ar zD6x6|%=#C*-t`;JJfryD_>}cF_2jbJ6PK$u87es~4a%Ht_^kQZ3|WSAzA8cjf>&Ob zUU^-bv6fX`tU*}tVD?R&Z63~h{#rB$ueA>I3e*s3`tBp&bySJL=J*+YqtE`2KmM3p z$MWSCM||xaA^(i~`)ae2k`B%NF7>=<=jRzem6Ff6Ds2)yaI?`tL5PEef#H}$#@dBj z+oqo^es|MD|ES;$NiR)}r4Le9I&8k_vias3ri#{%Z^oa6kIhiO^zzH*&)l^|MVD^v ztF3mA_R|i$A8q+LO=V^62}hSTNgH=KX7A-ZkS2IjLn*!cM3mC04iBESBCTRk>IE0Q z%tOT{>IljzFADE?;mW{1HR!y$Q0H{9U3^g!uavs!Ed9^Sx<3AV+P1s{3=9kmp00i_ I>zopr082ad6aWAK literal 0 HcmV?d00001 diff --git a/habitat/utils/visualizations/maps.py b/habitat/utils/visualizations/maps.py new file mode 100644 index 000000000..f845cbf2c --- /dev/null +++ b/habitat/utils/visualizations/maps.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import cv2 +import imageio +import scipy.ndimage +import os +from typing import List, Tuple, Optional +from habitat.core.simulator import Simulator +from habitat.utils.visualizations import utils + +AGENT_SPRITE = imageio.imread( + os.path.join( + os.path.dirname(__file__), + "assets", + "maps_topdown_agent_sprite", + "100x100.png", + ) +) +AGENT_SPRITE = np.ascontiguousarray(np.flipud(AGENT_SPRITE)) +COORDINATE_EPSILON = 1e-6 +COORDINATE_MIN = -62.3241 - COORDINATE_EPSILON +COORDINATE_MAX = 90.0399 + COORDINATE_EPSILON + + +def draw_agent( + image: np.ndarray, + agent_center_coord: Tuple[int, int], + agent_rotation: float, + agent_radius_px: int = 5, +) -> np.ndarray: + """Return an image with the agent image composited onto it. + Args: + image: the image onto which to put the agent. + agent_center_coord: the image coordinates where to paste the agent. + agent_rotation: the agent's current rotation in radians. + agent_radius_px: 1/2 number of pixels the agent will be resized to. + Returns: + The modified background image. This operation is in place. + """ + + # Rotate before resize to keep good resolution. + rotated_agent = scipy.ndimage.interpolation.rotate( + AGENT_SPRITE, agent_rotation * -180 / np.pi + ) + # Rescale because rotation may result in larger image than original, but the + # agent sprite size should stay the same. + initial_agent_size = AGENT_SPRITE.shape[0] + new_size = rotated_agent.shape[0] + agent_size_px = max( + 1, int(agent_radius_px * 2 * new_size / initial_agent_size) + ) + resized_agent = cv2.resize( + rotated_agent, + (agent_size_px, agent_size_px), + interpolation=cv2.INTER_LINEAR, + ) + utils.paste_overlapping_image(image, resized_agent, agent_center_coord) + return image + + +def pointnav_draw_target_birdseye_view( + agent_position: np.ndarray, + agent_heading: float, + goal_position: np.ndarray, + resolution_px: int = 800, + goal_radius: float = 0.2, + agent_radius_px: int = 20, + target_band_radii: Optional[List[float]] = None, + target_band_colors: Optional[List[Tuple[int, int, int]]] = None, +) -> np.ndarray: + """Return an image of agent w.r.t. centered target location for pointnav + tasks. + + Args: + agent_position: the agent's current position. + agent_heading: the agent's current rotation in radians. This can be + found using the HeadingSensor. + goal_position: the pointnav task goal position. + resolution_px: number of pixels for the output image width and height. + goal_radius: how near the agent needs to be to be successful for the + pointnav task. + agent_radius_px: 1/2 number of pixels the agent will be resized to. + target_band_radii: distance in meters to the outer-radius of each band + in the target image. + target_band_colors: colors in RGB 0-255 for the bands in the target. + Returns: + Image centered on the goal with the agent's current relative position + and rotation represented by an arrow. To make the rotations align + visually with habitat, positive-z is up, positive-x is left and a + rotation of 0 points upwards in the output image and rotates clockwise. + """ + if target_band_radii is None: + target_band_radii = [20, 10, 5, 2.5, 1] + if target_band_colors is None: + target_band_colors = [ + (47, 19, 122), + (22, 99, 170), + (92, 177, 0), + (226, 169, 0), + (226, 12, 29), + ] + + assert len(target_band_radii) == len( + target_band_colors + ), "There must be an equal number of scales and colors." + + goal_agent_dist = np.linalg.norm(agent_position - goal_position, 2) + + goal_distance_padding = max( + 2, 2 ** np.ceil(np.log(max(1e-6, goal_agent_dist)) / np.log(2)) + ) + movement_scale = 1.0 / goal_distance_padding + half_res = resolution_px // 2 + im_position = np.full( + (resolution_px, resolution_px, 3), 255, dtype=np.uint8 + ) + + # Draw bands: + for scale, color in zip(target_band_radii, target_band_colors): + if goal_distance_padding * 4 > scale: + cv2.circle( + im_position, + (half_res, half_res), + max(2, int(half_res * scale * movement_scale)), + color, + thickness=-1, + ) + + # Draw such that the agent being inside the radius is the circles + # overlapping. + cv2.circle( + im_position, + (half_res, half_res), + max(2, int(half_res * goal_radius * movement_scale)), + (127, 0, 0), + thickness=-1, + ) + + relative_position = agent_position - goal_position + # swap x and z, remove y for (x,y,z) -> image coordinates. + relative_position = relative_position[[2, 0]] + relative_position *= half_res * movement_scale + relative_position += half_res + relative_position = np.round(relative_position).astype(np.int32) + + # Draw the agent + draw_agent(im_position, relative_position, agent_heading, agent_radius_px) + + # Rotate twice to fix coordinate system to upwards being positive-z. + # Rotate instead of flip to keep agent rotations in sync with egocentric + # view. + im_position = np.rot90(im_position, 2) + return im_position + + +def _to_grid( + realworld_x: float, + realworld_y: float, + coordinate_min: float, + coordinate_max: float, + grid_resolution: Tuple[int, int], +) -> Tuple[int, int]: + """Return gridworld index of realworld coordinates assuming top-left corner + is the origin. The real world coordinates of lower left corner are + (coordinate_min, coordinate_min) and of top right corner are + (coordinate_max, coordinate_max) + """ + grid_size = ( + (coordinate_max - coordinate_min) / grid_resolution[0], + (coordinate_max - coordinate_min) / grid_resolution[1], + ) + grid_x = int((coordinate_max - realworld_x) / grid_size[0]) + grid_y = int((realworld_y - coordinate_min) / grid_size[1]) + return grid_x, grid_y + + +def _from_grid( + grid_x: int, + grid_y: int, + coordinate_min: float, + coordinate_max: float, + grid_resolution: Tuple[int, int], +) -> Tuple[float, float]: + """Inverse of _to_grid function. Return real world coordinate from gridworld + assuming top-left corner is the origin. The real world coordinates of lower + left corner are (coordinate_min, coordinate_min) and of top right corner + are (coordinate_max, coordinate_max) + """ + grid_size = ( + (coordinate_max - coordinate_min) / grid_resolution[0], + (coordinate_max - coordinate_min) / grid_resolution[1], + ) + realworld_x = coordinate_max - grid_x * grid_size[0] + realworld_y = coordinate_min + grid_y * grid_size[1] + return realworld_x, realworld_y + + +def _check_valid_nav_point(sim: Simulator, point: List[float]) -> bool: + """Checks if a given point is inside a wall or other object or not.""" + return ( + 0.01 + < sim.geodesic_distance(point, [point[0], point[1] + 0.1, point[2]]) + < 0.11 + ) + + +def _outline_border(top_down_map): + left_right_block_nav = (top_down_map[:, :-1] == 1) & ( + top_down_map[:, :-1] != top_down_map[:, 1:] + ) + left_right_nav_block = (top_down_map[:, 1:] == 1) & ( + top_down_map[:, :-1] != top_down_map[:, 1:] + ) + + up_down_block_nav = (top_down_map[:-1] == 1) & ( + top_down_map[:-1] != top_down_map[1:] + ) + up_down_nav_block = (top_down_map[1:] == 1) & ( + top_down_map[:-1] != top_down_map[1:] + ) + + top_down_map[:, :-1][left_right_block_nav] = 2 + top_down_map[:, 1:][left_right_nav_block] = 2 + + top_down_map[:-1][up_down_block_nav] = 2 + top_down_map[1:][up_down_nav_block] = 2 + + +def get_topdown_map( + sim: Simulator, + map_resolution: Tuple[int, int] = (1250, 1250), + num_samples: int = 20000, + draw_border: bool = True, +) -> np.ndarray: + """Return a top-down occupancy map for a sim. Note, this only returns valid + values for whatever floor the agent is currently on. + + Args: + sim: The simulator. + map_resolution: The resolution of map which will be computed and returned. + num_samples: The number of random navigable points which will be initially + sampled. For large environments it may need to be increased. + draw_border: Whether to outline the border of the occupied spaces. + + Returns: + Image containing 0 if occupied, 1 if unoccupied, and 2 if border (if + the flag is set). + """ + top_down_map = np.zeros(map_resolution, dtype=np.uint8) + grid_delta = 3 + + start_height = sim.get_agent_state().position[1] + + # Use sampling to find the extrema points that might be navigable. + for _ in range(num_samples): + point = sim.sample_navigable_point() + # Check if on same level as original + if np.abs(start_height - point[1]) > 0.5: + continue + g_x, g_y = _to_grid( + point[0], point[2], COORDINATE_MIN, COORDINATE_MAX, map_resolution + ) + + top_down_map[g_x, g_y] = 1 + + range_x = np.where(np.any(top_down_map, axis=1))[0] + range_y = np.where(np.any(top_down_map, axis=0))[0] + # Pad the range just in case not enough points were sampled to get the true + # extrema. + padding = int(np.ceil(map_resolution[0] / 125)) + range_x = ( + max(range_x[0] - padding, 0), + min(range_x[-1] + padding + 1, top_down_map.shape[0]), + ) + range_y = ( + max(range_y[0] - padding, 0), + min(range_y[-1] + padding + 1, top_down_map.shape[1]), + ) + top_down_map[:] = 0 + # Search over grid for valid points. + for ii in range(range_x[0], range_x[1]): + for jj in range(range_y[0], range_y[1]): + realworld_x, realworld_y = _from_grid( + ii, jj, COORDINATE_MIN, COORDINATE_MAX, map_resolution + ) + valid_point = _check_valid_nav_point( + sim, [realworld_x, start_height + 0.5, realworld_y] + ) + if valid_point: + top_down_map[ii, jj] = 1 + + # Draw border if necessary + if draw_border: + # Recompute range in case padding added any more values. + range_x = np.where(np.any(top_down_map, axis=1))[0] + range_y = np.where(np.any(top_down_map, axis=0))[0] + range_x = ( + max(range_x[0] - grid_delta, 0), + min(range_x[-1] + grid_delta + 1, top_down_map.shape[0]), + ) + range_y = ( + max(range_y[0] - grid_delta, 0), + min(range_y[-1] + grid_delta + 1, top_down_map.shape[1]), + ) + + _outline_border( + top_down_map[range_x[0] : range_x[1], range_y[0] : range_y[1]] + ) + return top_down_map diff --git a/habitat/utils/visualizations/utils.py b/habitat/utils/visualizations/utils.py new file mode 100644 index 000000000..b4d5980bd --- /dev/null +++ b/habitat/utils/visualizations/utils.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from typing import Tuple, Optional + + +def paste_overlapping_image( + background: np.ndarray, + foreground: np.ndarray, + location: Tuple[int, int], + mask: Optional[np.ndarray] = None, +): + """Composites the foreground onto the background dealing with edge + boundaries. + Args: + background: the background image to paste on. + foreground: the image to paste. Can be RGB or RGBA. If using alpha + blending, values for foreground and background should both be + between 0 and 255. Otherwise behavior is undefined. + location: the image coordinates to paste the foreground. + mask: If not None, a mask for deciding what part of the foreground to + use. Must be the same size as the foreground if provided. + Returns: + The modified background image. This operation is in place. + """ + assert mask is None or mask.shape[:2] == foreground.shape[:2] + foreground_size = foreground.shape[:2] + min_pad = ( + max(0, foreground_size[0] // 2 - location[0]), + max(0, foreground_size[1] // 2 - location[1]), + ) + + max_pad = ( + max( + 0, + (location[0] + (foreground_size[0] - foreground_size[0] // 2)) + - background.shape[0], + ), + max( + 0, + (location[1] + (foreground_size[1] - foreground_size[1] // 2)) + - background.shape[1], + ), + ) + + background_patch = background[ + (location[0] - foreground_size[0] // 2 + min_pad[0]) : ( + location[0] + + (foreground_size[0] - foreground_size[0] // 2) + - max_pad[0] + ), + (location[1] - foreground_size[1] // 2 + min_pad[1]) : ( + location[1] + + (foreground_size[1] - foreground_size[1] // 2) + - max_pad[1] + ), + ] + foreground = foreground[ + min_pad[0] : foreground.shape[0] - max_pad[0], + min_pad[1] : foreground.shape[1] - max_pad[1], + ] + if foreground.size == 0 or background_patch.size == 0: + # Nothing to do, no overlap. + return background + + if mask is not None: + mask = mask[ + min_pad[0] : foreground.shape[0] - max_pad[0], + min_pad[1] : foreground.shape[1] - max_pad[1], + ] + + if foreground.shape[2] == 4: + # Alpha blending + foreground = ( + background_patch.astype(np.int32) * (255 - foreground[:, :, [3]]) + + foreground[:, :, :3].astype(np.int32) * foreground[:, :, [3]] + ) // 255 + if mask is not None: + background_patch[mask] = foreground[mask] + else: + background_patch[:] = foreground + return background diff --git a/requirements.txt b/requirements.txt index fe169669a..f73903283 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,7 @@ -h5py gym==0.10.9 numpy==1.15 yacs>=0.1.5 +# visualization optional dependencies +imageio>=2.2.0 +opencv-python>=3.3.0 +scipy>=1.0.0 diff --git a/test/test_baseline_agents.py b/test/test_baseline_agents.py new file mode 100644 index 000000000..f849669e6 --- /dev/null +++ b/test/test_baseline_agents.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from baselines.agents import simple_agents, ppo_agents +import habitat +import os +import pytest + +CFG_TEST = "test/habitat_all_sensors_test.yaml" + + +def test_ppo_agents(): + config = ppo_agents.get_defaut_config() + config.MODEL_PATH = "" + config_env = habitat.get_config(config_file=CFG_TEST) + config_env.defrost() + if not os.path.exists(config_env.SIMULATOR.SCENE): + pytest.skip("Please download Habitat test data to data folder.") + + benchmark = habitat.Benchmark(config_file=CFG_TEST, config_dir="configs") + + for input_type in ["blind", "rgb", "depth", "rgbd"]: + config_env.defrost() + config_env.SIMULATOR.AGENT_0.SENSORS = [] + if input_type in ["rgb", "rgbd"]: + config_env.SIMULATOR.AGENT_0.SENSORS += ["RGB_SENSOR"] + if input_type in ["depth", "rgbd"]: + config_env.SIMULATOR.AGENT_0.SENSORS += ["DEPTH_SENSOR"] + config_env.freeze() + del benchmark._env + benchmark._env = habitat.Env(config=config_env) + config.INPUT_TYPE = input_type + + agent = ppo_agents.PPOAgent(config) + habitat.logger.info(benchmark.evaluate(agent, num_episodes=10)) + + +def test_simple_agents(): + config_env = habitat.get_config(config_file=CFG_TEST) + + if not os.path.exists(config_env.SIMULATOR.SCENE): + pytest.skip("Please download Habitat test data to data folder.") + + benchmark = habitat.Benchmark(config_file=CFG_TEST, config_dir="configs") + + for agent_class in [ + simple_agents.ForwardOnlyAgent, + simple_agents.GoalFollower, + simple_agents.RandomAgent, + simple_agents.RandomForwardAgent, + ]: + agent = agent_class(config_env) + habitat.logger.info(agent_class.__name__) + habitat.logger.info(benchmark.evaluate(agent, num_episodes=100)) diff --git a/test/test_habitat_env.py b/test/test_habitat_env.py index 75e3cb812..72d9a2380 100644 --- a/test/test_habitat_env.py +++ b/test/test_habitat_env.py @@ -19,7 +19,7 @@ from habitat.sims.habitat_simulator import ( SIM_ACTION_TO_NAME, SIM_NAME_TO_ACTION, ) -from habitat.tasks.nav.nav_task import NavigationEpisode +from habitat.tasks.nav.nav_task import NavigationEpisode, NavigationGoal CFG_TEST = "test/habitat_all_sensors_test.yaml" NUM_ENVS = 2 @@ -84,8 +84,6 @@ def _vec_env_test_fn(configs, datasets, multiprocessing_start_method): observations = envs.step(np.random.choice(non_stop_actions, num_envs)) assert len(observations) == num_envs - envs.close() - def test_vectorized_envs_forkserver(): configs, datasets = _load_test_data() @@ -112,6 +110,18 @@ def test_vectorized_envs_fork(): assert p.exitcode == 0 +def test_with_scope(): + configs, datasets = _load_test_data() + num_envs = len(configs) + env_fn_args = tuple(zip(configs, datasets, range(num_envs))) + with habitat.VectorEnv( + env_fn_args=env_fn_args, multiprocessing_start_method="forkserver" + ) as envs: + envs.reset() + + assert envs._is_closed + + def test_threaded_vectorized_env(): configs, datasets = _load_test_data() num_envs = len(configs) @@ -140,9 +150,10 @@ def test_env(): NavigationEpisode( episode_id="0", scene_id=config.SIMULATOR.SCENE, - start_position=[03.00611, 0.072447, -2.67867], + start_position=[3.00611, 0.072447, -2.67867], start_rotation=[0, 0.163276, 0, 0.98658], - goals=[], + goals=[NavigationGoal([3.00611, 0.072447, -2.67867])], + info={"geodesic_distance": 0.001}, ) ] @@ -219,9 +230,10 @@ def test_rl_env(): NavigationEpisode( episode_id="0", scene_id=config.SIMULATOR.SCENE, - start_position=[03.00611, 0.072447, -2.67867], + start_position=[3.00611, 0.072447, -2.67867], start_rotation=[0, 0.163276, 0, 0.98658], - goals=[], + goals=[NavigationGoal([3.00611, 0.072447, -2.67867])], + info={"geodesic_distance": 0.001}, ) ] @@ -276,6 +288,8 @@ def test_action_space_shortest_path(): while len(unreachable_targets) < 3: position = env.sim.sample_navigable_point() + # Change height of the point to make it unreachable + position[1] = 100 angles = [x for x in range(-180, 180, config.SIMULATOR.TURN_ANGLE)] angle = np.radians(np.random.choice(angles)) rotation = [0, np.sin(angle / 2), 0, np.cos(angle / 2)] @@ -289,26 +303,4 @@ def test_action_space_shortest_path(): targets = unreachable_targets shortest_path2 = env.sim.action_space_shortest_path(source, targets) assert shortest_path2 == [] - - targets = reachable_targets + unreachable_targets - shortest_path3 = env.sim.action_space_shortest_path(source, targets) - - # shortest_path1 should be identical to shortest_path3 - assert len(shortest_path1) == len(shortest_path3) - for i in range(len(shortest_path1)): - assert np.allclose( - shortest_path1[i].position, shortest_path3[i].position - ) - assert np.allclose( - shortest_path1[i].rotation, shortest_path3[i].rotation - ) - assert shortest_path1[i].action == shortest_path3[i].action - - targets = unreachable_targets + [source] - shortest_path4 = env.sim.action_space_shortest_path(source, targets) - assert len(shortest_path4) == 1 - assert np.allclose(shortest_path4[0].position, source.position) - assert np.allclose(shortest_path4[0].rotation, source.rotation) - assert shortest_path4[0].action is None - env.close() diff --git a/test/test_habitat_example.py b/test/test_habitat_example.py index 11f193446..227014d60 100644 --- a/test/test_habitat_example.py +++ b/test/test_habitat_example.py @@ -9,6 +9,7 @@ import pytest import habitat from examples.example import example from habitat.datasets.pointnav.pointnav_dataset import PointNavDatasetV1 +from examples import visualization_examples def test_readme_example(): @@ -17,3 +18,11 @@ def test_readme_example(): ): pytest.skip("Please download Habitat test data to data folder.") example() + + +def test_visualizations_example(): + if not PointNavDatasetV1.check_config_paths_exist( + config=habitat.get_config().DATASET + ): + pytest.skip("Please download Habitat test data to data folder.") + visualization_examples.main() diff --git a/test/test_sensors.py b/test/test_sensors.py new file mode 100644 index 000000000..d6f8d1b50 --- /dev/null +++ b/test/test_sensors.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import os +import pytest +import random + +import habitat +from habitat.config.default import get_config +from habitat.tasks.nav.nav_task import NavigationEpisode + +CFG_TEST = "test/habitat_all_sensors_test.yaml" + + +def test_heading_sensor(): + config = get_config(CFG_TEST) + if not os.path.exists(config.SIMULATOR.SCENE): + pytest.skip("Please download Habitat test data to data folder.") + config = get_config() + config.defrost() + config.TASK.SENSORS = ["HEADING_SENSOR"] + config.freeze() + env = habitat.Env(config=config, dataset=None) + env.reset() + random.seed(1234) + + for _ in range(100): + random_heading = np.random.uniform(-np.pi, np.pi) + random_rotation = [ + 0, + np.sin(random_heading / 2), + 0, + np.cos(random_heading / 2), + ] + env.episodes = [ + NavigationEpisode( + episode_id="0", + scene_id=config.SIMULATOR.SCENE, + start_position=[03.00611, 0.072447, -2.67867], + start_rotation=random_rotation, + goals=[], + ) + ] + + obs = env.reset() + heading = obs["heading"] + assert np.allclose(heading, random_heading) + + env.close() diff --git a/test/test_trajectory_sim.py b/test/test_trajectory_sim.py index e50384d09..9123cbf11 100644 --- a/test/test_trajectory_sim.py +++ b/test/test_trajectory_sim.py @@ -21,7 +21,7 @@ def init_sim(): return make_sim(config.SIMULATOR.TYPE, config=config.SIMULATOR) -def test_sim(): +def test_sim_trajectory(): with open("test/data/habitat-sim_trajectory_data.json", "r") as f: test_trajectory = json.load(f) sim = init_sim() @@ -61,3 +61,14 @@ def test_sim(): assert sim.is_episode_active is False sim.close() + + +def test_sim_no_sensors(): + config = get_config() + config.defrost() + config.SIMULATOR.AGENT_0.SENSORS = [] + if not os.path.exists(config.SIMULATOR.SCENE): + pytest.skip("Please download Habitat test data to data folder.") + sim = make_sim(config.SIMULATOR.TYPE, config=config.SIMULATOR) + sim.reset() + sim.close() -- GitLab