diff --git a/habitat/config/default.py b/habitat/config/default.py index 8f90919ae65c29ca1fb145f2422edad36d308d91..68634e3e44196792d580ea3a1b03b3bf0a9b2c91 100644 --- a/habitat/config/default.py +++ b/habitat/config/default.py @@ -71,6 +71,7 @@ _C.TASK.TOP_DOWN_MAP.NUM_TOPDOWN_MAP_SAMPLE_POINTS = 20000 _C.TASK.TOP_DOWN_MAP.MAP_RESOLUTION = 1250 _C.TASK.TOP_DOWN_MAP.DRAW_SOURCE_AND_TARGET = True _C.TASK.TOP_DOWN_MAP.DRAW_BORDER = True +_C.TASK.TOP_DOWN_MAP.DRAW_SHORTEST_PATH = True # ----------------------------------------------------------------------------- # # COLLISIONS MEASUREMENT # ----------------------------------------------------------------------------- diff --git a/habitat/core/env.py b/habitat/core/env.py index eaf72b43164ad581b10ce1995fbc8fd1150dc028..721adb34c5429e0cb6c2019ac65e22e721402787 100644 --- a/habitat/core/env.py +++ b/habitat/core/env.py @@ -287,6 +287,10 @@ class RLEnv(gym.Env): def episodes(self) -> List[Type[Episode]]: return self._env.episodes + @property + def current_episode(self) -> Type[Episode]: + return self._env.current_episode + @episodes.setter def episodes(self, episodes: List[Type[Episode]]) -> None: self._env.episodes = episodes diff --git a/habitat/core/vector_env.py b/habitat/core/vector_env.py index be8c735540f32b17707e8c884d07513c1afe46f7..29a4c23c6de8b17cd2fd3665baff3c89fac94cda 100644 --- a/habitat/core/vector_env.py +++ b/habitat/core/vector_env.py @@ -27,6 +27,7 @@ CLOSE_COMMAND = "close" OBSERVATION_SPACE_COMMAND = "observation_space" ACTION_SPACE_COMMAND = "action_space" CALL_COMMAND = "call" +EPISODE_COMMAND = "current_episode" def _make_env_fn( @@ -186,6 +187,10 @@ class VectorEnv: else: result = getattr(env, function_name)(*function_args) connection_write_fn(result) + + # TODO: update CALL_COMMAND for getting attribute like this + elif command == EPISODE_COMMAND: + connection_write_fn(env.current_episode) else: raise NotImplementedError @@ -231,6 +236,16 @@ class VectorEnv: [p.send for p in parent_connections], ) + def current_episodes(self): + self._is_waiting = True + for write_fn in self._connection_write_fns: + write_fn((EPISODE_COMMAND, None)) + results = [] + for read_fn in self._connection_read_fns: + results.append(read_fn()) + self._is_waiting = False + return results + def reset(self): r"""Reset all the vectorized environments diff --git a/habitat/tasks/nav/nav_task.py b/habitat/tasks/nav/nav_task.py index 16a3d5d107162cc6f793254daf641bf99b3ec0ab..bc9f9aeef45449a5157761bb8444df6477588aeb 100644 --- a/habitat/tasks/nav/nav_task.py +++ b/habitat/tasks/nav/nav_task.py @@ -393,10 +393,11 @@ class Collisions(Measure): def update_metric(self, episode, action): if self._metric is None: - self._metric = 0 - + self._metric = {"count": 0, "is_collision": False} + self._metric["is_collision"] = False if self._sim.previous_step_collided: - self._metric += 1 + self._metric["count"] += 1 + self._metric["is_collision"] = True @registry.register_measure @@ -419,9 +420,13 @@ class TopDownMap(Measure): self._coordinate_min = maps.COORDINATE_MIN self._coordinate_max = maps.COORDINATE_MAX self._top_down_map = None + self._shortest_path_points = None self._cell_scale = ( self._coordinate_max - self._coordinate_min ) / self._map_resolution[0] + self.line_thickness = int( + np.round(self._map_resolution[0] * 2 / MAP_THICKNESS_SCALAR) + ) super().__init__() def _get_uuid(self, *args: Any, **kwargs: Any): @@ -430,7 +435,7 @@ class TopDownMap(Measure): def _check_valid_nav_point(self, point: List[float]): self._sim.is_navigable(point) - def get_original_map(self, episode): + def get_original_map(self): top_down_map = maps.get_topdown_map( self._sim, self._map_resolution, @@ -445,43 +450,42 @@ class TopDownMap(Measure): self._ind_x_max = range_x[-1] self._ind_y_min = range_y[0] self._ind_y_max = range_y[-1] - - if self._config.DRAW_SOURCE_AND_TARGET: - # mark source point - s_x, s_y = maps.to_grid( - episode.start_position[0], - episode.start_position[2], - self._coordinate_min, - self._coordinate_max, - self._map_resolution, - ) - point_padding = 2 * int( - np.ceil(self._map_resolution[0] / MAP_THICKNESS_SCALAR) - ) - top_down_map[ - s_x - point_padding : s_x + point_padding + 1, - s_y - point_padding : s_y + point_padding + 1, - ] = maps.MAP_SOURCE_POINT_INDICATOR - - # mark target point - t_x, t_y = maps.to_grid( - episode.goals[0].position[0], - episode.goals[0].position[2], - self._coordinate_min, - self._coordinate_max, - self._map_resolution, - ) - top_down_map[ - t_x - point_padding : t_x + point_padding + 1, - t_y - point_padding : t_y + point_padding + 1, - ] = maps.MAP_TARGET_POINT_INDICATOR - return top_down_map + def draw_source_and_target(self, episode): + # mark source point + s_x, s_y = maps.to_grid( + episode.start_position[0], + episode.start_position[2], + self._coordinate_min, + self._coordinate_max, + self._map_resolution, + ) + point_padding = 2 * int( + np.ceil(self._map_resolution[0] / MAP_THICKNESS_SCALAR) + ) + self._top_down_map[ + s_x - point_padding : s_x + point_padding + 1, + s_y - point_padding : s_y + point_padding + 1, + ] = maps.MAP_SOURCE_POINT_INDICATOR + + # mark target point + t_x, t_y = maps.to_grid( + episode.goals[0].position[0], + episode.goals[0].position[2], + self._coordinate_min, + self._coordinate_max, + self._map_resolution, + ) + self._top_down_map[ + t_x - point_padding : t_x + point_padding + 1, + t_y - point_padding : t_y + point_padding + 1, + ] = maps.MAP_TARGET_POINT_INDICATOR + def reset_metric(self, episode): self._step_count = 0 self._metric = None - self._top_down_map = self.get_original_map(episode) + self._top_down_map = self.get_original_map() agent_position = self._sim.get_agent_state().position a_x, a_y = maps.to_grid( agent_position[0], @@ -491,6 +495,30 @@ class TopDownMap(Measure): self._map_resolution, ) self._previous_xy_location = (a_y, a_x) + if self._config.DRAW_SHORTEST_PATH: + # draw shortest path + self._shortest_path_points = self._sim.get_straight_shortest_path_points( + agent_position, episode.goals[0].position + ) + self._shortest_path_points = [ + maps.to_grid( + p[0], + p[2], + self._coordinate_min, + self._coordinate_max, + self._map_resolution, + )[::-1] + for p in self._shortest_path_points + ] + maps.draw_path( + self._top_down_map, + self._shortest_path_points, + maps.MAP_SHORTEST_PATH_COLOR, + self.line_thickness, + ) + # draw source and target points last to avoid overlap + if self._config.DRAW_SOURCE_AND_TARGET: + self.draw_source_and_target(episode) def update_metric(self, episode, action): self._step_count += 1 @@ -515,8 +543,22 @@ class TopDownMap(Measure): map_agent_x - (self._ind_x_min - self._grid_delta), map_agent_y - (self._ind_y_min - self._grid_delta), ), + "agent_angle": self.get_polar_angle(), } + def get_polar_angle(self): + agent_state = self._sim.get_agent_state() + # quaternion is in x, y, z, w format + ref_rotation = agent_state.rotation + + heading_vector = quaternion_rotate_vector( + ref_rotation.inverse(), np.array([0, 0, -1]) + ) + + phi = cartesian_to_polar(-heading_vector[2], heading_vector[0])[1] + x_y_flip = -np.pi / 2 + return np.array(phi) + x_y_flip + def update_map(self, agent_position): a_x, a_y = maps.to_grid( agent_position[0], diff --git a/habitat/utils/visualizations/maps.py b/habitat/utils/visualizations/maps.py index f8dbb171becd64c83b2542cf822b9b4be8f78195..9f34c22b20f14db3463c17d8fb5e54a6d5f4a11c 100644 --- a/habitat/utils/visualizations/maps.py +++ b/habitat/utils/visualizations/maps.py @@ -33,7 +33,7 @@ MAP_VALID_POINT = 1 MAP_BORDER_INDICATOR = 2 MAP_SOURCE_POINT_INDICATOR = 4 MAP_TARGET_POINT_INDICATOR = 6 - +MAP_SHORTEST_PATH_COLOR = 7 TOP_DOWN_MAP_COLORS = np.full((256, 3), 150, dtype=np.uint8) TOP_DOWN_MAP_COLORS[10:] = cv2.applyColorMap( np.arange(246, dtype=np.uint8), cv2.COLORMAP_JET @@ -43,6 +43,7 @@ TOP_DOWN_MAP_COLORS[MAP_VALID_POINT] = [150, 150, 150] TOP_DOWN_MAP_COLORS[MAP_BORDER_INDICATOR] = [50, 50, 50] TOP_DOWN_MAP_COLORS[MAP_SOURCE_POINT_INDICATOR] = [0, 0, 200] TOP_DOWN_MAP_COLORS[MAP_TARGET_POINT_INDICATOR] = [200, 0, 0] +TOP_DOWN_MAP_COLORS[MAP_SHORTEST_PATH_COLOR] = [0, 200, 0] def draw_agent( @@ -334,3 +335,20 @@ def colorize_topdown_map(top_down_map: np.ndarray) -> np.ndarray: A colored version of the top-down map. """ return TOP_DOWN_MAP_COLORS[top_down_map] + + +def draw_path( + top_down_map: np.ndarray, + path_points: List[Tuple], + color: int, + thickness: int = 2, +) -> None: + r"""Draw path on top_down_map (in place) with specified color. + Args: + top_down_map: A colored version of the map. + color: color code of the path, from TOP_DOWN_MAP_COLORS. + path_points: list of points that specify the path to be drawn + thickness: thickness of the path. + """ + for prev_pt, next_pt in zip(path_points[:-1], path_points[1:]): + cv2.line(top_down_map, prev_pt, next_pt, color, thickness=thickness) diff --git a/habitat/utils/visualizations/utils.py b/habitat/utils/visualizations/utils.py index a051df99fe9caca27829bc5678961d5cb570eeff..48ffc02ba396c366686bc56ae30ad288c38073d2 100644 --- a/habitat/utils/visualizations/utils.py +++ b/habitat/utils/visualizations/utils.py @@ -5,12 +5,15 @@ # LICENSE file in the root directory of this source tree. import os -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple +import cv2 import imageio import numpy as np import tqdm +from habitat.utils.visualizations import maps + def paste_overlapping_image( background: np.ndarray, @@ -103,7 +106,7 @@ def images_to_video( Args: images: The list of images. Images should be HxWx3 in RGB order. output_dir: The folder to put the video in. - video_name: The navme for the video. + video_name: The name for the video. fps: Frames per second for the video. Not all values work with FFMPEG, use at your own risk. quality: Default is 5. Uses variable bit rate. Highest quality is 10, @@ -125,3 +128,75 @@ def images_to_video( for im in tqdm.tqdm(images): writer.append_data(im) writer.close() + + +def draw_collision(view: np.ndarray, alpha: float = 0.4) -> np.ndarray: + r"""Draw translucent red strips on the border of input view to indicate + a collision has taken place. + Args: + view: input view of size HxWx3 in RGB order. + alpha: Opacity of red collision strip. 1 is completely non-transparent. + Returns: + A view with collision effect drawn. + """ + size = view.shape[0] + strip_width = size // 20 + mask = np.ones((size, size)) + mask[strip_width:-strip_width, strip_width:-strip_width] = 0 + mask = mask == 1 + view[mask] = (alpha * np.array([255, 0, 0]) + (1.0 - alpha) * view)[mask] + return view + + +def observations_to_image(observation: Dict, info: Dict) -> np.ndarray: + r"""Generate image of single frame from observation and info + returned from a single environment step(). + + Args: + observation: observation returned from an environment step(). + info: info returned from an environment step(). + + Returns: + generated image of a single frame. + """ + observation_size = observation["rgb"].shape[0] + egocentric_view = observation["rgb"][:, :, :3] + # draw collision + if "collisions" in info and info["collisions"]["is_collision"]: + egocentric_view = draw_collision(egocentric_view) + + # draw depth map if observation has depth info + if "depth" in observation: + depth_map = (observation["depth"].squeeze() * 255).astype(np.uint8) + depth_map = np.stack([depth_map for _ in range(3)], axis=2) + + egocentric_view = np.concatenate((egocentric_view, depth_map), axis=1) + + frame = egocentric_view + + if "top_down_map" in info: + top_down_map = info["top_down_map"]["map"] + top_down_map = maps.colorize_topdown_map(top_down_map) + map_agent_pos = info["top_down_map"]["agent_map_coord"] + top_down_map = maps.draw_agent( + image=top_down_map, + agent_center_coord=map_agent_pos, + agent_rotation=info["top_down_map"]["agent_angle"], + agent_radius_px=top_down_map.shape[0] // 16, + ) + + if top_down_map.shape[0] > top_down_map.shape[1]: + top_down_map = np.rot90(top_down_map, 1) + + # scale top down map to align with rgb view + old_h, old_w, _ = top_down_map.shape + top_down_height = observation_size + top_down_width = int(float(top_down_height) / old_h * old_w) + # cv2 resize (dsize is width first) + top_down_map = cv2.resize( + top_down_map, + (top_down_width, top_down_height), + interpolation=cv2.INTER_CUBIC, + ) + frame = np.concatenate((egocentric_view, top_down_map), axis=1) + return frame diff --git a/habitat_baselines/README.md b/habitat_baselines/README.md index f3150a210a034a5fb4e48116cb10cd287bc8430c..9a29cf4d955dd841ce70033af09bf0d8bf18eb03 100644 --- a/habitat_baselines/README.md +++ b/habitat_baselines/README.md @@ -48,12 +48,6 @@ python -u habitat_baselines/train_ppo.py \ ``` -**single-episode training**: -Algorithms can be trained with a single-episode option. This option can be used as a sanity check since good algorithms should overfit one episode relatively fast. To enable this option, add `DATASET.NUM_EPISODE_SAMPLE 1` *at the end* of the training command, or include the single-episode yaml file in `--task-config` like this: -``` - --task-config "configs/tasks/pointnav.yaml,configs/datasets/single_episode.yaml" -``` - **test**: ```bash python -u habitat_baselines/evaluate_ppo.py \ @@ -78,3 +72,28 @@ Set argument `--task-config` to `tasks/pointnav_mp3d.yaml` for training on [Matt - [Handcrafted agent baseline](slambased/README.md) adopted from the paper "Benchmarking Classic and Learned Navigation in Complex 3D Environments". +### Additional Utilities + +**single-episode training**: +Algorithms can be trained with a single-episode option. This option can be used as a sanity check since good algorithms should overfit one episode relatively fast. To enable this option, add `DATASET.NUM_EPISODE_SAMPLE 1` *at the end* of the training command, or include the single-episode yaml file in `--task-config` like this: +``` + --task-config "configs/tasks/pointnav.yaml,configs/datasets/single_episode.yaml" +``` + +**tensorboard and video generation support** + +Enable tensorboard logging by adding `--tensorboard-dir logdir` when running `train_ppo.py` or `evaluate_ppo.py` + +Enable video generation for `evaluate_ppo.py` using `--video-option`: specifying `tensorboard`(for displaying on tensorboard) or `disk` (for saving videos on disk), for example: +``` +python -u habitat_baselines/evaluate_ppo.py +... +--count-test-episodes 2 \ +--video-option tensorboard \ +--tensorboard-dir tb_eval \ +--model-path data/checkpoints/ckpt.xx.pth +``` +The above command should generate navigation episode recordings and display them on tensorboard like this: +<p align="center"> + <img src="../res/img/tensorboard_video_demo.gif" height="500"> +</p> diff --git a/habitat_baselines/evaluate_ppo.py b/habitat_baselines/evaluate_ppo.py index 0757f61483308123399e9223d47f2593d5f11e62..dddca866ae7aa72ef04be6d5ec1317a6d7877c1a 100644 --- a/habitat_baselines/evaluate_ppo.py +++ b/habitat_baselines/evaluate_ppo.py @@ -5,45 +5,83 @@ # LICENSE file in the root directory of this source tree. import argparse +import glob +import os +import time +from typing import Optional import torch import habitat from config.default import get_config as cfg_baseline +from habitat import logger from habitat.config.default import get_config +from habitat.utils.visualizations.utils import ( + images_to_video, + observations_to_image, +) from rl.ppo import PPO, Policy from rl.ppo.utils import batch_obs +from tensorboard_utils import get_tensorboard_writer from train_ppo import make_env_fn -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-path", type=str, required=True) - parser.add_argument("--sim-gpu-id", type=int, required=True) - parser.add_argument("--pth-gpu-id", type=int, required=True) - parser.add_argument("--num-processes", type=int, required=True) - parser.add_argument("--hidden-size", type=int, default=512) - parser.add_argument("--count-test-episodes", type=int, default=100) - parser.add_argument( - "--sensors", - type=str, - default="RGB_SENSOR,DEPTH_SENSOR", - help="comma separated string containing different" - "sensors to use, currently 'RGB_SENSOR' and" - "'DEPTH_SENSOR' are supported", - ) - parser.add_argument( - "--task-config", - type=str, - default="configs/tasks/pointnav.yaml", - help="path to config yaml containing information about task", +def poll_checkpoint_folder( + checkpoint_folder: str, previous_ckpt_ind: int +) -> Optional[str]: + r""" Return (previous_ckpt_ind + 1)th checkpoint in checkpoint folder + (sorted by time of last modification). + + Args: + checkpoint_folder: directory to look for checkpoints. + previous_ckpt_ind: index of checkpoint last returned. + + Returns: + return checkpoint path if (previous_ckpt_ind + 1)th checkpoint is found + else return None. + """ + assert os.path.isdir(checkpoint_folder), "invalid checkpoint folder path" + models_paths = list( + filter(os.path.isfile, glob.glob(checkpoint_folder + "/*")) ) - args = parser.parse_args() + models_paths.sort(key=os.path.getmtime) + ind = previous_ckpt_ind + 1 + if ind < len(models_paths): + return models_paths[ind] + return None + + +def generate_video( + args, images, episode_id, checkpoint_idx, spl, tb_writer, fps=10 +) -> None: + r"""Generate video according to specified information. - device = torch.device("cuda:{}".format(args.pth_gpu_id)) + Args: + args: contains args.video_option and args.video_dir. + images: list of images to be converted to video. + episode_id: episode id for video naming. + checkpoint_idx: checkpoint index for video naming. + spl: SPL for this episode for video naming. + tb_writer: tensorboard writer object for uploading video + fps: fps for generated video + Returns: + None + """ + if args.video_option and len(images) > 0: + video_name = f"episode{episode_id}_ckpt{checkpoint_idx}_spl{spl:.2f}" + if "disk" in args.video_option: + images_to_video(images, args.video_dir, video_name) + if "tensorboard" in args.video_option: + tb_writer.add_video_from_np_images( + f"episode{episode_id}", checkpoint_idx, images, fps=fps + ) + + +def eval_checkpoint(checkpoint_path, args, writer, cur_ckpt_idx=0): env_configs = [] baseline_configs = [] + device = torch.device("cuda", args.pth_gpu_id) for _ in range(args.num_processes): config_env = get_config(config_paths=args.task_config) @@ -54,6 +92,9 @@ def main(): for sensor in agent_sensors: assert sensor in ["RGB_SENSOR", "DEPTH_SENSOR"] config_env.SIMULATOR.AGENT_0.SENSORS = agent_sensors + if args.video_option: + config_env.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") + config_env.TASK.MEASUREMENTS.append("COLLISIONS") config_env.freeze() env_configs.append(config_env) @@ -71,7 +112,7 @@ def main(): ), ) - ckpt = torch.load(args.model_path, map_location=device) + ckpt = torch.load(checkpoint_path, map_location=device) actor_critic = Policy( observation_space=envs.observation_spaces[0], @@ -112,8 +153,16 @@ def main(): args.num_processes, args.hidden_size, device=device ) not_done_masks = torch.zeros(args.num_processes, 1, device=device) + stats_episodes = set() + + rgb_frames = None + if args.video_option: + rgb_frames = [[]] * args.num_processes + os.makedirs(args.video_dir, exist_ok=True) while episode_counts.sum() < args.count_test_episodes: + current_episodes = envs.current_episodes() + with torch.no_grad(): _, actions, _, test_recurrent_hidden_states = actor_critic.act( batch, @@ -149,13 +198,151 @@ def main(): episode_counts += 1 - not_done_masks current_episode_reward *= not_done_masks + next_episodes = envs.current_episodes() + envs_to_pause = [] + n_envs = envs.num_envs + for i in range(n_envs): + if next_episodes[i].episode_id in stats_episodes: + envs_to_pause.append(i) + + # episode ended + if not_done_masks[i].item() == 0: + stats_episodes.add(current_episodes[i].episode_id) + if args.video_option: + generate_video( + args, + rgb_frames[i], + current_episodes[i].episode_id, + cur_ckpt_idx, + infos[i]["spl"], + writer, + ) + rgb_frames[i] = [] + + # episode continues + elif args.video_option: + frame = observations_to_image(observations[i], infos[i]) + rgb_frames[i].append(frame) + + # stop tracking ended episodes if they exist + if len(envs_to_pause) > 0: + state_index = list(range(envs.num_envs)) + for idx in reversed(envs_to_pause): + state_index.pop(idx) + envs.pause_at(idx) + + # indexing along the batch dimensions + test_recurrent_hidden_states = test_recurrent_hidden_states[ + :, state_index + ] + not_done_masks = not_done_masks[state_index] + current_episode_reward = current_episode_reward[state_index] + + for k, v in batch.items(): + batch[k] = v[state_index] + + if args.video_option: + rgb_frames = [rgb_frames[i] for i in state_index] + episode_reward_mean = (episode_rewards / episode_counts).mean().item() episode_spl_mean = (episode_spls / episode_counts).mean().item() episode_success_mean = (episode_success / episode_counts).mean().item() - print("Average episode reward: {:.6f}".format(episode_reward_mean)) - print("Average episode success: {:.6f}".format(episode_success_mean)) - print("Average episode spl: {:.6f}".format(episode_spl_mean)) + logger.info("Average episode reward: {:.6f}".format(episode_reward_mean)) + logger.info("Average episode success: {:.6f}".format(episode_success_mean)) + logger.info("Average episode SPL: {:.6f}".format(episode_spl_mean)) + + writer.add_scalars( + "eval_reward", {"average reward": episode_reward_mean}, cur_ckpt_idx + ) + writer.add_scalars( + "eval_SPL", {"average SPL": episode_spl_mean}, cur_ckpt_idx + ) + writer.add_scalars( + "eval_success", {"average success": episode_success_mean}, cur_ckpt_idx + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str) + parser.add_argument("--tracking-model-dir", type=str) + parser.add_argument("--sim-gpu-id", type=int, required=True) + parser.add_argument("--pth-gpu-id", type=int, required=True) + parser.add_argument("--num-processes", type=int, required=True) + parser.add_argument("--hidden-size", type=int, default=512) + parser.add_argument("--count-test-episodes", type=int, default=100) + parser.add_argument( + "--sensors", + type=str, + default="RGB_SENSOR,DEPTH_SENSOR", + help="comma separated string containing different" + "sensors to use, currently 'RGB_SENSOR' and" + "'DEPTH_SENSOR' are supported", + ) + parser.add_argument( + "--task-config", + type=str, + default="configs/tasks/pointnav.yaml", + help="path to config yaml containing information about task", + ) + parser.add_argument( + "--video-option", + type=str, + default="", + choices=["tensorboard", "disk"], + nargs="*", + help="Options for video output, leave empty for no video. " + "Videos can be saved to disk, uploaded to tensorboard, or both.", + ) + parser.add_argument( + "--video-dir", type=str, help="directory for storing videos" + ) + parser.add_argument( + "--tensorboard-dir", + type=str, + help="directory for storing tensorboard statistics", + ) + + args = parser.parse_args() + + assert (args.model_path is not None) != ( + args.tracking_model_dir is not None + ), "Must specify a single model or a directory of models, but not both" + if "tensorboard" in args.video_option: + assert ( + args.tensorboard_dir is not None + ), "Must specify a tensorboard directory for video display" + if "disk" in args.video_option: + assert ( + args.video_dir is not None + ), "Must specify a directory for storing videos on disk" + + with get_tensorboard_writer( + args.tensorboard_dir, purge_step=0, flush_secs=30 + ) as writer: + if args.model_path is not None: + # evaluate singe checkpoint + eval_checkpoint(args.model_path, args, writer) + else: + # evaluate multiple checkpoints in order + prev_ckpt_ind = -1 + while True: + current_ckpt = None + while current_ckpt is None: + current_ckpt = poll_checkpoint_folder( + args.tracking_model_dir, prev_ckpt_ind + ) + time.sleep(2) # sleep for 2 seconds before polling again + logger.warning( + "=============current_ckpt: {}=============".format( + current_ckpt + ) + ) + prev_ckpt_ind += 1 + eval_checkpoint( + current_ckpt, args, writer, cur_ckpt_idx=prev_ckpt_ind + ) if __name__ == "__main__": diff --git a/habitat_baselines/rl/ppo/utils.py b/habitat_baselines/rl/ppo/utils.py index 0fed7445359b06b09d3491ba2d99451da92aadf1..8cd0e87095ac35eaea370277b01b38a8c40551bf 100644 --- a/habitat_baselines/rl/ppo/utils.py +++ b/habitat_baselines/rl/ppo/utils.py @@ -422,4 +422,9 @@ def ppo_args(): nargs=argparse.REMAINDER, help="Modify config options from command line", ) + parser.add_argument( + "--tensorboard-dir", + type=str, + help="path to tensorboard logging directory", + ) return parser diff --git a/habitat_baselines/rl/requirements.txt b/habitat_baselines/rl/requirements.txt index 5340339a18439ebc99f365caf4821f191143ae4a..1f31a714d5a20cec3a7a9adcb58d7c8d75372df9 100644 --- a/habitat_baselines/rl/requirements.txt +++ b/habitat_baselines/rl/requirements.txt @@ -1 +1,4 @@ torch==1.1.0 +# full tensorflow required for tensorboard video support +tensorflow==1.13.1 +tb-nightly diff --git a/habitat_baselines/tensorboard_utils.py b/habitat_baselines/tensorboard_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7449dd64be0c06897394f687cce8afd3edea9c56 --- /dev/null +++ b/habitat_baselines/tensorboard_utils.py @@ -0,0 +1,70 @@ +from typing import Optional, Union + +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + + +class TensorboardWriter(SummaryWriter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def add_video_from_np_images( + self, video_name: str, step_idx: int, images: np.ndarray, fps: int = 10 + ) -> None: + r"""Write video into tensorboard from images frames. + + Args: + video_name: name of video string. + step_idx: int of checkpoint index to be displayed. + images: list of n frames. Each frame is a np.ndarray of shape. + fps: frame per second for output video. + + Returns: + None. + """ + # initial shape of np.ndarray list: N * (H, W, 3) + frame_tensors = [ + torch.from_numpy(np_arr).unsqueeze(0) for np_arr in images + ] + video_tensor = torch.cat(tuple(frame_tensors)) + video_tensor = video_tensor.permute(0, 3, 1, 2).unsqueeze(0) + # final shape of video tensor: (1, n, 3, H, W) + self.add_video(video_name, video_tensor, fps=fps, global_step=step_idx) + + +class DummyWriter: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def close(self): + pass + + def __getattr__(self, item): + return lambda *args, **kwargs: None + + +def get_tensorboard_writer( + log_dir: str, *args, **kwargs +) -> Union[DummyWriter, TensorboardWriter]: + r"""Get tensorboard writer if log_dir is specified, otherwise, + return dummy writer instead. + + Args: + log_dir: log directory path for tensorboard SummaryWriter. + *args: additional positional args. + **kwargs: additional keyword args. + + Returns: + Either the created tensorboard writer or a dummy writer. + """ + if log_dir: + return TensorboardWriter(log_dir, *args, **kwargs) + else: + return DummyWriter() diff --git a/habitat_baselines/train_ppo.py b/habitat_baselines/train_ppo.py index a991c17467a60e4402626262a2f04990c1560fbf..733d78b66b64830931b1fc5deae2a3973e251fa5 100644 --- a/habitat_baselines/train_ppo.py +++ b/habitat_baselines/train_ppo.py @@ -19,6 +19,7 @@ from habitat.config.default import get_config as cfg_env from habitat.datasets.registration import make_dataset from rl.ppo import PPO, Policy, RolloutStorage from rl.ppo.utils import batch_obs, ppo_args, update_linear_schedule +from tensorboard_utils import get_tensorboard_writer class NavRLEnv(habitat.RLEnv): @@ -162,7 +163,7 @@ def run_training(): random.seed(args.seed) - device = torch.device("cuda:{}".format(args.pth_gpu_id)) + device = torch.device("cuda", args.pth_gpu_id) logger.add_filehandler(args.log_file) @@ -218,8 +219,8 @@ def run_training(): episode_rewards = torch.zeros(envs.num_envs, 1) episode_counts = torch.zeros(envs.num_envs, 1) current_episode_reward = torch.zeros(envs.num_envs, 1) - window_episode_reward = deque() - window_episode_counts = deque() + window_episode_reward = deque(maxlen=args.reward_window_size) + window_episode_counts = deque(maxlen=args.reward_window_size) t_start = time() env_time = 0 @@ -227,137 +228,169 @@ def run_training(): count_steps = 0 count_checkpoints = 0 - for update in range(args.num_updates): - if args.use_linear_lr_decay: - update_linear_schedule( - agent.optimizer, update, args.num_updates, args.lr - ) + with ( + get_tensorboard_writer( + log_dir=args.tensorboard_dir, purge_step=count_steps, flush_secs=30 + ) + ) as writer: + for update in range(args.num_updates): + if args.use_linear_lr_decay: + update_linear_schedule( + agent.optimizer, update, args.num_updates, args.lr + ) - agent.clip_param = args.clip_param * (1 - update / args.num_updates) + agent.clip_param = args.clip_param * ( + 1 - update / args.num_updates + ) - for step in range(args.num_steps): - t_sample_action = time() - # sample actions - with torch.no_grad(): - step_observation = { - k: v[step] for k, v in rollouts.observations.items() - } + for step in range(args.num_steps): + t_sample_action = time() + # sample actions + with torch.no_grad(): + step_observation = { + k: v[step] for k, v in rollouts.observations.items() + } + + ( + values, + actions, + actions_log_probs, + recurrent_hidden_states, + ) = actor_critic.act( + step_observation, + rollouts.recurrent_hidden_states[step], + rollouts.masks[step], + ) + pth_time += time() - t_sample_action - ( - values, - actions, - actions_log_probs, - recurrent_hidden_states, - ) = actor_critic.act( - step_observation, - rollouts.recurrent_hidden_states[step], - rollouts.masks[step], - ) - pth_time += time() - t_sample_action + t_step_env = time() - t_step_env = time() + outputs = envs.step([a[0].item() for a in actions]) + observations, rewards, dones, infos = [ + list(x) for x in zip(*outputs) + ] - outputs = envs.step([a[0].item() for a in actions]) - observations, rewards, dones, infos = [ - list(x) for x in zip(*outputs) - ] + env_time += time() - t_step_env - env_time += time() - t_step_env + t_update_stats = time() + batch = batch_obs(observations) + rewards = torch.tensor(rewards, dtype=torch.float) + rewards = rewards.unsqueeze(1) - t_update_stats = time() - batch = batch_obs(observations) - rewards = torch.tensor(rewards, dtype=torch.float) - rewards = rewards.unsqueeze(1) + masks = torch.tensor( + [[0.0] if done else [1.0] for done in dones], + dtype=torch.float, + ) - masks = torch.tensor( - [[0.0] if done else [1.0] for done in dones], dtype=torch.float - ) + current_episode_reward += rewards + episode_rewards += (1 - masks) * current_episode_reward + episode_counts += 1 - masks + current_episode_reward *= masks - current_episode_reward += rewards - episode_rewards += (1 - masks) * current_episode_reward - episode_counts += 1 - masks - current_episode_reward *= masks - - rollouts.insert( - batch, - recurrent_hidden_states, - actions, - actions_log_probs, - values, - rewards, - masks, - ) + rollouts.insert( + batch, + recurrent_hidden_states, + actions, + actions_log_probs, + values, + rewards, + masks, + ) - count_steps += envs.num_envs - pth_time += time() - t_update_stats + count_steps += envs.num_envs + pth_time += time() - t_update_stats - if len(window_episode_reward) == args.reward_window_size: - window_episode_reward.popleft() - window_episode_counts.popleft() - window_episode_reward.append(episode_rewards.clone()) - window_episode_counts.append(episode_counts.clone()) + window_episode_reward.append(episode_rewards.clone()) + window_episode_counts.append(episode_counts.clone()) - t_update_model = time() - with torch.no_grad(): - last_observation = { - k: v[-1] for k, v in rollouts.observations.items() - } - next_value = actor_critic.get_value( - last_observation, - rollouts.recurrent_hidden_states[-1], - rollouts.masks[-1], - ).detach() - - rollouts.compute_returns( - next_value, args.use_gae, args.gamma, args.tau - ) + t_update_model = time() + with torch.no_grad(): + last_observation = { + k: v[-1] for k, v in rollouts.observations.items() + } + next_value = actor_critic.get_value( + last_observation, + rollouts.recurrent_hidden_states[-1], + rollouts.masks[-1], + ).detach() + + rollouts.compute_returns( + next_value, args.use_gae, args.gamma, args.tau + ) - value_loss, action_loss, dist_entropy = agent.update(rollouts) + value_loss, action_loss, dist_entropy = agent.update(rollouts) - rollouts.after_update() - pth_time += time() - t_update_model + rollouts.after_update() + pth_time += time() - t_update_model - # log stats - if update > 0 and update % args.log_interval == 0: - logger.info( - "update: {}\tfps: {:.3f}\t".format( - update, count_steps / (time() - t_start) + losses = [value_loss, action_loss] + stats = zip( + ["count", "reward"], + [window_episode_counts, window_episode_reward], + ) + deltas = { + k: ( + (v[-1] - v[0]).sum().item() + if len(v) > 1 + else v[0].sum().item() ) + for k, v in stats + } + deltas["count"] = max(deltas["count"], 1.0) + + writer.add_scalar( + "reward", deltas["reward"] / deltas["count"], count_steps ) - logger.info( - "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" - "frames: {}".format(update, env_time, pth_time, count_steps) + writer.add_scalars( + "losses", + {k: l for l, k in zip(losses, ["value", "policy"])}, + count_steps, ) - window_rewards = ( - window_episode_reward[-1] - window_episode_reward[0] - ).sum() - window_counts = ( - window_episode_counts[-1] - window_episode_counts[0] - ).sum() + # log stats + if update > 0 and update % args.log_interval == 0: + logger.info( + "update: {}\tfps: {:.3f}\t".format( + update, count_steps / (time() - t_start) + ) + ) - if window_counts > 0: logger.info( - "Average window size {} reward: {:3f}".format( - len(window_episode_reward), - (window_rewards / window_counts).item(), + "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" + "frames: {}".format( + update, env_time, pth_time, count_steps ) ) - else: - logger.info("No episodes finish in current window") - - # checkpoint model - if update % args.checkpoint_interval == 0: - checkpoint = {"state_dict": agent.state_dict()} - torch.save( - checkpoint, - os.path.join( - args.checkpoint_folder, - "ckpt.{}.pth".format(count_checkpoints), - ), - ) - count_checkpoints += 1 + + window_rewards = ( + window_episode_reward[-1] - window_episode_reward[0] + ).sum() + window_counts = ( + window_episode_counts[-1] - window_episode_counts[0] + ).sum() + + if window_counts > 0: + logger.info( + "Average window size {} reward: {:3f}".format( + len(window_episode_reward), + (window_rewards / window_counts).item(), + ) + ) + else: + logger.info("No episodes finish in current window") + + # checkpoint model + if update % args.checkpoint_interval == 0: + checkpoint = {"state_dict": agent.state_dict(), "args": args} + torch.save( + checkpoint, + os.path.join( + args.checkpoint_folder, + "ckpt.{}.pth".format(count_checkpoints), + ), + ) + count_checkpoints += 1 if __name__ == "__main__": diff --git a/res/img/tensorboard_video_demo.gif b/res/img/tensorboard_video_demo.gif new file mode 100644 index 0000000000000000000000000000000000000000..e7ed86ffcf602076f3f0e807281277bf5aee13f6 Binary files /dev/null and b/res/img/tensorboard_video_demo.gif differ diff --git a/test/test_sensors.py b/test/test_sensors.py index 5e34e6b5f75a62501fa227e2c150c4e6614fbd43..2b4d5971bbd067df19e3b85ddbce073cb6f73b57 100644 --- a/test/test_sensors.py +++ b/test/test_sensors.py @@ -127,7 +127,7 @@ def test_collisions(): for _ in range(50): action = np.random.choice(actions) env.step(action) - collisions = env.get_metrics()["collisions"] + collisions = env.get_metrics()["collisions"]["count"] loc = env.sim.get_agent_state().position if (