Skip to content
Snippets Groups Projects
Unverified Commit e703e610 authored by Oleksandr's avatar Oleksandr Committed by GitHub
Browse files

Added episode generation code for PointGoal task. (#81)

* Added episode generation code for PointGoal task.
* Removed usage of internal methods, added island_radius method to HabitatSim
* Added generator test with shortest path proper test and episode serialization testing.
* Fixed formatting issues and unused imports across the codebase.
* Added editor config style formatter properties.
parent ac166fa8
No related branches found
No related tags found
No related merge requests found
# See https://editorconfig.org/ for more info :)
[*]
indent_style = space
indent_size = 2
trim_trailing_whitespace = true
insert_final_newline = true
[*.py]
indent_size = 4
max_line_length = 79
......@@ -7,7 +7,7 @@
from collections import defaultdict
from typing import Dict, Optional
from habitat.config.default import get_config, DEFAULT_CONFIG_DIR
from habitat.config.default import get_config
from habitat.core.agent import Agent
from habitat.core.env import Env
......
from typing import Optional
import numpy as np
from habitat.core.simulator import Simulator
from habitat.datasets.utils import get_action_shortest_path
from habitat.tasks.nav.nav_task import NavigationGoal, NavigationEpisode
"""
A minimum radius of a plane that a point should be part of to be
considered as a target or source location. Used to filter isolated points
that aren't part of a floor.
"""
ISLAND_RADIUS_LIMIT = 1.5
def _ratio_sample_rate(ratio: float, ratio_threshold: float) -> float:
"""
Sampling function for aggressive filtering of straight-line
episodes with shortest path geodesic distance to Euclid distance ratio
threshold.
:param ratio: geodesic distance ratio to Euclid distance
:param ratio_threshold: geodesic shortest path to Euclid
distance ratio upper limit till aggressive sampling is applied.
:return: value between 0.008 and 0.144 for ratio [1, 1.1]
"""
assert ratio < ratio_threshold
return 20 * (ratio - 0.98) ** 2
def is_compatible_episode(
s, t, sim, near_dist, far_dist, geodesic_to_euclid_ratio
):
euclid_dist = np.power(np.power(np.array(s) - np.array(t), 2).sum(0), 0.5)
if np.abs(s[1] - t[1]) > 0.5: # check height difference to assure s and
# t are from same floor
return False, 0
d_separation = sim.geodesic_distance(s, t)
if d_separation == np.inf:
return False, 0
if not near_dist <= d_separation <= far_dist:
return False, 0
distances_ratio = d_separation / euclid_dist
if distances_ratio < geodesic_to_euclid_ratio and (
np.random.rand()
> _ratio_sample_rate(distances_ratio, geodesic_to_euclid_ratio)
):
return False, 0
if sim.island_radius(s) < ISLAND_RADIUS_LIMIT:
return False, 0
return True, d_separation
def _create_episode(
episode_id,
scene_id,
start_position,
start_rotation,
target_position,
shortest_paths=None,
radius=None,
info=None,
) -> Optional[NavigationEpisode]:
goals = [NavigationGoal(position=target_position, radius=radius)]
return NavigationEpisode(
episode_id=str(episode_id),
goals=goals,
scene_id=scene_id,
start_position=start_position,
start_rotation=start_rotation,
shortest_paths=shortest_paths,
info=info,
)
def generate_pointnav_episode(
sim: Simulator,
num_episodes: int = -1,
is_gen_shortest_path: bool = True,
shortest_path_success_distance: float = 0.2,
shortest_path_max_steps: int = 500,
closest_dist_limit: float = 1,
furthest_dist_limit: float = 30,
geodesic_to_euclid_min_ratio: float = 1.1,
number_retries_per_target: int = 10,
) -> NavigationEpisode:
"""
Generator function that generates PointGoal navigation episodes.
An episode is trivial if there is an obstacle-free, straight line between
the start and goal positions. A good measure of the navigation
complexity of an episode is the ratio of
geodesic shortest path position to Euclidean distance between start and
goal positions to the corresponding Euclidean distance.
If the ratio is nearly 1, it indicates there are few obstacles, and the
episode is easy; if the ratio is larger than 1, the
episode is difficult because strategic navigation is required.
To keep the navigation complexity of the precomputed episodes reasonably
high, we perform aggressive rejection sampling for episodes with the above
ratio falling in the range [1, 1.1].
Following this, there is a significant decrease in the number of
straight-line episodes.
:param sim: simulator with loaded scene for generation.
:param num_episodes: number of episodes needed to generate
:param is_gen_shortest_path: option to generate shortest paths
:param shortest_path_success_distance: success distance when agent should
stop during shortest path generation
:param shortest_path_max_steps maximum number of steps shortest path
expected to be
:param closest_dist_limit episode geodesic distance lowest limit
:param furthest_dist_limit episode geodesic distance highest limit
:param geodesic_to_euclid_min_ratio geodesic shortest path to Euclid
distance ratio upper limit till aggressive sampling is applied.
:return: navigation episode that satisfy specified distribution for
currently loaded into simulator scene.
"""
episode_count = 0
while episode_count < num_episodes or num_episodes < 0:
target_position = sim.sample_navigable_point()
if sim.island_radius(target_position) < ISLAND_RADIUS_LIMIT:
continue
for retry in range(number_retries_per_target):
source_position = sim.sample_navigable_point()
is_compatible, dist = is_compatible_episode(
source_position,
target_position,
sim,
near_dist=closest_dist_limit,
far_dist=furthest_dist_limit,
geodesic_to_euclid_ratio=geodesic_to_euclid_min_ratio,
)
if is_compatible:
angle = np.random.uniform(0, 2 * np.pi)
source_rotation = [0, np.sin(angle / 2), 0, np.cos(angle / 2)]
shortest_paths = None
if is_gen_shortest_path:
shortest_paths = [
get_action_shortest_path(
sim,
source_position=source_position,
source_rotation=source_rotation,
goal_position=target_position,
success_distance=shortest_path_success_distance,
max_episode_steps=shortest_path_max_steps,
)
]
episode = _create_episode(
episode_id=episode_count,
scene_id=sim.config.SCENE,
start_position=source_position,
start_rotation=source_rotation,
target_position=target_position,
shortest_paths=shortest_paths,
radius=shortest_path_success_distance,
info={"geodesic_distance": dist},
)
episode_count += 1
yield episode
from typing import List
from habitat.core.logging import logger
from habitat.core.simulator import ShortestPathPoint
from habitat.sims.habitat_simulator import SimulatorActions
from habitat.tasks.nav.shortest_path_follower import ShortestPathFollower
from habitat.utils.geometry_utils import quaternion_to_list
def get_action_shortest_path(
sim,
source_position,
source_rotation,
goal_position,
success_distance=0.05,
max_episode_steps=500,
shortest_path_mode="greedy",
) -> List[ShortestPathPoint]:
sim.reset()
sim.set_agent_state(source_position, source_rotation)
follower = ShortestPathFollower(sim, success_distance, False)
follower.mode = shortest_path_mode
shortest_path = []
action = None
step_count = 0
while action != SimulatorActions.STOP and step_count < max_episode_steps:
action = follower.get_next_action(goal_position)
state = sim.get_agent_state()
shortest_path.append(
ShortestPathPoint(
state.position.tolist(),
quaternion_to_list(state.rotation),
action.value,
)
)
sim.step(action.value)
step_count += 1
if step_count == max_episode_steps:
logger.warning("Shortest path wasn't found.")
return shortest_path
......@@ -317,7 +317,8 @@ class HabitatSim(habitat.Simulator):
will be None.
"""
raise NotImplementedError(
"This function is no longer implemented. Please use the greedy follower instead"
"This function is no longer implemented. Please use the greedy "
"follower instead"
)
@property
......@@ -421,10 +422,11 @@ class HabitatSim(habitat.Simulator):
state.position = position
state.rotation = rotation
# NB: The agent state also contains the sensor states in _absolute_ coordinates.
# In order to set the agent's body to a specific location and have the sensors follow,
# we must not provide any state for the sensors.
# This will cause them to follow the agent's body
# NB: The agent state also contains the sensor states in _absolute_
# coordinates. In order to set the agent's body to a specific
# location and have the sensors follow, we must not provide any
# state for the sensors. This will cause them to follow the agent's
# body
state.sensor_states = dict()
agent.set_state(state, reset_sensors)
......@@ -440,3 +442,6 @@ class HabitatSim(habitat.Simulator):
return self._sim.pathfinder.distance_to_closest_obstacle(
position, max_search_radius
)
def island_radius(self, position):
return self._sim.pathfinder.island_radius(position)
......@@ -20,11 +20,7 @@ from habitat.core.simulator import (
SensorTypes,
SensorSuite,
)
from habitat.tasks.utils import (
quaternion_to_rotation,
cartesian_to_polar,
quaternion_rotate_vector,
)
from habitat.tasks.utils import cartesian_to_polar, quaternion_rotate_vector
from habitat.utils.visualizations import maps
COLLISION_PROXIMITY_TOLERANCE: float = 1e-3
......
......@@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
import numpy as np
import quaternion
import quaternion # noqa # pylint: disable=unused-import
def quaternion_to_rotation(q_r, q_i, q_j, q_k):
......
......@@ -48,3 +48,9 @@ def quaternion_xyzw_to_wxyz(v: np.array):
def quaternion_wxyz_to_xyzw(v: np.array):
return np.quaternion(*v[1:4], v[0])
def quaternion_to_list(q: np.quaternion):
return quaternion.as_float_array(
quaternion_wxyz_to_xyzw(quaternion.as_float_array(q))
).tolist()
......@@ -108,8 +108,9 @@ def images_to_video(
use at your own risk.
quality: Default is 5. Uses variable bit rate. Highest quality is 10,
lowest is 0. Set to None to prevent variable bitrate flags to
FFMPEG so you can manually specify them using output_params instead.
Specifying a fixed bitrate using ‘bitrate’ disables this parameter.
FFMPEG so you can manually specify them using output_params
instead. Specifying a fixed bitrate using ‘bitrate’ disables
this parameter.
"""
assert 0 <= quality <= 10
if not os.path.exists(output_dir):
......
......@@ -10,7 +10,7 @@ import pytest
from habitat_baselines.agents import simple_agents
try:
import torch
import torch # noqa # pylint: disable=unused-import
has_torch = True
except ImportError:
......
......@@ -5,11 +5,14 @@
# LICENSE file in the root directory of this source tree.
import os
import random
import time
import numpy as np
import pytest
import habitat
import habitat.datasets.pointnav.pointnav_generator as pointnav_generator
from habitat.config.default import get_config
from habitat.core.embodied_task import Episode
from habitat.core.logging import logger
......@@ -18,9 +21,12 @@ from habitat.datasets.pointnav.pointnav_dataset import (
PointNavDatasetV1,
DEFAULT_SCENE_PATH_PREFIX,
)
from habitat.utils.geometry_utils import quaternion_xyzw_to_wxyz
CFG_TEST = "configs/datasets/pointnav/gibson.yaml"
CFG_TEST = "configs/test/habitat_all_sensors_test.yaml"
CFG_MULTI_TEST = "configs/datasets/pointnav/gibson.yaml"
PARTIAL_LOAD_SCENES = 3
NUM_EPISODES = 10
def check_json_serializaiton(dataset: habitat.Dataset):
......@@ -56,7 +62,7 @@ def test_single_pointnav_dataset():
def test_multiple_files_scene_path():
dataset_config = get_config(CFG_TEST).DATASET
dataset_config = get_config(CFG_MULTI_TEST).DATASET
if not PointNavDatasetV1.check_config_paths_exist(dataset_config):
pytest.skip("Test skipped as dataset files are missing.")
scenes = PointNavDatasetV1.get_scenes_to_load(config=dataset_config)
......@@ -84,7 +90,7 @@ def test_multiple_files_scene_path():
def test_multiple_files_pointnav_dataset():
dataset_config = get_config(CFG_TEST).DATASET
dataset_config = get_config(CFG_MULTI_TEST).DATASET
if not PointNavDatasetV1.check_config_paths_exist(dataset_config):
pytest.skip("Test skipped as dataset files are missing.")
scenes = PointNavDatasetV1.get_scenes_to_load(config=dataset_config)
......@@ -101,3 +107,67 @@ def test_multiple_files_pointnav_dataset():
len(partial_dataset.scene_ids) == PARTIAL_LOAD_SCENES
), "Number of loaded scenes doesn't correspond."
check_json_serializaiton(partial_dataset)
def check_shortest_path(env, episode):
def check_state(agent_state, position, rotation):
assert np.allclose(
agent_state.rotation, quaternion_xyzw_to_wxyz(rotation)
), "Agent's rotation diverges from the shortest path."
assert np.allclose(
agent_state.position, position
), "Agent's position position diverges from the shortest path's one."
assert len(episode.goals) == 1, "Episode has no goals or more than one."
assert (
len(episode.shortest_paths) == 1
), "Episode has no shortest paths or more than one."
env.episodes = [episode]
env.reset()
start_state = env.sim.get_agent_state()
check_state(start_state, episode.start_position, episode.start_rotation)
for step_id, point in enumerate(episode.shortest_paths[0]):
cur_state = env.sim.get_agent_state()
check_state(cur_state, point.position, point.rotation)
env.step(point.action)
def test_pointnav_episode_generator():
config = get_config(CFG_TEST)
config.defrost()
config.DATASET.SPLIT = "val"
config.ENVIRONMENT.MAX_EPISODE_STEPS = 500
config.freeze()
env = habitat.Env(config)
env.seed(config.SEED)
random.seed(config.SEED)
generator = pointnav_generator.generate_pointnav_episode(
sim=env.sim,
shortest_path_success_distance=config.TASK.SUCCESS_DISTANCE,
shortest_path_max_steps=config.ENVIRONMENT.MAX_EPISODE_STEPS,
)
episodes = []
for i in range(NUM_EPISODES):
episode = next(generator)
episodes.append(episode)
for episode in pointnav_generator.generate_pointnav_episode(
sim=env.sim,
num_episodes=NUM_EPISODES,
shortest_path_success_distance=config.TASK.SUCCESS_DISTANCE,
shortest_path_max_steps=config.ENVIRONMENT.MAX_EPISODE_STEPS,
geodesic_to_euclid_min_ratio=0,
):
episodes.append(episode)
assert len(episodes) == 2 * NUM_EPISODES
env.episodes = episodes
for episode in episodes:
check_shortest_path(env, episode)
dataset = habitat.Dataset()
dataset.episodes = episodes
assert dataset.to_json(), "Generated episodes aren't json serializable."
......@@ -148,8 +148,8 @@ def test_collisions():
< 0.9 * config.SIMULATOR.FORWARD_STEP_SIZE
and action == actions[0]
):
# Check to see if the new method of doing collisions catches all the same
# collisions as the old method
# Check to see if the new method of doing collisions catches
# all the same collisions as the old method
assert collisions == prev_collisions + 1
prev_loc = loc
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment