diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000000000000000000000000000000000..4855b1d584c79becb91e33780584c99730242daa --- /dev/null +++ b/.editorconfig @@ -0,0 +1,11 @@ +# See https://editorconfig.org/ for more info :) + +[*] +indent_style = space +indent_size = 2 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.py] +indent_size = 4 +max_line_length = 79 diff --git a/examples/pointnav_episode_gen_example.py b/examples/pointnav_episode_gen_example.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/habitat/core/benchmark.py b/habitat/core/benchmark.py index a3dd090ede0f78a651bbea87e9aa673a027d5bd1..037b6155be2c78b229f0e0bf0fbc06a5d5871753 100644 --- a/habitat/core/benchmark.py +++ b/habitat/core/benchmark.py @@ -7,7 +7,7 @@ from collections import defaultdict from typing import Dict, Optional -from habitat.config.default import get_config, DEFAULT_CONFIG_DIR +from habitat.config.default import get_config from habitat.core.agent import Agent from habitat.core.env import Env diff --git a/habitat/datasets/pointnav/pointnav_generator.py b/habitat/datasets/pointnav/pointnav_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..99cb542a2ba445b43e0f84b9395208b2bdd43ba7 --- /dev/null +++ b/habitat/datasets/pointnav/pointnav_generator.py @@ -0,0 +1,167 @@ +from typing import Optional + +import numpy as np + +from habitat.core.simulator import Simulator +from habitat.datasets.utils import get_action_shortest_path +from habitat.tasks.nav.nav_task import NavigationGoal, NavigationEpisode + +""" + A minimum radius of a plane that a point should be part of to be + considered as a target or source location. Used to filter isolated points + that aren't part of a floor. +""" +ISLAND_RADIUS_LIMIT = 1.5 + + +def _ratio_sample_rate(ratio: float, ratio_threshold: float) -> float: + """ + Sampling function for aggressive filtering of straight-line + episodes with shortest path geodesic distance to Euclid distance ratio + threshold. + + :param ratio: geodesic distance ratio to Euclid distance + :param ratio_threshold: geodesic shortest path to Euclid + distance ratio upper limit till aggressive sampling is applied. + :return: value between 0.008 and 0.144 for ratio [1, 1.1] + """ + assert ratio < ratio_threshold + return 20 * (ratio - 0.98) ** 2 + + +def is_compatible_episode( + s, t, sim, near_dist, far_dist, geodesic_to_euclid_ratio +): + euclid_dist = np.power(np.power(np.array(s) - np.array(t), 2).sum(0), 0.5) + if np.abs(s[1] - t[1]) > 0.5: # check height difference to assure s and + # t are from same floor + return False, 0 + d_separation = sim.geodesic_distance(s, t) + if d_separation == np.inf: + return False, 0 + if not near_dist <= d_separation <= far_dist: + return False, 0 + distances_ratio = d_separation / euclid_dist + if distances_ratio < geodesic_to_euclid_ratio and ( + np.random.rand() + > _ratio_sample_rate(distances_ratio, geodesic_to_euclid_ratio) + ): + return False, 0 + if sim.island_radius(s) < ISLAND_RADIUS_LIMIT: + return False, 0 + return True, d_separation + + +def _create_episode( + episode_id, + scene_id, + start_position, + start_rotation, + target_position, + shortest_paths=None, + radius=None, + info=None, +) -> Optional[NavigationEpisode]: + goals = [NavigationGoal(position=target_position, radius=radius)] + return NavigationEpisode( + episode_id=str(episode_id), + goals=goals, + scene_id=scene_id, + start_position=start_position, + start_rotation=start_rotation, + shortest_paths=shortest_paths, + info=info, + ) + + +def generate_pointnav_episode( + sim: Simulator, + num_episodes: int = -1, + is_gen_shortest_path: bool = True, + shortest_path_success_distance: float = 0.2, + shortest_path_max_steps: int = 500, + closest_dist_limit: float = 1, + furthest_dist_limit: float = 30, + geodesic_to_euclid_min_ratio: float = 1.1, + number_retries_per_target: int = 10, +) -> NavigationEpisode: + """ + Generator function that generates PointGoal navigation episodes. + + An episode is trivial if there is an obstacle-free, straight line between + the start and goal positions. A good measure of the navigation + complexity of an episode is the ratio of + geodesic shortest path position to Euclidean distance between start and + goal positions to the corresponding Euclidean distance. + If the ratio is nearly 1, it indicates there are few obstacles, and the + episode is easy; if the ratio is larger than 1, the + episode is difficult because strategic navigation is required. + To keep the navigation complexity of the precomputed episodes reasonably + high, we perform aggressive rejection sampling for episodes with the above + ratio falling in the range [1, 1.1]. + Following this, there is a significant decrease in the number of + straight-line episodes. + + + :param sim: simulator with loaded scene for generation. + :param num_episodes: number of episodes needed to generate + :param is_gen_shortest_path: option to generate shortest paths + :param shortest_path_success_distance: success distance when agent should + stop during shortest path generation + :param shortest_path_max_steps maximum number of steps shortest path + expected to be + :param closest_dist_limit episode geodesic distance lowest limit + :param furthest_dist_limit episode geodesic distance highest limit + :param geodesic_to_euclid_min_ratio geodesic shortest path to Euclid + distance ratio upper limit till aggressive sampling is applied. + :return: navigation episode that satisfy specified distribution for + currently loaded into simulator scene. + """ + episode_count = 0 + while episode_count < num_episodes or num_episodes < 0: + target_position = sim.sample_navigable_point() + + if sim.island_radius(target_position) < ISLAND_RADIUS_LIMIT: + continue + + for retry in range(number_retries_per_target): + source_position = sim.sample_navigable_point() + + is_compatible, dist = is_compatible_episode( + source_position, + target_position, + sim, + near_dist=closest_dist_limit, + far_dist=furthest_dist_limit, + geodesic_to_euclid_ratio=geodesic_to_euclid_min_ratio, + ) + if is_compatible: + angle = np.random.uniform(0, 2 * np.pi) + source_rotation = [0, np.sin(angle / 2), 0, np.cos(angle / 2)] + + shortest_paths = None + if is_gen_shortest_path: + shortest_paths = [ + get_action_shortest_path( + sim, + source_position=source_position, + source_rotation=source_rotation, + goal_position=target_position, + success_distance=shortest_path_success_distance, + max_episode_steps=shortest_path_max_steps, + ) + ] + + episode = _create_episode( + episode_id=episode_count, + scene_id=sim.config.SCENE, + start_position=source_position, + start_rotation=source_rotation, + target_position=target_position, + shortest_paths=shortest_paths, + radius=shortest_path_success_distance, + info={"geodesic_distance": dist}, + ) + + episode_count += 1 + yield episode diff --git a/habitat/datasets/utils.py b/habitat/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..918096a53588c8c1c0a128b4614b3f0f974a72e8 --- /dev/null +++ b/habitat/datasets/utils.py @@ -0,0 +1,41 @@ +from typing import List + +from habitat.core.logging import logger +from habitat.core.simulator import ShortestPathPoint +from habitat.sims.habitat_simulator import SimulatorActions +from habitat.tasks.nav.shortest_path_follower import ShortestPathFollower +from habitat.utils.geometry_utils import quaternion_to_list + + +def get_action_shortest_path( + sim, + source_position, + source_rotation, + goal_position, + success_distance=0.05, + max_episode_steps=500, + shortest_path_mode="greedy", +) -> List[ShortestPathPoint]: + sim.reset() + sim.set_agent_state(source_position, source_rotation) + follower = ShortestPathFollower(sim, success_distance, False) + follower.mode = shortest_path_mode + + shortest_path = [] + action = None + step_count = 0 + while action != SimulatorActions.STOP and step_count < max_episode_steps: + action = follower.get_next_action(goal_position) + state = sim.get_agent_state() + shortest_path.append( + ShortestPathPoint( + state.position.tolist(), + quaternion_to_list(state.rotation), + action.value, + ) + ) + sim.step(action.value) + step_count += 1 + if step_count == max_episode_steps: + logger.warning("Shortest path wasn't found.") + return shortest_path diff --git a/habitat/sims/habitat_simulator.py b/habitat/sims/habitat_simulator.py index 0a70cee0a0557c3caad5f55c378550af25be82fc..7986d67730380a37cd67c321d742d60571dd16e9 100644 --- a/habitat/sims/habitat_simulator.py +++ b/habitat/sims/habitat_simulator.py @@ -317,7 +317,8 @@ class HabitatSim(habitat.Simulator): will be None. """ raise NotImplementedError( - "This function is no longer implemented. Please use the greedy follower instead" + "This function is no longer implemented. Please use the greedy " + "follower instead" ) @property @@ -421,10 +422,11 @@ class HabitatSim(habitat.Simulator): state.position = position state.rotation = rotation - # NB: The agent state also contains the sensor states in _absolute_ coordinates. - # In order to set the agent's body to a specific location and have the sensors follow, - # we must not provide any state for the sensors. - # This will cause them to follow the agent's body + # NB: The agent state also contains the sensor states in _absolute_ + # coordinates. In order to set the agent's body to a specific + # location and have the sensors follow, we must not provide any + # state for the sensors. This will cause them to follow the agent's + # body state.sensor_states = dict() agent.set_state(state, reset_sensors) @@ -440,3 +442,6 @@ class HabitatSim(habitat.Simulator): return self._sim.pathfinder.distance_to_closest_obstacle( position, max_search_radius ) + + def island_radius(self, position): + return self._sim.pathfinder.island_radius(position) diff --git a/habitat/tasks/nav/nav_task.py b/habitat/tasks/nav/nav_task.py index 95fbda44e4585174eb013e14caf3cde1df51fecf..d6c8382cb29ec12540a28c2d6a03c59ecac8ddca 100644 --- a/habitat/tasks/nav/nav_task.py +++ b/habitat/tasks/nav/nav_task.py @@ -20,11 +20,7 @@ from habitat.core.simulator import ( SensorTypes, SensorSuite, ) -from habitat.tasks.utils import ( - quaternion_to_rotation, - cartesian_to_polar, - quaternion_rotate_vector, -) +from habitat.tasks.utils import cartesian_to_polar, quaternion_rotate_vector from habitat.utils.visualizations import maps COLLISION_PROXIMITY_TOLERANCE: float = 1e-3 diff --git a/habitat/tasks/utils.py b/habitat/tasks/utils.py index 077da1d52fd3e112d507af2832947c2a30c859a9..d41f9b7842f43c1dbf47dce6a8f9cc0c57d4223a 100644 --- a/habitat/tasks/utils.py +++ b/habitat/tasks/utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import numpy as np -import quaternion +import quaternion # noqa # pylint: disable=unused-import def quaternion_to_rotation(q_r, q_i, q_j, q_k): diff --git a/habitat/utils/geometry_utils.py b/habitat/utils/geometry_utils.py index 9bbf53e0cf82aa7274032a6b5e6f852025ade673..7249333bb860db2c65175527a24c2b5b963aa675 100644 --- a/habitat/utils/geometry_utils.py +++ b/habitat/utils/geometry_utils.py @@ -48,3 +48,9 @@ def quaternion_xyzw_to_wxyz(v: np.array): def quaternion_wxyz_to_xyzw(v: np.array): return np.quaternion(*v[1:4], v[0]) + + +def quaternion_to_list(q: np.quaternion): + return quaternion.as_float_array( + quaternion_wxyz_to_xyzw(quaternion.as_float_array(q)) + ).tolist() diff --git a/habitat/utils/visualizations/utils.py b/habitat/utils/visualizations/utils.py index 7d9142fb5930530cdf35596e4bfff861dc4bf3ae..474b980a587bb1bb3b37b548fbd4ee09998b37eb 100644 --- a/habitat/utils/visualizations/utils.py +++ b/habitat/utils/visualizations/utils.py @@ -108,8 +108,9 @@ def images_to_video( use at your own risk. quality: Default is 5. Uses variable bit rate. Highest quality is 10, lowest is 0. Set to None to prevent variable bitrate flags to - FFMPEG so you can manually specify them using output_params instead. - Specifying a fixed bitrate using ‘bitrate’ disables this parameter. + FFMPEG so you can manually specify them using output_params + instead. Specifying a fixed bitrate using ‘bitrate’ disables + this parameter. """ assert 0 <= quality <= 10 if not os.path.exists(output_dir): diff --git a/test/test_baseline_agents.py b/test/test_baseline_agents.py index a1800e216f9ce28f14281c1afe2b3b5f8e731da3..fcb6283f06cdc1d0f87994711b2d686969e29762 100644 --- a/test/test_baseline_agents.py +++ b/test/test_baseline_agents.py @@ -10,7 +10,7 @@ import pytest from habitat_baselines.agents import simple_agents try: - import torch + import torch # noqa # pylint: disable=unused-import has_torch = True except ImportError: diff --git a/test/test_pointnav_dataset.py b/test/test_pointnav_dataset.py index 8778204f01e2a836cdfbf668940df8e6e1abebd1..dbc4b9053dcc16f8aed9b29747c10d0072fbcc71 100644 --- a/test/test_pointnav_dataset.py +++ b/test/test_pointnav_dataset.py @@ -5,11 +5,14 @@ # LICENSE file in the root directory of this source tree. import os +import random import time +import numpy as np import pytest import habitat +import habitat.datasets.pointnav.pointnav_generator as pointnav_generator from habitat.config.default import get_config from habitat.core.embodied_task import Episode from habitat.core.logging import logger @@ -18,9 +21,12 @@ from habitat.datasets.pointnav.pointnav_dataset import ( PointNavDatasetV1, DEFAULT_SCENE_PATH_PREFIX, ) +from habitat.utils.geometry_utils import quaternion_xyzw_to_wxyz -CFG_TEST = "configs/datasets/pointnav/gibson.yaml" +CFG_TEST = "configs/test/habitat_all_sensors_test.yaml" +CFG_MULTI_TEST = "configs/datasets/pointnav/gibson.yaml" PARTIAL_LOAD_SCENES = 3 +NUM_EPISODES = 10 def check_json_serializaiton(dataset: habitat.Dataset): @@ -56,7 +62,7 @@ def test_single_pointnav_dataset(): def test_multiple_files_scene_path(): - dataset_config = get_config(CFG_TEST).DATASET + dataset_config = get_config(CFG_MULTI_TEST).DATASET if not PointNavDatasetV1.check_config_paths_exist(dataset_config): pytest.skip("Test skipped as dataset files are missing.") scenes = PointNavDatasetV1.get_scenes_to_load(config=dataset_config) @@ -84,7 +90,7 @@ def test_multiple_files_scene_path(): def test_multiple_files_pointnav_dataset(): - dataset_config = get_config(CFG_TEST).DATASET + dataset_config = get_config(CFG_MULTI_TEST).DATASET if not PointNavDatasetV1.check_config_paths_exist(dataset_config): pytest.skip("Test skipped as dataset files are missing.") scenes = PointNavDatasetV1.get_scenes_to_load(config=dataset_config) @@ -101,3 +107,67 @@ def test_multiple_files_pointnav_dataset(): len(partial_dataset.scene_ids) == PARTIAL_LOAD_SCENES ), "Number of loaded scenes doesn't correspond." check_json_serializaiton(partial_dataset) + + +def check_shortest_path(env, episode): + def check_state(agent_state, position, rotation): + assert np.allclose( + agent_state.rotation, quaternion_xyzw_to_wxyz(rotation) + ), "Agent's rotation diverges from the shortest path." + + assert np.allclose( + agent_state.position, position + ), "Agent's position position diverges from the shortest path's one." + + assert len(episode.goals) == 1, "Episode has no goals or more than one." + assert ( + len(episode.shortest_paths) == 1 + ), "Episode has no shortest paths or more than one." + + env.episodes = [episode] + env.reset() + start_state = env.sim.get_agent_state() + check_state(start_state, episode.start_position, episode.start_rotation) + + for step_id, point in enumerate(episode.shortest_paths[0]): + cur_state = env.sim.get_agent_state() + check_state(cur_state, point.position, point.rotation) + env.step(point.action) + + +def test_pointnav_episode_generator(): + config = get_config(CFG_TEST) + config.defrost() + config.DATASET.SPLIT = "val" + config.ENVIRONMENT.MAX_EPISODE_STEPS = 500 + config.freeze() + env = habitat.Env(config) + env.seed(config.SEED) + random.seed(config.SEED) + generator = pointnav_generator.generate_pointnav_episode( + sim=env.sim, + shortest_path_success_distance=config.TASK.SUCCESS_DISTANCE, + shortest_path_max_steps=config.ENVIRONMENT.MAX_EPISODE_STEPS, + ) + episodes = [] + for i in range(NUM_EPISODES): + episode = next(generator) + episodes.append(episode) + + for episode in pointnav_generator.generate_pointnav_episode( + sim=env.sim, + num_episodes=NUM_EPISODES, + shortest_path_success_distance=config.TASK.SUCCESS_DISTANCE, + shortest_path_max_steps=config.ENVIRONMENT.MAX_EPISODE_STEPS, + geodesic_to_euclid_min_ratio=0, + ): + episodes.append(episode) + assert len(episodes) == 2 * NUM_EPISODES + env.episodes = episodes + + for episode in episodes: + check_shortest_path(env, episode) + + dataset = habitat.Dataset() + dataset.episodes = episodes + assert dataset.to_json(), "Generated episodes aren't json serializable." diff --git a/test/test_sensors.py b/test/test_sensors.py index 15343f3ae521a790941238540aae3fdd259e9600..3dc5e5afb47e1e0d615e800bf52ab70fe6e0cea6 100644 --- a/test/test_sensors.py +++ b/test/test_sensors.py @@ -148,8 +148,8 @@ def test_collisions(): < 0.9 * config.SIMULATOR.FORWARD_STEP_SIZE and action == actions[0] ): - # Check to see if the new method of doing collisions catches all the same - # collisions as the old method + # Check to see if the new method of doing collisions catches + # all the same collisions as the old method assert collisions == prev_collisions + 1 prev_loc = loc