From b6a3d649d04a81fb9a7569db679d34d6801a4480 Mon Sep 17 00:00:00 2001 From: Erik Wijmans <etw@gatech.edu> Date: Mon, 10 Jun 2019 14:41:40 -0700 Subject: [PATCH] Fixes and upgrades the collision logic (#117) Fixes and upgrades the collision logic. --- habitat/core/simulator.py | 10 ++++++++ habitat/sims/habitat_simulator.py | 24 +++++++++++++++++++- habitat/tasks/nav/nav_task.py | 8 +------ habitat_baselines/slambased/path_planners.py | 2 +- test/test_dataset.py | 1 + test/test_sensors.py | 22 ++++++------------ 6 files changed, 43 insertions(+), 24 deletions(-) diff --git a/habitat/core/simulator.py b/habitat/core/simulator.py index e3325e47f..b199ba5f3 100644 --- a/habitat/core/simulator.py +++ b/habitat/core/simulator.py @@ -366,3 +366,13 @@ class Simulator: def close(self) -> None: raise NotImplementedError + + @property + def previous_step_collided(self): + r"""Whether or not the previous step resulted in a collision + + Returns: + bool: True if the previous step resulted in a collision, false otherwise + + """ + raise NotImplementedError diff --git a/habitat/sims/habitat_simulator.py b/habitat/sims/habitat_simulator.py index 2b98b6d84..32da8af5b 100644 --- a/habitat/sims/habitat_simulator.py +++ b/habitat/sims/habitat_simulator.py @@ -69,9 +69,10 @@ class HabitatSimRGBSensor(RGBSensor): def get_observation(self, sim_obs): obs = sim_obs.get(self.uuid, None) + check_sim_obs(obs, self) + # remove alpha channel obs = obs[:, :, :RGBSENSOR_DIMENSION] - check_sim_obs(obs, self) return obs @@ -262,6 +263,7 @@ class HabitatSim(Simulator): if self._update_agents_state(): sim_obs = self._sim.get_sensor_observations() + self._prev_sim_obs = sim_obs self._is_episode_active = True return self._sensor_suite.get_observations(sim_obs) @@ -277,6 +279,8 @@ class HabitatSim(Simulator): else: sim_obs = self._sim.step(action) + self._prev_sim_obs = sim_obs + observations = self._sensor_suite.get_observations(sim_obs) return observations @@ -471,6 +475,9 @@ class HabitatSim(Simulator): success = self.set_agent_state(position, rotation, reset_sensors=False) if success: sim_obs = self._sim.get_sensor_observations() + + self._prev_sim_obs = sim_obs + observations = self._sensor_suite.get_observations(sim_obs) if not keep_agent_at_new_pose: self.set_agent_state( @@ -496,3 +503,18 @@ class HabitatSim(Simulator): def island_radius(self, position): return self._sim.pathfinder.island_radius(position) + + @property + def previous_step_collided(self): + r"""Whether or not the previous step resulted in a collision + + Returns: + bool: True if the previous step resulted in a collision, false otherwise + + Warning: + This feild is only updated when :meth:`step`, :meth:`reset`, or :meth:`get_observations_at` are + called. It does not update when the agent is moved to a new loction. Furthermore, it + will _always_ be false after :meth:`reset` or :meth:`get_observations_at` as neither of those + result in an action (step) being taken. + """ + return self._prev_sim_obs.get("collided", False) diff --git a/habitat/tasks/nav/nav_task.py b/habitat/tasks/nav/nav_task.py index a6bf8df37..16a3d5d10 100644 --- a/habitat/tasks/nav/nav_task.py +++ b/habitat/tasks/nav/nav_task.py @@ -26,7 +26,6 @@ from habitat.core.utils import not_none_validator from habitat.tasks.utils import cartesian_to_polar, quaternion_rotate_vector from habitat.utils.visualizations import maps -COLLISION_PROXIMITY_TOLERANCE: float = 1e-3 MAP_THICKNESS_SCALAR: int = 1250 @@ -396,12 +395,7 @@ class Collisions(Measure): if self._metric is None: self._metric = 0 - current_position = self._sim.get_agent_state().position - if ( - action == self._sim.index_forward_action - and self._sim.distance_to_closest_obstacle(current_position) - < COLLISION_PROXIMITY_TOLERANCE - ): + if self._sim.previous_step_collided: self._metric += 1 diff --git a/habitat_baselines/slambased/path_planners.py b/habitat_baselines/slambased/path_planners.py index f56776df4..9e15e1574 100644 --- a/habitat_baselines/slambased/path_planners.py +++ b/habitat_baselines/slambased/path_planners.py @@ -1,9 +1,9 @@ +import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import matplotlib.pyplot as plt from habitat_baselines.slambased.utils import generate_2dgrid diff --git a/test/test_dataset.py b/test/test_dataset.py index 4edc866c6..01d59f8da 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import pytest + from habitat.core.dataset import Dataset, Episode diff --git a/test/test_sensors.py b/test/test_sensors.py index 86dc34d13..c99d4e826 100644 --- a/test/test_sensors.py +++ b/test/test_sensors.py @@ -11,15 +11,9 @@ import numpy as np import pytest import habitat -import numpy as np -import pytest from habitat.config.default import get_config from habitat.sims.habitat_simulator import SimulatorActions -from habitat.tasks.nav.nav_task import ( - COLLISION_PROXIMITY_TOLERANCE, - NavigationEpisode, - NavigationGoal, -) +from habitat.tasks.nav.nav_task import NavigationEpisode, NavigationGoal NON_STOP_ACTIONS = [ v for v in range(len(SimulatorActions)) if v != SimulatorActions.STOP.value @@ -95,7 +89,6 @@ def test_tactile(): pytest.skip("Please download Habitat test data to data folder.") config.defrost() config.TASK.SENSORS = ["PROXIMITY_SENSOR"] - config.TASK.MEASUREMENTS = ["COLLISIONS"] config.freeze() env = habitat.Env(config=config, dataset=None) env.reset() @@ -104,18 +97,13 @@ def test_tactile(): for _ in range(20): _random_episode(env, config) env.reset() - assert env.get_metrics()["collisions"] is None - my_collisions_count = 0 action = env._sim.index_forward_action for _ in range(10): obs = env.step(action) - collisions = env.get_metrics()["collisions"] proximity = obs["proximity"] - if proximity < COLLISION_PROXIMITY_TOLERANCE: - my_collisions_count += 1 - - assert my_collisions_count == collisions + assert 0.0 <= proximity + assert 2.0 >= proximity env.close() @@ -161,6 +149,10 @@ def test_collisions(): # all the same collisions as the old method assert collisions == prev_collisions + 1 + # We can _never_ collide with standard turn actions + if action != actions[0]: + assert collisions == prev_collisions + prev_loc = loc prev_collisions = collisions -- GitLab