Skip to content
Snippets Groups Projects
Commit b5f2b00a authored by Erik Wijmans's avatar Erik Wijmans Committed by Oleksandr
Browse files

Correct the naming of PointGoalNav Sensors (#180)

This PR seeks to bring habitat-API inline with this. StaticPointGoalSensor is now just a PointGoalSensor and the old PointGoalSensor is now a PointGoalWithGPSCompassSensor. There are also now explicit GPS and Compass sensors, such that you can fairly easily compute PointGoalWithGPSCompassSensor given [PointGoalSensor, GPSSensor, CompassSensor].
The PointGoal* sensors now also include {2d, 3d} x {Cartesian, Polar} support.
parent 6d76d42f
No related branches found
No related tags found
No related merge requests found
...@@ -14,12 +14,14 @@ SIMULATOR: ...@@ -14,12 +14,14 @@ SIMULATOR:
TASK: TASK:
TYPE: Nav-v0 TYPE: Nav-v0
SUCCESS_DISTANCE: 0.2 SUCCESS_DISTANCE: 0.2
SENSORS: ['POINTGOAL_SENSOR']
POINTGOAL_SENSOR: SENSORS: ['POINTGOAL_WITH_GPS_COMPASS_SENSOR']
TYPE: PointGoalSensor POINTGOAL_WITH_GPS_COMPASS_SENSOR:
GOAL_FORMAT: POLAR GOAL_FORMAT: "POLAR"
DIMENSIONALITY: 2
GOAL_SENSOR_UUID: pointgoal_with_gps_compass
MEASUREMENTS: ['SPL'] MEASUREMENTS: ['SPL']
GOAL_SENSOR_UUID: pointgoal
SPL: SPL:
TYPE: SPL TYPE: SPL
SUCCESS_DISTANCE: 0.2 SUCCESS_DISTANCE: 0.2
...@@ -14,10 +14,13 @@ SIMULATOR: ...@@ -14,10 +14,13 @@ SIMULATOR:
TASK: TASK:
TYPE: Nav-v0 TYPE: Nav-v0
SUCCESS_DISTANCE: 0.2 SUCCESS_DISTANCE: 0.2
SENSORS: ['POINTGOAL_SENSOR']
POINTGOAL_SENSOR: SENSORS: ['POINTGOAL_WITH_GPS_COMPASS_SENSOR']
TYPE: PointGoalSensor POINTGOAL_WITH_GPS_COMPASS_SENSOR:
GOAL_FORMAT: POLAR GOAL_FORMAT: "POLAR"
DIMENSIONALITY: 2
GOAL_SENSOR_UUID: pointgoal_with_gps_compass
MEASUREMENTS: ['SPL'] MEASUREMENTS: ['SPL']
SPL: SPL:
TYPE: SPL TYPE: SPL
......
...@@ -14,10 +14,13 @@ SIMULATOR: ...@@ -14,10 +14,13 @@ SIMULATOR:
TASK: TASK:
TYPE: Nav-v0 TYPE: Nav-v0
SUCCESS_DISTANCE: 0.2 SUCCESS_DISTANCE: 0.2
SENSORS: ['POINTGOAL_SENSOR']
POINTGOAL_SENSOR: SENSORS: ['POINTGOAL_WITH_GPS_COMPASS_SENSOR']
TYPE: PointGoalSensor POINTGOAL_WITH_GPS_COMPASS_SENSOR:
GOAL_FORMAT: POLAR GOAL_FORMAT: "POLAR"
DIMENSIONALITY: 2
GOAL_SENSOR_UUID: pointgoal_with_gps_compass
MEASUREMENTS: ['SPL'] MEASUREMENTS: ['SPL']
SPL: SPL:
TYPE: SPL TYPE: SPL
......
...@@ -14,10 +14,13 @@ SIMULATOR: ...@@ -14,10 +14,13 @@ SIMULATOR:
TASK: TASK:
TYPE: Nav-v0 TYPE: Nav-v0
SUCCESS_DISTANCE: 0.2 SUCCESS_DISTANCE: 0.2
SENSORS: ['POINTGOAL_SENSOR']
POINTGOAL_SENSOR: SENSORS: ['POINTGOAL_WITH_GPS_COMPASS_SENSOR']
TYPE: PointGoalSensor POINTGOAL_WITH_GPS_COMPASS_SENSOR:
GOAL_FORMAT: POLAR GOAL_FORMAT: "POLAR"
DIMENSIONALITY: 2
GOAL_SENSOR_UUID: pointgoal_with_gps_compass
MEASUREMENTS: ['SPL'] MEASUREMENTS: ['SPL']
SPL: SPL:
TYPE: SPL TYPE: SPL
......
...@@ -16,12 +16,14 @@ DATASET: ...@@ -16,12 +16,14 @@ DATASET:
TASK: TASK:
TYPE: Nav-v0 TYPE: Nav-v0
SUCCESS_DISTANCE: 0.2 SUCCESS_DISTANCE: 0.2
SENSORS: ['POINTGOAL_SENSOR']
POINTGOAL_SENSOR: SENSORS: ['POINTGOAL_WITH_GPS_COMPASS_SENSOR']
TYPE: PointGoalSensor POINTGOAL_WITH_GPS_COMPASS_SENSOR:
GOAL_FORMAT: POLAR GOAL_FORMAT: "POLAR"
DIMENSIONALITY: 2
GOAL_SENSOR_UUID: pointgoal_with_gps_compass
MEASUREMENTS: ['SPL'] MEASUREMENTS: ['SPL']
GOAL_SENSOR_UUID: pointgoal
SPL: SPL:
TYPE: SPL TYPE: SPL
SUCCESS_DISTANCE: 0.2 SUCCESS_DISTANCE: 0.2
...@@ -43,18 +43,31 @@ _C.TASK.GOAL_SENSOR_UUID = "pointgoal" ...@@ -43,18 +43,31 @@ _C.TASK.GOAL_SENSOR_UUID = "pointgoal"
_C.TASK.POINTGOAL_SENSOR = CN() _C.TASK.POINTGOAL_SENSOR = CN()
_C.TASK.POINTGOAL_SENSOR.TYPE = "PointGoalSensor" _C.TASK.POINTGOAL_SENSOR.TYPE = "PointGoalSensor"
_C.TASK.POINTGOAL_SENSOR.GOAL_FORMAT = "POLAR" _C.TASK.POINTGOAL_SENSOR.GOAL_FORMAT = "POLAR"
_C.TASK.POINTGOAL_SENSOR.DIMENSIONALITY = 2
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# # STATIC POINTGOAL SENSOR # # POINTGOAL WITH GPS+COMPASS SENSOR
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
_C.TASK.STATIC_POINTGOAL_SENSOR = CN() _C.TASK.POINTGOAL_WITH_GPS_COMPASS_SENSOR = _C.TASK.POINTGOAL_SENSOR.clone()
_C.TASK.STATIC_POINTGOAL_SENSOR.TYPE = "StaticPointGoalSensor" _C.TASK.POINTGOAL_WITH_GPS_COMPASS_SENSOR.TYPE = (
_C.TASK.STATIC_POINTGOAL_SENSOR.GOAL_FORMAT = "CARTESIAN" "PointGoalWithGPSCompassSensor"
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# # HEADING SENSOR # # HEADING SENSOR
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
_C.TASK.HEADING_SENSOR = CN() _C.TASK.HEADING_SENSOR = CN()
_C.TASK.HEADING_SENSOR.TYPE = "HeadingSensor" _C.TASK.HEADING_SENSOR.TYPE = "HeadingSensor"
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# # COMPASS SENSOR
# -----------------------------------------------------------------------------
_C.TASK.COMPASS_SENSOR = CN()
_C.TASK.COMPASS_SENSOR.TYPE = "CompassSensor"
# -----------------------------------------------------------------------------
# # GPS SENSOR
# -----------------------------------------------------------------------------
_C.TASK.GPS_SENSOR = CN()
_C.TASK.GPS_SENSOR.TYPE = "GPSSensor"
_C.TASK.GPS_SENSOR.DIMENSIONALITY = 2
# -----------------------------------------------------------------------------
# # PROXIMITY SENSOR # # PROXIMITY SENSOR
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
_C.TASK.PROXIMITY_SENSOR = CN() _C.TASK.PROXIMITY_SENSOR = CN()
......
...@@ -23,7 +23,11 @@ from habitat.core.simulator import ( ...@@ -23,7 +23,11 @@ from habitat.core.simulator import (
Simulator, Simulator,
) )
from habitat.core.utils import not_none_validator from habitat.core.utils import not_none_validator
from habitat.tasks.utils import cartesian_to_polar, quaternion_rotate_vector from habitat.tasks.utils import (
cartesian_to_polar,
quaternion_from_coeff,
quaternion_rotate_vector,
)
from habitat.utils.visualizations import fog_of_war, maps from habitat.utils.visualizations import fog_of_war, maps
MAP_THICKNESS_SCALAR: int = 1250 MAP_THICKNESS_SCALAR: int = 1250
...@@ -107,7 +111,8 @@ class NavigationEpisode(Episode): ...@@ -107,7 +111,8 @@ class NavigationEpisode(Episode):
@registry.register_sensor @registry.register_sensor
class PointGoalSensor(Sensor): class PointGoalSensor(Sensor):
r"""Sensor for PointGoal observations which are used in the PointNav task. r"""Sensor for PointGoal observations which are used in PointGoal Navigation.
For the agent in simulator the forward direction is along negative-z. For the agent in simulator the forward direction is along negative-z.
In polar coordinate format the angle returned is azimuth to the goal. In polar coordinate format the angle returned is azimuth to the goal.
...@@ -118,9 +123,13 @@ class PointGoalSensor(Sensor): ...@@ -118,9 +123,13 @@ class PointGoalSensor(Sensor):
the pointgoal is specified. Current options for goal format are the pointgoal is specified. Current options for goal format are
cartesian and polar. cartesian and polar.
Also contains a DIMENSIONALITY field which specifes the number
of dimensions ued to specify the goal, must be in [2, 3]
Attributes: Attributes:
_goal_format: format for specifying the goal which can be done _goal_format: format for specifying the goal which can be done
in cartesian or polar coordinates. in cartesian or polar coordinates.
_dimensionality: number of dimensions used to specify the goal
""" """
def __init__(self, sim: Simulator, config: Config): def __init__(self, sim: Simulator, config: Config):
...@@ -129,6 +138,9 @@ class PointGoalSensor(Sensor): ...@@ -129,6 +138,9 @@ class PointGoalSensor(Sensor):
self._goal_format = getattr(config, "GOAL_FORMAT", "CARTESIAN") self._goal_format = getattr(config, "GOAL_FORMAT", "CARTESIAN")
assert self._goal_format in ["CARTESIAN", "POLAR"] assert self._goal_format in ["CARTESIAN", "POLAR"]
self._dimensionality = getattr(config, "DIMENSIONALITY", 2)
assert self._dimensionality in [2, 3]
super().__init__(config=config) super().__init__(config=config)
def _get_uuid(self, *args: Any, **kwargs: Any): def _get_uuid(self, *args: Any, **kwargs: Any):
...@@ -138,10 +150,8 @@ class PointGoalSensor(Sensor): ...@@ -138,10 +150,8 @@ class PointGoalSensor(Sensor):
return SensorTypes.PATH return SensorTypes.PATH
def _get_observation_space(self, *args: Any, **kwargs: Any): def _get_observation_space(self, *args: Any, **kwargs: Any):
if self._goal_format == "CARTESIAN": sensor_shape = (self._dimensionality,)
sensor_shape = (3,)
else:
sensor_shape = (2,)
return spaces.Box( return spaces.Box(
low=np.finfo(np.float32).min, low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max, high=np.finfo(np.float32).max,
...@@ -149,98 +159,85 @@ class PointGoalSensor(Sensor): ...@@ -149,98 +159,85 @@ class PointGoalSensor(Sensor):
dtype=np.float32, dtype=np.float32,
) )
def get_observation(self, observations, episode): def _compute_pointgoal(
agent_state = self._sim.get_agent_state() self, source_position, source_rotation, goal_position
ref_position = agent_state.position ):
rotation_world_agent = agent_state.rotation direction_vector = goal_position - source_position
direction_vector = (
np.array(episode.goals[0].position, dtype=np.float32)
- ref_position
)
direction_vector_agent = quaternion_rotate_vector( direction_vector_agent = quaternion_rotate_vector(
rotation_world_agent.inverse(), direction_vector source_rotation.inverse(), direction_vector
) )
if self._goal_format == "POLAR": if self._goal_format == "POLAR":
rho, phi = cartesian_to_polar( if self._dimensionality == 2:
-direction_vector_agent[2], direction_vector_agent[0] rho, phi = cartesian_to_polar(
) -direction_vector_agent[2], direction_vector_agent[0]
direction_vector_agent = np.array([rho, -phi], dtype=np.float32) )
return np.array([rho, -phi], dtype=np.float32)
else:
_, phi = cartesian_to_polar(
-direction_vector_agent[2], direction_vector_agent[0]
)
theta = np.arccos(
direction_vector_agent[1]
/ np.linalg.norm(direction_vector_agent)
)
rho = np.linalg.norm(direction_vector_agent)
return np.array([rho, -phi, theta], dtype=np.float32)
else:
if self._dimensionality == 2:
return np.array(
[-direction_vector_agent[2], direction_vector_agent[0]],
dtype=np.float32,
)
else:
return direction_vector_agent
def get_observation(self, observations, episode: Episode):
source_position = np.array(episode.start_position, dtype=np.float32)
rotation_world_start = quaternion_from_coeff(episode.start_rotation)
goal_position = np.array(episode.goals[0].position, dtype=np.float32)
return direction_vector_agent return self._compute_pointgoal(
source_position, rotation_world_start, goal_position
)
@registry.register_sensor @registry.register_sensor(name="PointGoalWithGPSCompassSensor")
class StaticPointGoalSensor(Sensor): class IntegratedPointGoalGPSAndCompassSensor(PointGoalSensor):
r"""Sensor for PointGoal observations which are used in the StaticPointNav r"""Sensor that integrates PointGoals observations (which are used PointGoal Navigation) and GPS+Compass.
task. For the agent in simulator the forward direction is along negative-z.
For the agent in simulator the forward direction is along negative-z.
In polar coordinate format the angle returned is azimuth to the goal. In polar coordinate format the angle returned is azimuth to the goal.
Args: Args:
sim: reference to the simulator for calculating task observations. sim: reference to the simulator for calculating task observations.
config: config for the PointGoal sensor. Can contain field for config: config for the PointGoal sensor. Can contain field for
GOAL_FORMAT which can be used to specify the format in which GOAL_FORMAT which can be used to specify the format in which
the pointgoal is specified. Current options for goal format are the pointgoal is specified. Current options for goal format are
cartesian and polar. cartesian and polar.
Also contains a DIMENSIONALITY field which specifes the number
of dimensions ued to specify the goal, must be in [2, 3]
Attributes: Attributes:
_goal_format: format for specifying the goal which can be done _goal_format: format for specifying the goal which can be done
in cartesian or polar coordinates. in cartesian or polar coordinates.
_dimensionality: number of dimensions used to specify the goal
""" """
def __init__(self, sim: Simulator, config: Config):
self._sim = sim
self._goal_format = getattr(config, "GOAL_FORMAT", "CARTESIAN")
assert self._goal_format in ["CARTESIAN", "POLAR"]
super().__init__(sim, config)
self._initial_vector = None
self.current_episode_id = None
def _get_uuid(self, *args: Any, **kwargs: Any): def _get_uuid(self, *args: Any, **kwargs: Any):
return "static_pointgoal" return "pointgoal_with_gps_compass"
def _get_sensor_type(self, *args: Any, **kwargs: Any):
return SensorTypes.PATH
def _get_observation_space(self, *args: Any, **kwargs: Any):
if self._goal_format == "CARTESIAN":
sensor_shape = (3,)
else:
sensor_shape = (2,)
return spaces.Box(
low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max,
shape=sensor_shape,
dtype=np.float32,
)
def get_observation(self, observations, episode): def get_observation(self, observations, episode):
episode_id = (episode.episode_id, episode.scene_id) agent_state = self._sim.get_agent_state()
if self.current_episode_id != episode_id: agent_position = agent_state.position
# Only compute the direction vector when a new episode is started. rotation_world_agent = agent_state.rotation
self.current_episode_id = episode_id goal_position = np.array(episode.goals[0].position, dtype=np.float32)
agent_state = self._sim.get_agent_state()
ref_position = agent_state.position
rotation_world_agent = agent_state.rotation
direction_vector = (
np.array(episode.goals[0].position, dtype=np.float32)
- ref_position
)
direction_vector_agent = quaternion_rotate_vector(
rotation_world_agent.inverse(), direction_vector
)
if self._goal_format == "POLAR":
rho, phi = cartesian_to_polar(
-direction_vector_agent[2], direction_vector_agent[0]
)
direction_vector_agent = np.array(
[rho, -phi], dtype=np.float32
)
self._initial_vector = direction_vector_agent return self._compute_pointgoal(
return self._initial_vector agent_position, rotation_world_agent, goal_position
)
@registry.register_sensor @registry.register_sensor
...@@ -266,18 +263,91 @@ class HeadingSensor(Sensor): ...@@ -266,18 +263,91 @@ class HeadingSensor(Sensor):
def _get_observation_space(self, *args: Any, **kwargs: Any): def _get_observation_space(self, *args: Any, **kwargs: Any):
return spaces.Box(low=-np.pi, high=np.pi, shape=(1,), dtype=np.float) return spaces.Box(low=-np.pi, high=np.pi, shape=(1,), dtype=np.float)
def _quat_to_xy_heading(self, quat):
direction_vector = np.array([0, 0, -1])
heading_vector = quaternion_rotate_vector(quat, direction_vector)
phi = cartesian_to_polar(-heading_vector[2], heading_vector[0])[1]
return np.array(phi)
def get_observation(self, observations, episode): def get_observation(self, observations, episode):
agent_state = self._sim.get_agent_state() agent_state = self._sim.get_agent_state()
rotation_world_agent = agent_state.rotation rotation_world_agent = agent_state.rotation
direction_vector = np.array([0, 0, -1]) return self._quat_to_xy_heading(rotation_world_agent.inverse())
heading_vector = quaternion_rotate_vector(
rotation_world_agent.inverse(), direction_vector @registry.register_sensor(name="CompassSensor")
class EpisodicCompassSensor(HeadingSensor):
r"""The agents heading in the coordinate frame defined by the epiosde,
theta=0 is defined by the agents state at t=0
"""
def _get_uuid(self, *args: Any, **kwargs: Any):
return "compass"
def get_observation(self, observations, episode):
agent_state = self._sim.get_agent_state()
rotation_world_agent = agent_state.rotation
rotation_world_start = quaternion_from_coeff(episode.start_rotation)
return self._quat_to_xy_heading(
rotation_world_agent.inverse() * rotation_world_start
) )
phi = cartesian_to_polar(-heading_vector[2], heading_vector[0])[1]
return np.array(phi) @registry.register_sensor(name="GPSSensor")
class EpisodicGPSSensor(Sensor):
r"""The agents current location in the coordinate frame defined by the episode,
i.e. the axis it faces along and the origin is defined by its state at t=0
Args:
sim: reference to the simulator for calculating task observations.
config: Contains the DIMENSIONALITY field for the number of dimensions to express the agents position
Attributes:
_dimensionality: number of dimensions used to specify the agents position
"""
def __init__(self, sim: Simulator, config: Config):
self._sim = sim
self._dimensionality = getattr(config, "DIMENSIONALITY", 2)
assert self._dimensionality in [2, 3]
super().__init__(config=config)
def _get_uuid(self, *args: Any, **kwargs: Any):
return "gps"
def _get_sensor_type(self, *args: Any, **kwargs: Any):
return SensorTypes.POSITION
def _get_observation_space(self, *args: Any, **kwargs: Any):
sensor_shape = (self._dimensionality,)
return spaces.Box(
low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max,
shape=sensor_shape,
dtype=np.float32,
)
def get_observation(self, observations, episode):
agent_state = self._sim.get_agent_state()
origin = np.array(episode.start_position, dtype=np.float32)
rotation_world_start = quaternion_from_coeff(episode.start_rotation)
agent_position = agent_state.position
agent_position = quaternion_rotate_vector(
rotation_world_start.inverse(), agent_position - origin
)
if self._dimensionality == 2:
return np.array(
[-agent_position[2], agent_position[0]], dtype=np.float32
)
else:
return agent_position.astype(np.float32)
@registry.register_sensor @registry.register_sensor
......
...@@ -51,6 +51,15 @@ def quaternion_rotate_vector(quat: np.quaternion, v: np.array) -> np.array: ...@@ -51,6 +51,15 @@ def quaternion_rotate_vector(quat: np.quaternion, v: np.array) -> np.array:
return (quat * vq * quat.inverse()).imag return (quat * vq * quat.inverse()).imag
def quaternion_from_coeff(coeffs: np.ndarray) -> np.quaternion:
r"""Creates a quaternions from coeffs in [x, y, z, w] format
"""
quat = np.quaternion(0, 0, 0, 0)
quat.real = coeffs[3]
quat.imag = coeffs[0:3]
return quat
def cartesian_to_polar(x, y): def cartesian_to_polar(x, y):
rho = np.sqrt(x ** 2 + y ** 2) rho = np.sqrt(x ** 2 + y ** 2)
phi = np.arctan2(y, x) phi = np.arctan2(y, x)
......
...@@ -28,7 +28,7 @@ def get_default_config(): ...@@ -28,7 +28,7 @@ def get_default_config():
c.HIDDEN_SIZE = 512 c.HIDDEN_SIZE = 512
c.RANDOM_SEED = 7 c.RANDOM_SEED = 7
c.PTH_GPU_ID = 0 c.PTH_GPU_ID = 0
c.GOAL_SENSOR_UUID = "pointgoal" c.GOAL_SENSOR_UUID = "pointgoal_with_gps_compass"
return c return c
......
...@@ -36,6 +36,8 @@ from habitat_baselines.slambased.reprojection import ( ...@@ -36,6 +36,8 @@ from habitat_baselines.slambased.reprojection import (
) )
from habitat_baselines.slambased.utils import generate_2dgrid from habitat_baselines.slambased.utils import generate_2dgrid
GOAL_SENSOR_UUID = "pointgoal_with_gps_compass"
def download(url, filename): def download(url, filename):
with open(filename, "wb") as f: with open(filename, "wb") as f:
...@@ -107,7 +109,7 @@ class RandomAgent(object): ...@@ -107,7 +109,7 @@ class RandomAgent(object):
return return
def is_goal_reached(self): def is_goal_reached(self):
dist = self.obs["pointgoal"][0] dist = self.obs[GOAL_SENSOR_UUID][0]
return dist <= self.dist_threshold_to_stop return dist <= self.dist_threshold_to_stop
def act(self, habitat_observation=None, random_prob=1.0): def act(self, habitat_observation=None, random_prob=1.0):
...@@ -130,8 +132,8 @@ class BlindAgent(RandomAgent): ...@@ -130,8 +132,8 @@ class BlindAgent(RandomAgent):
return return
def decide_what_to_do(self): def decide_what_to_do(self):
distance_to_goal = self.obs["pointgoal"][0] distance_to_goal = self.obs[GOAL_SENSOR_UUID][0]
angle_to_goal = norm_ang(np.array(self.obs["pointgoal"][1])) angle_to_goal = norm_ang(np.array(self.obs[GOAL_SENSOR_UUID][1]))
command = SimulatorActions.STOP command = SimulatorActions.STOP
if distance_to_goal <= self.pos_th: if distance_to_goal <= self.pos_th:
return command return command
...@@ -400,7 +402,9 @@ class ORBSLAM2Agent(RandomAgent): ...@@ -400,7 +402,9 @@ class ORBSLAM2Agent(RandomAgent):
def set_offset_to_goal(self, observation): def set_offset_to_goal(self, observation):
self.offset_to_goal = ( self.offset_to_goal = (
torch.from_numpy(observation["pointgoal"]).float().to(self.device) torch.from_numpy(observation[GOAL_SENSOR_UUID])
.float()
.to(self.device)
) )
self.estimatedGoalPos2D = habitat_goalpos_to_mapgoal_pos( self.estimatedGoalPos2D = habitat_goalpos_to_mapgoal_pos(
self.offset_to_goal, self.offset_to_goal,
......
...@@ -9,11 +9,13 @@ import random ...@@ -9,11 +9,13 @@ import random
import numpy as np import numpy as np
import pytest import pytest
import quaternion
import habitat import habitat
from habitat.config.default import get_config from habitat.config.default import get_config
from habitat.core.simulator import SimulatorActions from habitat.core.simulator import SimulatorActions
from habitat.tasks.nav.nav_task import NavigationEpisode, NavigationGoal from habitat.tasks.nav.nav_task import NavigationEpisode, NavigationGoal
from habitat.tasks.utils import quaternion_rotate_vector
def _random_episode(env, config): def _random_episode(env, config):
...@@ -38,12 +40,12 @@ def _random_episode(env, config): ...@@ -38,12 +40,12 @@ def _random_episode(env, config):
) )
def test_heading_sensor(): def test_state_sensors():
config = get_config() config = get_config()
if not os.path.exists(config.SIMULATOR.SCENE): if not os.path.exists(config.SIMULATOR.SCENE):
pytest.skip("Please download Habitat test data to data folder.") pytest.skip("Please download Habitat test data to data folder.")
config.defrost() config.defrost()
config.TASK.SENSORS = ["HEADING_SENSOR"] config.TASK.SENSORS = ["HEADING_SENSOR", "COMPASS_SENSOR", "GPS_SENSOR"]
config.freeze() config.freeze()
env = habitat.Env(config=config, dataset=None) env = habitat.Env(config=config, dataset=None)
env.reset() env.reset()
...@@ -73,6 +75,8 @@ def test_heading_sensor(): ...@@ -73,6 +75,8 @@ def test_heading_sensor():
obs = env.reset() obs = env.reset()
heading = obs["heading"] heading = obs["heading"]
assert np.allclose(heading, random_heading) assert np.allclose(heading, random_heading)
assert np.allclose(obs["compass"], [0.0], atol=1e-5)
assert np.allclose(obs["gps"], [0.0, 0.0], atol=1e-5)
env.close() env.close()
...@@ -153,19 +157,21 @@ def test_collisions(): ...@@ -153,19 +157,21 @@ def test_collisions():
env.close() env.close()
def test_static_pointgoal_sensor(): def test_pointgoal_sensor():
config = get_config() config = get_config()
if not os.path.exists(config.SIMULATOR.SCENE): if not os.path.exists(config.SIMULATOR.SCENE):
pytest.skip("Please download Habitat test data to data folder.") pytest.skip("Please download Habitat test data to data folder.")
config.defrost() config.defrost()
config.TASK.SENSORS = ["STATIC_POINTGOAL_SENSOR"] config.TASK.SENSORS = ["POINTGOAL_SENSOR"]
config.TASK.POINTGOAL_SENSOR.DIMENSIONALITY = 3
config.TASK.POINTGOAL_SENSOR.GOAL_FORMAT = "CARTESIAN"
config.freeze() config.freeze()
env = habitat.Env(config=config, dataset=None) env = habitat.Env(config=config, dataset=None)
# start position is checked for validity for the specific test scene # start position is checked for validity for the specific test scene
valid_start_position = [-1.3731, 0.08431, 8.60692] valid_start_position = [-1.3731, 0.08431, 8.60692]
expected_static_pointgoal = [0.1, 0.2, 0.3] expected_pointgoal = [0.1, 0.2, 0.3]
goal_position = np.add(valid_start_position, expected_static_pointgoal) goal_position = np.add(valid_start_position, expected_pointgoal)
# starting quaternion is rotated 180 degree along z-axis, which # starting quaternion is rotated 180 degree along z-axis, which
# corresponds to simulator using z-negative as forward action # corresponds to simulator using z-negative as forward action
...@@ -191,9 +197,78 @@ def test_static_pointgoal_sensor(): ...@@ -191,9 +197,78 @@ def test_static_pointgoal_sensor():
env.reset() env.reset()
for _ in range(100): for _ in range(100):
obs = env.step(np.random.choice(non_stop_actions)) obs = env.step(np.random.choice(non_stop_actions))
static_pointgoal = obs["static_pointgoal"] pointgoal = obs["pointgoal"]
# check to see if taking non-stop actions will affect static point_goal # check to see if taking non-stop actions will affect static point_goal
assert np.allclose(static_pointgoal, expected_static_pointgoal) assert np.allclose(pointgoal, expected_pointgoal)
env.close()
def test_pointgoal_with_gps_compass_sensor():
config = get_config()
if not os.path.exists(config.SIMULATOR.SCENE):
pytest.skip("Please download Habitat test data to data folder.")
config.defrost()
config.TASK.SENSORS = [
"POINTGOAL_WITH_GPS_COMPASS_SENSOR",
"COMPASS_SENSOR",
"GPS_SENSOR",
"POINTGOAL_SENSOR",
]
config.TASK.POINTGOAL_WITH_GPS_COMPASS_SENSOR.DIMENSIONALITY = 3
config.TASK.POINTGOAL_WITH_GPS_COMPASS_SENSOR.GOAL_FORMAT = "CARTESIAN"
config.TASK.POINTGOAL_SENSOR.DIMENSIONALITY = 3
config.TASK.POINTGOAL_SENSOR.GOAL_FORMAT = "CARTESIAN"
config.TASK.GPS_SENSOR.DIMENSIONALITY = 3
config.freeze()
env = habitat.Env(config=config, dataset=None)
# start position is checked for validity for the specific test scene
valid_start_position = [-1.3731, 0.08431, 8.60692]
expected_pointgoal = [0.1, 0.2, 0.3]
goal_position = np.add(valid_start_position, expected_pointgoal)
# starting quaternion is rotated 180 degree along z-axis, which
# corresponds to simulator using z-negative as forward action
start_rotation = [0, 0, 0, 1]
env.episode_iterator = iter(
[
NavigationEpisode(
episode_id="0",
scene_id=config.SIMULATOR.SCENE,
start_position=valid_start_position,
start_rotation=start_rotation,
goals=[NavigationGoal(position=goal_position)],
)
]
)
non_stop_actions = [
act
for act in range(env.action_space.n)
if act != SimulatorActions.STOP
]
env.reset()
for _ in range(100):
obs = env.step(np.random.choice(non_stop_actions))
pointgoal = obs["pointgoal"]
pointgoal_with_gps_compass = obs["pointgoal_with_gps_compass"]
comapss = obs["compass"]
gps = obs["gps"]
# check to see if taking non-stop actions will affect static point_goal
assert np.allclose(
pointgoal_with_gps_compass,
quaternion_rotate_vector(
quaternion.from_rotation_vector(
comapss * np.array([0, 1, 0])
).inverse(),
pointgoal - gps,
),
)
env.close() env.close()
...@@ -210,8 +285,8 @@ def test_get_observations_at(): ...@@ -210,8 +285,8 @@ def test_get_observations_at():
# start position is checked for validity for the specific test scene # start position is checked for validity for the specific test scene
valid_start_position = [-1.3731, 0.08431, 8.60692] valid_start_position = [-1.3731, 0.08431, 8.60692]
expected_static_pointgoal = [0.1, 0.2, 0.3] expected_pointgoal = [0.1, 0.2, 0.3]
goal_position = np.add(valid_start_position, expected_static_pointgoal) goal_position = np.add(valid_start_position, expected_pointgoal)
# starting quaternion is rotated 180 degree along z-axis, which # starting quaternion is rotated 180 degree along z-axis, which
# corresponds to simulator using z-negative as forward action # corresponds to simulator using z-negative as forward action
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment