From b5f2b00a25627ecb52b43b13ea96b05998d9a121 Mon Sep 17 00:00:00 2001 From: Erik Wijmans <etw@gatech.edu> Date: Thu, 15 Aug 2019 18:18:58 -0400 Subject: [PATCH] Correct the naming of PointGoalNav Sensors (#180) This PR seeks to bring habitat-API inline with this. StaticPointGoalSensor is now just a PointGoalSensor and the old PointGoalSensor is now a PointGoalWithGPSCompassSensor. There are also now explicit GPS and Compass sensors, such that you can fairly easily compute PointGoalWithGPSCompassSensor given [PointGoalSensor, GPSSensor, CompassSensor]. The PointGoal* sensors now also include {2d, 3d} x {Cartesian, Polar} support. --- configs/tasks/pointnav.yaml | 12 +- configs/tasks/pointnav_gibson.yaml | 11 +- configs/tasks/pointnav_mp3d.yaml | 11 +- configs/tasks/pointnav_rgbd.yaml | 11 +- configs/test/habitat_all_sensors_test.yaml | 12 +- habitat/config/default.py | 21 +- habitat/tasks/nav/nav_task.py | 230 ++++++++++++++------- habitat/tasks/utils.py | 9 + habitat_baselines/agents/ppo_agents.py | 2 +- habitat_baselines/agents/slam_agents.py | 12 +- test/test_sensors.py | 95 ++++++++- 11 files changed, 305 insertions(+), 121 deletions(-) diff --git a/configs/tasks/pointnav.yaml b/configs/tasks/pointnav.yaml index 8a9e07239..bcf81eaec 100644 --- a/configs/tasks/pointnav.yaml +++ b/configs/tasks/pointnav.yaml @@ -14,12 +14,14 @@ SIMULATOR: TASK: TYPE: Nav-v0 SUCCESS_DISTANCE: 0.2 - SENSORS: ['POINTGOAL_SENSOR'] - POINTGOAL_SENSOR: - TYPE: PointGoalSensor - GOAL_FORMAT: POLAR + + SENSORS: ['POINTGOAL_WITH_GPS_COMPASS_SENSOR'] + POINTGOAL_WITH_GPS_COMPASS_SENSOR: + GOAL_FORMAT: "POLAR" + DIMENSIONALITY: 2 + GOAL_SENSOR_UUID: pointgoal_with_gps_compass + MEASUREMENTS: ['SPL'] - GOAL_SENSOR_UUID: pointgoal SPL: TYPE: SPL SUCCESS_DISTANCE: 0.2 diff --git a/configs/tasks/pointnav_gibson.yaml b/configs/tasks/pointnav_gibson.yaml index 17681b46f..52d5ae708 100644 --- a/configs/tasks/pointnav_gibson.yaml +++ b/configs/tasks/pointnav_gibson.yaml @@ -14,10 +14,13 @@ SIMULATOR: TASK: TYPE: Nav-v0 SUCCESS_DISTANCE: 0.2 - SENSORS: ['POINTGOAL_SENSOR'] - POINTGOAL_SENSOR: - TYPE: PointGoalSensor - GOAL_FORMAT: POLAR + + SENSORS: ['POINTGOAL_WITH_GPS_COMPASS_SENSOR'] + POINTGOAL_WITH_GPS_COMPASS_SENSOR: + GOAL_FORMAT: "POLAR" + DIMENSIONALITY: 2 + GOAL_SENSOR_UUID: pointgoal_with_gps_compass + MEASUREMENTS: ['SPL'] SPL: TYPE: SPL diff --git a/configs/tasks/pointnav_mp3d.yaml b/configs/tasks/pointnav_mp3d.yaml index b22c32984..341cac7e6 100644 --- a/configs/tasks/pointnav_mp3d.yaml +++ b/configs/tasks/pointnav_mp3d.yaml @@ -14,10 +14,13 @@ SIMULATOR: TASK: TYPE: Nav-v0 SUCCESS_DISTANCE: 0.2 - SENSORS: ['POINTGOAL_SENSOR'] - POINTGOAL_SENSOR: - TYPE: PointGoalSensor - GOAL_FORMAT: POLAR + + SENSORS: ['POINTGOAL_WITH_GPS_COMPASS_SENSOR'] + POINTGOAL_WITH_GPS_COMPASS_SENSOR: + GOAL_FORMAT: "POLAR" + DIMENSIONALITY: 2 + GOAL_SENSOR_UUID: pointgoal_with_gps_compass + MEASUREMENTS: ['SPL'] SPL: TYPE: SPL diff --git a/configs/tasks/pointnav_rgbd.yaml b/configs/tasks/pointnav_rgbd.yaml index 9635a93b0..e1660792c 100644 --- a/configs/tasks/pointnav_rgbd.yaml +++ b/configs/tasks/pointnav_rgbd.yaml @@ -14,10 +14,13 @@ SIMULATOR: TASK: TYPE: Nav-v0 SUCCESS_DISTANCE: 0.2 - SENSORS: ['POINTGOAL_SENSOR'] - POINTGOAL_SENSOR: - TYPE: PointGoalSensor - GOAL_FORMAT: POLAR + + SENSORS: ['POINTGOAL_WITH_GPS_COMPASS_SENSOR'] + POINTGOAL_WITH_GPS_COMPASS_SENSOR: + GOAL_FORMAT: "POLAR" + DIMENSIONALITY: 2 + GOAL_SENSOR_UUID: pointgoal_with_gps_compass + MEASUREMENTS: ['SPL'] SPL: TYPE: SPL diff --git a/configs/test/habitat_all_sensors_test.yaml b/configs/test/habitat_all_sensors_test.yaml index 944310b6a..df365281e 100644 --- a/configs/test/habitat_all_sensors_test.yaml +++ b/configs/test/habitat_all_sensors_test.yaml @@ -16,12 +16,14 @@ DATASET: TASK: TYPE: Nav-v0 SUCCESS_DISTANCE: 0.2 - SENSORS: ['POINTGOAL_SENSOR'] - POINTGOAL_SENSOR: - TYPE: PointGoalSensor - GOAL_FORMAT: POLAR + + SENSORS: ['POINTGOAL_WITH_GPS_COMPASS_SENSOR'] + POINTGOAL_WITH_GPS_COMPASS_SENSOR: + GOAL_FORMAT: "POLAR" + DIMENSIONALITY: 2 + GOAL_SENSOR_UUID: pointgoal_with_gps_compass + MEASUREMENTS: ['SPL'] - GOAL_SENSOR_UUID: pointgoal SPL: TYPE: SPL SUCCESS_DISTANCE: 0.2 diff --git a/habitat/config/default.py b/habitat/config/default.py index a433e997c..24014465f 100644 --- a/habitat/config/default.py +++ b/habitat/config/default.py @@ -43,18 +43,31 @@ _C.TASK.GOAL_SENSOR_UUID = "pointgoal" _C.TASK.POINTGOAL_SENSOR = CN() _C.TASK.POINTGOAL_SENSOR.TYPE = "PointGoalSensor" _C.TASK.POINTGOAL_SENSOR.GOAL_FORMAT = "POLAR" +_C.TASK.POINTGOAL_SENSOR.DIMENSIONALITY = 2 # ----------------------------------------------------------------------------- -# # STATIC POINTGOAL SENSOR +# # POINTGOAL WITH GPS+COMPASS SENSOR # ----------------------------------------------------------------------------- -_C.TASK.STATIC_POINTGOAL_SENSOR = CN() -_C.TASK.STATIC_POINTGOAL_SENSOR.TYPE = "StaticPointGoalSensor" -_C.TASK.STATIC_POINTGOAL_SENSOR.GOAL_FORMAT = "CARTESIAN" +_C.TASK.POINTGOAL_WITH_GPS_COMPASS_SENSOR = _C.TASK.POINTGOAL_SENSOR.clone() +_C.TASK.POINTGOAL_WITH_GPS_COMPASS_SENSOR.TYPE = ( + "PointGoalWithGPSCompassSensor" +) # ----------------------------------------------------------------------------- # # HEADING SENSOR # ----------------------------------------------------------------------------- _C.TASK.HEADING_SENSOR = CN() _C.TASK.HEADING_SENSOR.TYPE = "HeadingSensor" # ----------------------------------------------------------------------------- +# # COMPASS SENSOR +# ----------------------------------------------------------------------------- +_C.TASK.COMPASS_SENSOR = CN() +_C.TASK.COMPASS_SENSOR.TYPE = "CompassSensor" +# ----------------------------------------------------------------------------- +# # GPS SENSOR +# ----------------------------------------------------------------------------- +_C.TASK.GPS_SENSOR = CN() +_C.TASK.GPS_SENSOR.TYPE = "GPSSensor" +_C.TASK.GPS_SENSOR.DIMENSIONALITY = 2 +# ----------------------------------------------------------------------------- # # PROXIMITY SENSOR # ----------------------------------------------------------------------------- _C.TASK.PROXIMITY_SENSOR = CN() diff --git a/habitat/tasks/nav/nav_task.py b/habitat/tasks/nav/nav_task.py index 27b145152..3aac5d9aa 100644 --- a/habitat/tasks/nav/nav_task.py +++ b/habitat/tasks/nav/nav_task.py @@ -23,7 +23,11 @@ from habitat.core.simulator import ( Simulator, ) from habitat.core.utils import not_none_validator -from habitat.tasks.utils import cartesian_to_polar, quaternion_rotate_vector +from habitat.tasks.utils import ( + cartesian_to_polar, + quaternion_from_coeff, + quaternion_rotate_vector, +) from habitat.utils.visualizations import fog_of_war, maps MAP_THICKNESS_SCALAR: int = 1250 @@ -107,7 +111,8 @@ class NavigationEpisode(Episode): @registry.register_sensor class PointGoalSensor(Sensor): - r"""Sensor for PointGoal observations which are used in the PointNav task. + r"""Sensor for PointGoal observations which are used in PointGoal Navigation. + For the agent in simulator the forward direction is along negative-z. In polar coordinate format the angle returned is azimuth to the goal. @@ -118,9 +123,13 @@ class PointGoalSensor(Sensor): the pointgoal is specified. Current options for goal format are cartesian and polar. + Also contains a DIMENSIONALITY field which specifes the number + of dimensions ued to specify the goal, must be in [2, 3] + Attributes: _goal_format: format for specifying the goal which can be done in cartesian or polar coordinates. + _dimensionality: number of dimensions used to specify the goal """ def __init__(self, sim: Simulator, config: Config): @@ -129,6 +138,9 @@ class PointGoalSensor(Sensor): self._goal_format = getattr(config, "GOAL_FORMAT", "CARTESIAN") assert self._goal_format in ["CARTESIAN", "POLAR"] + self._dimensionality = getattr(config, "DIMENSIONALITY", 2) + assert self._dimensionality in [2, 3] + super().__init__(config=config) def _get_uuid(self, *args: Any, **kwargs: Any): @@ -138,10 +150,8 @@ class PointGoalSensor(Sensor): return SensorTypes.PATH def _get_observation_space(self, *args: Any, **kwargs: Any): - if self._goal_format == "CARTESIAN": - sensor_shape = (3,) - else: - sensor_shape = (2,) + sensor_shape = (self._dimensionality,) + return spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, @@ -149,98 +159,85 @@ class PointGoalSensor(Sensor): dtype=np.float32, ) - def get_observation(self, observations, episode): - agent_state = self._sim.get_agent_state() - ref_position = agent_state.position - rotation_world_agent = agent_state.rotation - - direction_vector = ( - np.array(episode.goals[0].position, dtype=np.float32) - - ref_position - ) + def _compute_pointgoal( + self, source_position, source_rotation, goal_position + ): + direction_vector = goal_position - source_position direction_vector_agent = quaternion_rotate_vector( - rotation_world_agent.inverse(), direction_vector + source_rotation.inverse(), direction_vector ) if self._goal_format == "POLAR": - rho, phi = cartesian_to_polar( - -direction_vector_agent[2], direction_vector_agent[0] - ) - direction_vector_agent = np.array([rho, -phi], dtype=np.float32) + if self._dimensionality == 2: + rho, phi = cartesian_to_polar( + -direction_vector_agent[2], direction_vector_agent[0] + ) + return np.array([rho, -phi], dtype=np.float32) + else: + _, phi = cartesian_to_polar( + -direction_vector_agent[2], direction_vector_agent[0] + ) + theta = np.arccos( + direction_vector_agent[1] + / np.linalg.norm(direction_vector_agent) + ) + rho = np.linalg.norm(direction_vector_agent) + + return np.array([rho, -phi, theta], dtype=np.float32) + else: + if self._dimensionality == 2: + return np.array( + [-direction_vector_agent[2], direction_vector_agent[0]], + dtype=np.float32, + ) + else: + return direction_vector_agent + + def get_observation(self, observations, episode: Episode): + source_position = np.array(episode.start_position, dtype=np.float32) + rotation_world_start = quaternion_from_coeff(episode.start_rotation) + goal_position = np.array(episode.goals[0].position, dtype=np.float32) - return direction_vector_agent + return self._compute_pointgoal( + source_position, rotation_world_start, goal_position + ) -@registry.register_sensor -class StaticPointGoalSensor(Sensor): - r"""Sensor for PointGoal observations which are used in the StaticPointNav - task. For the agent in simulator the forward direction is along negative-z. +@registry.register_sensor(name="PointGoalWithGPSCompassSensor") +class IntegratedPointGoalGPSAndCompassSensor(PointGoalSensor): + r"""Sensor that integrates PointGoals observations (which are used PointGoal Navigation) and GPS+Compass. + + For the agent in simulator the forward direction is along negative-z. In polar coordinate format the angle returned is azimuth to the goal. + Args: sim: reference to the simulator for calculating task observations. config: config for the PointGoal sensor. Can contain field for GOAL_FORMAT which can be used to specify the format in which the pointgoal is specified. Current options for goal format are cartesian and polar. + + Also contains a DIMENSIONALITY field which specifes the number + of dimensions ued to specify the goal, must be in [2, 3] + Attributes: _goal_format: format for specifying the goal which can be done in cartesian or polar coordinates. + _dimensionality: number of dimensions used to specify the goal """ - def __init__(self, sim: Simulator, config: Config): - self._sim = sim - self._goal_format = getattr(config, "GOAL_FORMAT", "CARTESIAN") - assert self._goal_format in ["CARTESIAN", "POLAR"] - - super().__init__(sim, config) - self._initial_vector = None - self.current_episode_id = None - def _get_uuid(self, *args: Any, **kwargs: Any): - return "static_pointgoal" - - def _get_sensor_type(self, *args: Any, **kwargs: Any): - return SensorTypes.PATH - - def _get_observation_space(self, *args: Any, **kwargs: Any): - if self._goal_format == "CARTESIAN": - sensor_shape = (3,) - else: - sensor_shape = (2,) - return spaces.Box( - low=np.finfo(np.float32).min, - high=np.finfo(np.float32).max, - shape=sensor_shape, - dtype=np.float32, - ) + return "pointgoal_with_gps_compass" def get_observation(self, observations, episode): - episode_id = (episode.episode_id, episode.scene_id) - if self.current_episode_id != episode_id: - # Only compute the direction vector when a new episode is started. - self.current_episode_id = episode_id - agent_state = self._sim.get_agent_state() - ref_position = agent_state.position - rotation_world_agent = agent_state.rotation - - direction_vector = ( - np.array(episode.goals[0].position, dtype=np.float32) - - ref_position - ) - direction_vector_agent = quaternion_rotate_vector( - rotation_world_agent.inverse(), direction_vector - ) - - if self._goal_format == "POLAR": - rho, phi = cartesian_to_polar( - -direction_vector_agent[2], direction_vector_agent[0] - ) - direction_vector_agent = np.array( - [rho, -phi], dtype=np.float32 - ) + agent_state = self._sim.get_agent_state() + agent_position = agent_state.position + rotation_world_agent = agent_state.rotation + goal_position = np.array(episode.goals[0].position, dtype=np.float32) - self._initial_vector = direction_vector_agent - return self._initial_vector + return self._compute_pointgoal( + agent_position, rotation_world_agent, goal_position + ) @registry.register_sensor @@ -266,18 +263,91 @@ class HeadingSensor(Sensor): def _get_observation_space(self, *args: Any, **kwargs: Any): return spaces.Box(low=-np.pi, high=np.pi, shape=(1,), dtype=np.float) + def _quat_to_xy_heading(self, quat): + direction_vector = np.array([0, 0, -1]) + + heading_vector = quaternion_rotate_vector(quat, direction_vector) + + phi = cartesian_to_polar(-heading_vector[2], heading_vector[0])[1] + return np.array(phi) + def get_observation(self, observations, episode): agent_state = self._sim.get_agent_state() rotation_world_agent = agent_state.rotation - direction_vector = np.array([0, 0, -1]) + return self._quat_to_xy_heading(rotation_world_agent.inverse()) - heading_vector = quaternion_rotate_vector( - rotation_world_agent.inverse(), direction_vector + +@registry.register_sensor(name="CompassSensor") +class EpisodicCompassSensor(HeadingSensor): + r"""The agents heading in the coordinate frame defined by the epiosde, + theta=0 is defined by the agents state at t=0 + """ + + def _get_uuid(self, *args: Any, **kwargs: Any): + return "compass" + + def get_observation(self, observations, episode): + agent_state = self._sim.get_agent_state() + rotation_world_agent = agent_state.rotation + rotation_world_start = quaternion_from_coeff(episode.start_rotation) + + return self._quat_to_xy_heading( + rotation_world_agent.inverse() * rotation_world_start ) - phi = cartesian_to_polar(-heading_vector[2], heading_vector[0])[1] - return np.array(phi) + +@registry.register_sensor(name="GPSSensor") +class EpisodicGPSSensor(Sensor): + r"""The agents current location in the coordinate frame defined by the episode, + i.e. the axis it faces along and the origin is defined by its state at t=0 + + Args: + sim: reference to the simulator for calculating task observations. + config: Contains the DIMENSIONALITY field for the number of dimensions to express the agents position + Attributes: + _dimensionality: number of dimensions used to specify the agents position + """ + + def __init__(self, sim: Simulator, config: Config): + self._sim = sim + + self._dimensionality = getattr(config, "DIMENSIONALITY", 2) + assert self._dimensionality in [2, 3] + super().__init__(config=config) + + def _get_uuid(self, *args: Any, **kwargs: Any): + return "gps" + + def _get_sensor_type(self, *args: Any, **kwargs: Any): + return SensorTypes.POSITION + + def _get_observation_space(self, *args: Any, **kwargs: Any): + sensor_shape = (self._dimensionality,) + return spaces.Box( + low=np.finfo(np.float32).min, + high=np.finfo(np.float32).max, + shape=sensor_shape, + dtype=np.float32, + ) + + def get_observation(self, observations, episode): + agent_state = self._sim.get_agent_state() + + origin = np.array(episode.start_position, dtype=np.float32) + rotation_world_start = quaternion_from_coeff(episode.start_rotation) + + agent_position = agent_state.position + + agent_position = quaternion_rotate_vector( + rotation_world_start.inverse(), agent_position - origin + ) + if self._dimensionality == 2: + return np.array( + [-agent_position[2], agent_position[0]], dtype=np.float32 + ) + else: + return agent_position.astype(np.float32) @registry.register_sensor diff --git a/habitat/tasks/utils.py b/habitat/tasks/utils.py index ceb36600d..c32d4b4f2 100644 --- a/habitat/tasks/utils.py +++ b/habitat/tasks/utils.py @@ -51,6 +51,15 @@ def quaternion_rotate_vector(quat: np.quaternion, v: np.array) -> np.array: return (quat * vq * quat.inverse()).imag +def quaternion_from_coeff(coeffs: np.ndarray) -> np.quaternion: + r"""Creates a quaternions from coeffs in [x, y, z, w] format + """ + quat = np.quaternion(0, 0, 0, 0) + quat.real = coeffs[3] + quat.imag = coeffs[0:3] + return quat + + def cartesian_to_polar(x, y): rho = np.sqrt(x ** 2 + y ** 2) phi = np.arctan2(y, x) diff --git a/habitat_baselines/agents/ppo_agents.py b/habitat_baselines/agents/ppo_agents.py index 551be4ff3..07f6f9237 100644 --- a/habitat_baselines/agents/ppo_agents.py +++ b/habitat_baselines/agents/ppo_agents.py @@ -28,7 +28,7 @@ def get_default_config(): c.HIDDEN_SIZE = 512 c.RANDOM_SEED = 7 c.PTH_GPU_ID = 0 - c.GOAL_SENSOR_UUID = "pointgoal" + c.GOAL_SENSOR_UUID = "pointgoal_with_gps_compass" return c diff --git a/habitat_baselines/agents/slam_agents.py b/habitat_baselines/agents/slam_agents.py index 59e2b73d3..843b09fbc 100644 --- a/habitat_baselines/agents/slam_agents.py +++ b/habitat_baselines/agents/slam_agents.py @@ -36,6 +36,8 @@ from habitat_baselines.slambased.reprojection import ( ) from habitat_baselines.slambased.utils import generate_2dgrid +GOAL_SENSOR_UUID = "pointgoal_with_gps_compass" + def download(url, filename): with open(filename, "wb") as f: @@ -107,7 +109,7 @@ class RandomAgent(object): return def is_goal_reached(self): - dist = self.obs["pointgoal"][0] + dist = self.obs[GOAL_SENSOR_UUID][0] return dist <= self.dist_threshold_to_stop def act(self, habitat_observation=None, random_prob=1.0): @@ -130,8 +132,8 @@ class BlindAgent(RandomAgent): return def decide_what_to_do(self): - distance_to_goal = self.obs["pointgoal"][0] - angle_to_goal = norm_ang(np.array(self.obs["pointgoal"][1])) + distance_to_goal = self.obs[GOAL_SENSOR_UUID][0] + angle_to_goal = norm_ang(np.array(self.obs[GOAL_SENSOR_UUID][1])) command = SimulatorActions.STOP if distance_to_goal <= self.pos_th: return command @@ -400,7 +402,9 @@ class ORBSLAM2Agent(RandomAgent): def set_offset_to_goal(self, observation): self.offset_to_goal = ( - torch.from_numpy(observation["pointgoal"]).float().to(self.device) + torch.from_numpy(observation[GOAL_SENSOR_UUID]) + .float() + .to(self.device) ) self.estimatedGoalPos2D = habitat_goalpos_to_mapgoal_pos( self.offset_to_goal, diff --git a/test/test_sensors.py b/test/test_sensors.py index 9372dacf9..b1cd266ba 100644 --- a/test/test_sensors.py +++ b/test/test_sensors.py @@ -9,11 +9,13 @@ import random import numpy as np import pytest +import quaternion import habitat from habitat.config.default import get_config from habitat.core.simulator import SimulatorActions from habitat.tasks.nav.nav_task import NavigationEpisode, NavigationGoal +from habitat.tasks.utils import quaternion_rotate_vector def _random_episode(env, config): @@ -38,12 +40,12 @@ def _random_episode(env, config): ) -def test_heading_sensor(): +def test_state_sensors(): config = get_config() if not os.path.exists(config.SIMULATOR.SCENE): pytest.skip("Please download Habitat test data to data folder.") config.defrost() - config.TASK.SENSORS = ["HEADING_SENSOR"] + config.TASK.SENSORS = ["HEADING_SENSOR", "COMPASS_SENSOR", "GPS_SENSOR"] config.freeze() env = habitat.Env(config=config, dataset=None) env.reset() @@ -73,6 +75,8 @@ def test_heading_sensor(): obs = env.reset() heading = obs["heading"] assert np.allclose(heading, random_heading) + assert np.allclose(obs["compass"], [0.0], atol=1e-5) + assert np.allclose(obs["gps"], [0.0, 0.0], atol=1e-5) env.close() @@ -153,19 +157,21 @@ def test_collisions(): env.close() -def test_static_pointgoal_sensor(): +def test_pointgoal_sensor(): config = get_config() if not os.path.exists(config.SIMULATOR.SCENE): pytest.skip("Please download Habitat test data to data folder.") config.defrost() - config.TASK.SENSORS = ["STATIC_POINTGOAL_SENSOR"] + config.TASK.SENSORS = ["POINTGOAL_SENSOR"] + config.TASK.POINTGOAL_SENSOR.DIMENSIONALITY = 3 + config.TASK.POINTGOAL_SENSOR.GOAL_FORMAT = "CARTESIAN" config.freeze() env = habitat.Env(config=config, dataset=None) # start position is checked for validity for the specific test scene valid_start_position = [-1.3731, 0.08431, 8.60692] - expected_static_pointgoal = [0.1, 0.2, 0.3] - goal_position = np.add(valid_start_position, expected_static_pointgoal) + expected_pointgoal = [0.1, 0.2, 0.3] + goal_position = np.add(valid_start_position, expected_pointgoal) # starting quaternion is rotated 180 degree along z-axis, which # corresponds to simulator using z-negative as forward action @@ -191,9 +197,78 @@ def test_static_pointgoal_sensor(): env.reset() for _ in range(100): obs = env.step(np.random.choice(non_stop_actions)) - static_pointgoal = obs["static_pointgoal"] + pointgoal = obs["pointgoal"] # check to see if taking non-stop actions will affect static point_goal - assert np.allclose(static_pointgoal, expected_static_pointgoal) + assert np.allclose(pointgoal, expected_pointgoal) + + env.close() + + +def test_pointgoal_with_gps_compass_sensor(): + config = get_config() + if not os.path.exists(config.SIMULATOR.SCENE): + pytest.skip("Please download Habitat test data to data folder.") + config.defrost() + config.TASK.SENSORS = [ + "POINTGOAL_WITH_GPS_COMPASS_SENSOR", + "COMPASS_SENSOR", + "GPS_SENSOR", + "POINTGOAL_SENSOR", + ] + config.TASK.POINTGOAL_WITH_GPS_COMPASS_SENSOR.DIMENSIONALITY = 3 + config.TASK.POINTGOAL_WITH_GPS_COMPASS_SENSOR.GOAL_FORMAT = "CARTESIAN" + + config.TASK.POINTGOAL_SENSOR.DIMENSIONALITY = 3 + config.TASK.POINTGOAL_SENSOR.GOAL_FORMAT = "CARTESIAN" + + config.TASK.GPS_SENSOR.DIMENSIONALITY = 3 + + config.freeze() + env = habitat.Env(config=config, dataset=None) + + # start position is checked for validity for the specific test scene + valid_start_position = [-1.3731, 0.08431, 8.60692] + expected_pointgoal = [0.1, 0.2, 0.3] + goal_position = np.add(valid_start_position, expected_pointgoal) + + # starting quaternion is rotated 180 degree along z-axis, which + # corresponds to simulator using z-negative as forward action + start_rotation = [0, 0, 0, 1] + + env.episode_iterator = iter( + [ + NavigationEpisode( + episode_id="0", + scene_id=config.SIMULATOR.SCENE, + start_position=valid_start_position, + start_rotation=start_rotation, + goals=[NavigationGoal(position=goal_position)], + ) + ] + ) + + non_stop_actions = [ + act + for act in range(env.action_space.n) + if act != SimulatorActions.STOP + ] + env.reset() + for _ in range(100): + obs = env.step(np.random.choice(non_stop_actions)) + pointgoal = obs["pointgoal"] + pointgoal_with_gps_compass = obs["pointgoal_with_gps_compass"] + comapss = obs["compass"] + gps = obs["gps"] + # check to see if taking non-stop actions will affect static point_goal + assert np.allclose( + pointgoal_with_gps_compass, + quaternion_rotate_vector( + quaternion.from_rotation_vector( + comapss * np.array([0, 1, 0]) + ).inverse(), + pointgoal - gps, + ), + ) env.close() @@ -210,8 +285,8 @@ def test_get_observations_at(): # start position is checked for validity for the specific test scene valid_start_position = [-1.3731, 0.08431, 8.60692] - expected_static_pointgoal = [0.1, 0.2, 0.3] - goal_position = np.add(valid_start_position, expected_static_pointgoal) + expected_pointgoal = [0.1, 0.2, 0.3] + goal_position = np.add(valid_start_position, expected_pointgoal) # starting quaternion is rotated 180 degree along z-axis, which # corresponds to simulator using z-negative as forward action -- GitLab