Skip to content
Snippets Groups Projects
Commit 38be0f53 authored by JasonJiazhiZhang's avatar JasonJiazhiZhang Committed by Oleksandr
Browse files

Add tensorboard and video generation for ppo train and eval (#127)

* Add checkpoint progress tracking for evalute_ppo. Now when specified with a checkpoint directory,
* Evaluate_ppo will evaluate checkpoints in chronological order, and constantly check for new checkpoint.
* Add tensorboard visualization to both train_ppo and evaluate_ppo
* Add video generation for navigation episode evaluation. Generated videos can be either saved locally or visualized through tensorboard.
* Add shortest path visualization
parent b397408c
No related branches found
No related tags found
No related merge requests found
......@@ -71,6 +71,7 @@ _C.TASK.TOP_DOWN_MAP.NUM_TOPDOWN_MAP_SAMPLE_POINTS = 20000
_C.TASK.TOP_DOWN_MAP.MAP_RESOLUTION = 1250
_C.TASK.TOP_DOWN_MAP.DRAW_SOURCE_AND_TARGET = True
_C.TASK.TOP_DOWN_MAP.DRAW_BORDER = True
_C.TASK.TOP_DOWN_MAP.DRAW_SHORTEST_PATH = True
# -----------------------------------------------------------------------------
# # COLLISIONS MEASUREMENT
# -----------------------------------------------------------------------------
......
......@@ -287,6 +287,10 @@ class RLEnv(gym.Env):
def episodes(self) -> List[Type[Episode]]:
return self._env.episodes
@property
def current_episode(self) -> Type[Episode]:
return self._env.current_episode
@episodes.setter
def episodes(self, episodes: List[Type[Episode]]) -> None:
self._env.episodes = episodes
......
......@@ -27,6 +27,7 @@ CLOSE_COMMAND = "close"
OBSERVATION_SPACE_COMMAND = "observation_space"
ACTION_SPACE_COMMAND = "action_space"
CALL_COMMAND = "call"
EPISODE_COMMAND = "current_episode"
def _make_env_fn(
......@@ -186,6 +187,10 @@ class VectorEnv:
else:
result = getattr(env, function_name)(*function_args)
connection_write_fn(result)
# TODO: update CALL_COMMAND for getting attribute like this
elif command == EPISODE_COMMAND:
connection_write_fn(env.current_episode)
else:
raise NotImplementedError
......@@ -231,6 +236,16 @@ class VectorEnv:
[p.send for p in parent_connections],
)
def current_episodes(self):
self._is_waiting = True
for write_fn in self._connection_write_fns:
write_fn((EPISODE_COMMAND, None))
results = []
for read_fn in self._connection_read_fns:
results.append(read_fn())
self._is_waiting = False
return results
def reset(self):
r"""Reset all the vectorized environments
......
......@@ -393,10 +393,11 @@ class Collisions(Measure):
def update_metric(self, episode, action):
if self._metric is None:
self._metric = 0
self._metric = {"count": 0, "is_collision": False}
self._metric["is_collision"] = False
if self._sim.previous_step_collided:
self._metric += 1
self._metric["count"] += 1
self._metric["is_collision"] = True
@registry.register_measure
......@@ -419,9 +420,13 @@ class TopDownMap(Measure):
self._coordinate_min = maps.COORDINATE_MIN
self._coordinate_max = maps.COORDINATE_MAX
self._top_down_map = None
self._shortest_path_points = None
self._cell_scale = (
self._coordinate_max - self._coordinate_min
) / self._map_resolution[0]
self.line_thickness = int(
np.round(self._map_resolution[0] * 2 / MAP_THICKNESS_SCALAR)
)
super().__init__()
def _get_uuid(self, *args: Any, **kwargs: Any):
......@@ -430,7 +435,7 @@ class TopDownMap(Measure):
def _check_valid_nav_point(self, point: List[float]):
self._sim.is_navigable(point)
def get_original_map(self, episode):
def get_original_map(self):
top_down_map = maps.get_topdown_map(
self._sim,
self._map_resolution,
......@@ -445,43 +450,42 @@ class TopDownMap(Measure):
self._ind_x_max = range_x[-1]
self._ind_y_min = range_y[0]
self._ind_y_max = range_y[-1]
if self._config.DRAW_SOURCE_AND_TARGET:
# mark source point
s_x, s_y = maps.to_grid(
episode.start_position[0],
episode.start_position[2],
self._coordinate_min,
self._coordinate_max,
self._map_resolution,
)
point_padding = 2 * int(
np.ceil(self._map_resolution[0] / MAP_THICKNESS_SCALAR)
)
top_down_map[
s_x - point_padding : s_x + point_padding + 1,
s_y - point_padding : s_y + point_padding + 1,
] = maps.MAP_SOURCE_POINT_INDICATOR
# mark target point
t_x, t_y = maps.to_grid(
episode.goals[0].position[0],
episode.goals[0].position[2],
self._coordinate_min,
self._coordinate_max,
self._map_resolution,
)
top_down_map[
t_x - point_padding : t_x + point_padding + 1,
t_y - point_padding : t_y + point_padding + 1,
] = maps.MAP_TARGET_POINT_INDICATOR
return top_down_map
def draw_source_and_target(self, episode):
# mark source point
s_x, s_y = maps.to_grid(
episode.start_position[0],
episode.start_position[2],
self._coordinate_min,
self._coordinate_max,
self._map_resolution,
)
point_padding = 2 * int(
np.ceil(self._map_resolution[0] / MAP_THICKNESS_SCALAR)
)
self._top_down_map[
s_x - point_padding : s_x + point_padding + 1,
s_y - point_padding : s_y + point_padding + 1,
] = maps.MAP_SOURCE_POINT_INDICATOR
# mark target point
t_x, t_y = maps.to_grid(
episode.goals[0].position[0],
episode.goals[0].position[2],
self._coordinate_min,
self._coordinate_max,
self._map_resolution,
)
self._top_down_map[
t_x - point_padding : t_x + point_padding + 1,
t_y - point_padding : t_y + point_padding + 1,
] = maps.MAP_TARGET_POINT_INDICATOR
def reset_metric(self, episode):
self._step_count = 0
self._metric = None
self._top_down_map = self.get_original_map(episode)
self._top_down_map = self.get_original_map()
agent_position = self._sim.get_agent_state().position
a_x, a_y = maps.to_grid(
agent_position[0],
......@@ -491,6 +495,30 @@ class TopDownMap(Measure):
self._map_resolution,
)
self._previous_xy_location = (a_y, a_x)
if self._config.DRAW_SHORTEST_PATH:
# draw shortest path
self._shortest_path_points = self._sim.get_straight_shortest_path_points(
agent_position, episode.goals[0].position
)
self._shortest_path_points = [
maps.to_grid(
p[0],
p[2],
self._coordinate_min,
self._coordinate_max,
self._map_resolution,
)[::-1]
for p in self._shortest_path_points
]
maps.draw_path(
self._top_down_map,
self._shortest_path_points,
maps.MAP_SHORTEST_PATH_COLOR,
self.line_thickness,
)
# draw source and target points last to avoid overlap
if self._config.DRAW_SOURCE_AND_TARGET:
self.draw_source_and_target(episode)
def update_metric(self, episode, action):
self._step_count += 1
......@@ -515,8 +543,22 @@ class TopDownMap(Measure):
map_agent_x - (self._ind_x_min - self._grid_delta),
map_agent_y - (self._ind_y_min - self._grid_delta),
),
"agent_angle": self.get_polar_angle(),
}
def get_polar_angle(self):
agent_state = self._sim.get_agent_state()
# quaternion is in x, y, z, w format
ref_rotation = agent_state.rotation
heading_vector = quaternion_rotate_vector(
ref_rotation.inverse(), np.array([0, 0, -1])
)
phi = cartesian_to_polar(-heading_vector[2], heading_vector[0])[1]
x_y_flip = -np.pi / 2
return np.array(phi) + x_y_flip
def update_map(self, agent_position):
a_x, a_y = maps.to_grid(
agent_position[0],
......
......@@ -33,7 +33,7 @@ MAP_VALID_POINT = 1
MAP_BORDER_INDICATOR = 2
MAP_SOURCE_POINT_INDICATOR = 4
MAP_TARGET_POINT_INDICATOR = 6
MAP_SHORTEST_PATH_COLOR = 7
TOP_DOWN_MAP_COLORS = np.full((256, 3), 150, dtype=np.uint8)
TOP_DOWN_MAP_COLORS[10:] = cv2.applyColorMap(
np.arange(246, dtype=np.uint8), cv2.COLORMAP_JET
......@@ -43,6 +43,7 @@ TOP_DOWN_MAP_COLORS[MAP_VALID_POINT] = [150, 150, 150]
TOP_DOWN_MAP_COLORS[MAP_BORDER_INDICATOR] = [50, 50, 50]
TOP_DOWN_MAP_COLORS[MAP_SOURCE_POINT_INDICATOR] = [0, 0, 200]
TOP_DOWN_MAP_COLORS[MAP_TARGET_POINT_INDICATOR] = [200, 0, 0]
TOP_DOWN_MAP_COLORS[MAP_SHORTEST_PATH_COLOR] = [0, 200, 0]
def draw_agent(
......@@ -334,3 +335,20 @@ def colorize_topdown_map(top_down_map: np.ndarray) -> np.ndarray:
A colored version of the top-down map.
"""
return TOP_DOWN_MAP_COLORS[top_down_map]
def draw_path(
top_down_map: np.ndarray,
path_points: List[Tuple],
color: int,
thickness: int = 2,
) -> None:
r"""Draw path on top_down_map (in place) with specified color.
Args:
top_down_map: A colored version of the map.
color: color code of the path, from TOP_DOWN_MAP_COLORS.
path_points: list of points that specify the path to be drawn
thickness: thickness of the path.
"""
for prev_pt, next_pt in zip(path_points[:-1], path_points[1:]):
cv2.line(top_down_map, prev_pt, next_pt, color, thickness=thickness)
......@@ -5,12 +5,15 @@
# LICENSE file in the root directory of this source tree.
import os
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import cv2
import imageio
import numpy as np
import tqdm
from habitat.utils.visualizations import maps
def paste_overlapping_image(
background: np.ndarray,
......@@ -103,7 +106,7 @@ def images_to_video(
Args:
images: The list of images. Images should be HxWx3 in RGB order.
output_dir: The folder to put the video in.
video_name: The navme for the video.
video_name: The name for the video.
fps: Frames per second for the video. Not all values work with FFMPEG,
use at your own risk.
quality: Default is 5. Uses variable bit rate. Highest quality is 10,
......@@ -125,3 +128,75 @@ def images_to_video(
for im in tqdm.tqdm(images):
writer.append_data(im)
writer.close()
def draw_collision(view: np.ndarray, alpha: float = 0.4) -> np.ndarray:
r"""Draw translucent red strips on the border of input view to indicate
a collision has taken place.
Args:
view: input view of size HxWx3 in RGB order.
alpha: Opacity of red collision strip. 1 is completely non-transparent.
Returns:
A view with collision effect drawn.
"""
size = view.shape[0]
strip_width = size // 20
mask = np.ones((size, size))
mask[strip_width:-strip_width, strip_width:-strip_width] = 0
mask = mask == 1
view[mask] = (alpha * np.array([255, 0, 0]) + (1.0 - alpha) * view)[mask]
return view
def observations_to_image(observation: Dict, info: Dict) -> np.ndarray:
r"""Generate image of single frame from observation and info
returned from a single environment step().
Args:
observation: observation returned from an environment step().
info: info returned from an environment step().
Returns:
generated image of a single frame.
"""
observation_size = observation["rgb"].shape[0]
egocentric_view = observation["rgb"][:, :, :3]
# draw collision
if "collisions" in info and info["collisions"]["is_collision"]:
egocentric_view = draw_collision(egocentric_view)
# draw depth map if observation has depth info
if "depth" in observation:
depth_map = (observation["depth"].squeeze() * 255).astype(np.uint8)
depth_map = np.stack([depth_map for _ in range(3)], axis=2)
egocentric_view = np.concatenate((egocentric_view, depth_map), axis=1)
frame = egocentric_view
if "top_down_map" in info:
top_down_map = info["top_down_map"]["map"]
top_down_map = maps.colorize_topdown_map(top_down_map)
map_agent_pos = info["top_down_map"]["agent_map_coord"]
top_down_map = maps.draw_agent(
image=top_down_map,
agent_center_coord=map_agent_pos,
agent_rotation=info["top_down_map"]["agent_angle"],
agent_radius_px=top_down_map.shape[0] // 16,
)
if top_down_map.shape[0] > top_down_map.shape[1]:
top_down_map = np.rot90(top_down_map, 1)
# scale top down map to align with rgb view
old_h, old_w, _ = top_down_map.shape
top_down_height = observation_size
top_down_width = int(float(top_down_height) / old_h * old_w)
# cv2 resize (dsize is width first)
top_down_map = cv2.resize(
top_down_map,
(top_down_width, top_down_height),
interpolation=cv2.INTER_CUBIC,
)
frame = np.concatenate((egocentric_view, top_down_map), axis=1)
return frame
......@@ -48,12 +48,6 @@ python -u habitat_baselines/train_ppo.py \
```
**single-episode training**:
Algorithms can be trained with a single-episode option. This option can be used as a sanity check since good algorithms should overfit one episode relatively fast. To enable this option, add `DATASET.NUM_EPISODE_SAMPLE 1` *at the end* of the training command, or include the single-episode yaml file in `--task-config` like this:
```
--task-config "configs/tasks/pointnav.yaml,configs/datasets/single_episode.yaml"
```
**test**:
```bash
python -u habitat_baselines/evaluate_ppo.py \
......@@ -78,3 +72,28 @@ Set argument `--task-config` to `tasks/pointnav_mp3d.yaml` for training on [Matt
- [Handcrafted agent baseline](slambased/README.md) adopted from the paper
"Benchmarking Classic and Learned Navigation in Complex 3D Environments".
### Additional Utilities
**single-episode training**:
Algorithms can be trained with a single-episode option. This option can be used as a sanity check since good algorithms should overfit one episode relatively fast. To enable this option, add `DATASET.NUM_EPISODE_SAMPLE 1` *at the end* of the training command, or include the single-episode yaml file in `--task-config` like this:
```
--task-config "configs/tasks/pointnav.yaml,configs/datasets/single_episode.yaml"
```
**tensorboard and video generation support**
Enable tensorboard logging by adding `--tensorboard-dir logdir` when running `train_ppo.py` or `evaluate_ppo.py`
Enable video generation for `evaluate_ppo.py` using `--video-option`: specifying `tensorboard`(for displaying on tensorboard) or `disk` (for saving videos on disk), for example:
```
python -u habitat_baselines/evaluate_ppo.py
...
--count-test-episodes 2 \
--video-option tensorboard \
--tensorboard-dir tb_eval \
--model-path data/checkpoints/ckpt.xx.pth
```
The above command should generate navigation episode recordings and display them on tensorboard like this:
<p align="center">
<img src="../res/img/tensorboard_video_demo.gif" height="500">
</p>
......@@ -5,45 +5,83 @@
# LICENSE file in the root directory of this source tree.
import argparse
import glob
import os
import time
from typing import Optional
import torch
import habitat
from config.default import get_config as cfg_baseline
from habitat import logger
from habitat.config.default import get_config
from habitat.utils.visualizations.utils import (
images_to_video,
observations_to_image,
)
from rl.ppo import PPO, Policy
from rl.ppo.utils import batch_obs
from tensorboard_utils import get_tensorboard_writer
from train_ppo import make_env_fn
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, required=True)
parser.add_argument("--sim-gpu-id", type=int, required=True)
parser.add_argument("--pth-gpu-id", type=int, required=True)
parser.add_argument("--num-processes", type=int, required=True)
parser.add_argument("--hidden-size", type=int, default=512)
parser.add_argument("--count-test-episodes", type=int, default=100)
parser.add_argument(
"--sensors",
type=str,
default="RGB_SENSOR,DEPTH_SENSOR",
help="comma separated string containing different"
"sensors to use, currently 'RGB_SENSOR' and"
"'DEPTH_SENSOR' are supported",
)
parser.add_argument(
"--task-config",
type=str,
default="configs/tasks/pointnav.yaml",
help="path to config yaml containing information about task",
def poll_checkpoint_folder(
checkpoint_folder: str, previous_ckpt_ind: int
) -> Optional[str]:
r""" Return (previous_ckpt_ind + 1)th checkpoint in checkpoint folder
(sorted by time of last modification).
Args:
checkpoint_folder: directory to look for checkpoints.
previous_ckpt_ind: index of checkpoint last returned.
Returns:
return checkpoint path if (previous_ckpt_ind + 1)th checkpoint is found
else return None.
"""
assert os.path.isdir(checkpoint_folder), "invalid checkpoint folder path"
models_paths = list(
filter(os.path.isfile, glob.glob(checkpoint_folder + "/*"))
)
args = parser.parse_args()
models_paths.sort(key=os.path.getmtime)
ind = previous_ckpt_ind + 1
if ind < len(models_paths):
return models_paths[ind]
return None
def generate_video(
args, images, episode_id, checkpoint_idx, spl, tb_writer, fps=10
) -> None:
r"""Generate video according to specified information.
device = torch.device("cuda:{}".format(args.pth_gpu_id))
Args:
args: contains args.video_option and args.video_dir.
images: list of images to be converted to video.
episode_id: episode id for video naming.
checkpoint_idx: checkpoint index for video naming.
spl: SPL for this episode for video naming.
tb_writer: tensorboard writer object for uploading video
fps: fps for generated video
Returns:
None
"""
if args.video_option and len(images) > 0:
video_name = f"episode{episode_id}_ckpt{checkpoint_idx}_spl{spl:.2f}"
if "disk" in args.video_option:
images_to_video(images, args.video_dir, video_name)
if "tensorboard" in args.video_option:
tb_writer.add_video_from_np_images(
f"episode{episode_id}", checkpoint_idx, images, fps=fps
)
def eval_checkpoint(checkpoint_path, args, writer, cur_ckpt_idx=0):
env_configs = []
baseline_configs = []
device = torch.device("cuda", args.pth_gpu_id)
for _ in range(args.num_processes):
config_env = get_config(config_paths=args.task_config)
......@@ -54,6 +92,9 @@ def main():
for sensor in agent_sensors:
assert sensor in ["RGB_SENSOR", "DEPTH_SENSOR"]
config_env.SIMULATOR.AGENT_0.SENSORS = agent_sensors
if args.video_option:
config_env.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
config_env.TASK.MEASUREMENTS.append("COLLISIONS")
config_env.freeze()
env_configs.append(config_env)
......@@ -71,7 +112,7 @@ def main():
),
)
ckpt = torch.load(args.model_path, map_location=device)
ckpt = torch.load(checkpoint_path, map_location=device)
actor_critic = Policy(
observation_space=envs.observation_spaces[0],
......@@ -112,8 +153,16 @@ def main():
args.num_processes, args.hidden_size, device=device
)
not_done_masks = torch.zeros(args.num_processes, 1, device=device)
stats_episodes = set()
rgb_frames = None
if args.video_option:
rgb_frames = [[]] * args.num_processes
os.makedirs(args.video_dir, exist_ok=True)
while episode_counts.sum() < args.count_test_episodes:
current_episodes = envs.current_episodes()
with torch.no_grad():
_, actions, _, test_recurrent_hidden_states = actor_critic.act(
batch,
......@@ -149,13 +198,151 @@ def main():
episode_counts += 1 - not_done_masks
current_episode_reward *= not_done_masks
next_episodes = envs.current_episodes()
envs_to_pause = []
n_envs = envs.num_envs
for i in range(n_envs):
if next_episodes[i].episode_id in stats_episodes:
envs_to_pause.append(i)
# episode ended
if not_done_masks[i].item() == 0:
stats_episodes.add(current_episodes[i].episode_id)
if args.video_option:
generate_video(
args,
rgb_frames[i],
current_episodes[i].episode_id,
cur_ckpt_idx,
infos[i]["spl"],
writer,
)
rgb_frames[i] = []
# episode continues
elif args.video_option:
frame = observations_to_image(observations[i], infos[i])
rgb_frames[i].append(frame)
# stop tracking ended episodes if they exist
if len(envs_to_pause) > 0:
state_index = list(range(envs.num_envs))
for idx in reversed(envs_to_pause):
state_index.pop(idx)
envs.pause_at(idx)
# indexing along the batch dimensions
test_recurrent_hidden_states = test_recurrent_hidden_states[
:, state_index
]
not_done_masks = not_done_masks[state_index]
current_episode_reward = current_episode_reward[state_index]
for k, v in batch.items():
batch[k] = v[state_index]
if args.video_option:
rgb_frames = [rgb_frames[i] for i in state_index]
episode_reward_mean = (episode_rewards / episode_counts).mean().item()
episode_spl_mean = (episode_spls / episode_counts).mean().item()
episode_success_mean = (episode_success / episode_counts).mean().item()
print("Average episode reward: {:.6f}".format(episode_reward_mean))
print("Average episode success: {:.6f}".format(episode_success_mean))
print("Average episode spl: {:.6f}".format(episode_spl_mean))
logger.info("Average episode reward: {:.6f}".format(episode_reward_mean))
logger.info("Average episode success: {:.6f}".format(episode_success_mean))
logger.info("Average episode SPL: {:.6f}".format(episode_spl_mean))
writer.add_scalars(
"eval_reward", {"average reward": episode_reward_mean}, cur_ckpt_idx
)
writer.add_scalars(
"eval_SPL", {"average SPL": episode_spl_mean}, cur_ckpt_idx
)
writer.add_scalars(
"eval_success", {"average success": episode_success_mean}, cur_ckpt_idx
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str)
parser.add_argument("--tracking-model-dir", type=str)
parser.add_argument("--sim-gpu-id", type=int, required=True)
parser.add_argument("--pth-gpu-id", type=int, required=True)
parser.add_argument("--num-processes", type=int, required=True)
parser.add_argument("--hidden-size", type=int, default=512)
parser.add_argument("--count-test-episodes", type=int, default=100)
parser.add_argument(
"--sensors",
type=str,
default="RGB_SENSOR,DEPTH_SENSOR",
help="comma separated string containing different"
"sensors to use, currently 'RGB_SENSOR' and"
"'DEPTH_SENSOR' are supported",
)
parser.add_argument(
"--task-config",
type=str,
default="configs/tasks/pointnav.yaml",
help="path to config yaml containing information about task",
)
parser.add_argument(
"--video-option",
type=str,
default="",
choices=["tensorboard", "disk"],
nargs="*",
help="Options for video output, leave empty for no video. "
"Videos can be saved to disk, uploaded to tensorboard, or both.",
)
parser.add_argument(
"--video-dir", type=str, help="directory for storing videos"
)
parser.add_argument(
"--tensorboard-dir",
type=str,
help="directory for storing tensorboard statistics",
)
args = parser.parse_args()
assert (args.model_path is not None) != (
args.tracking_model_dir is not None
), "Must specify a single model or a directory of models, but not both"
if "tensorboard" in args.video_option:
assert (
args.tensorboard_dir is not None
), "Must specify a tensorboard directory for video display"
if "disk" in args.video_option:
assert (
args.video_dir is not None
), "Must specify a directory for storing videos on disk"
with get_tensorboard_writer(
args.tensorboard_dir, purge_step=0, flush_secs=30
) as writer:
if args.model_path is not None:
# evaluate singe checkpoint
eval_checkpoint(args.model_path, args, writer)
else:
# evaluate multiple checkpoints in order
prev_ckpt_ind = -1
while True:
current_ckpt = None
while current_ckpt is None:
current_ckpt = poll_checkpoint_folder(
args.tracking_model_dir, prev_ckpt_ind
)
time.sleep(2) # sleep for 2 seconds before polling again
logger.warning(
"=============current_ckpt: {}=============".format(
current_ckpt
)
)
prev_ckpt_ind += 1
eval_checkpoint(
current_ckpt, args, writer, cur_ckpt_idx=prev_ckpt_ind
)
if __name__ == "__main__":
......
......@@ -422,4 +422,9 @@ def ppo_args():
nargs=argparse.REMAINDER,
help="Modify config options from command line",
)
parser.add_argument(
"--tensorboard-dir",
type=str,
help="path to tensorboard logging directory",
)
return parser
torch==1.1.0
# full tensorflow required for tensorboard video support
tensorflow==1.13.1
tb-nightly
from typing import Optional, Union
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
class TensorboardWriter(SummaryWriter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def add_video_from_np_images(
self, video_name: str, step_idx: int, images: np.ndarray, fps: int = 10
) -> None:
r"""Write video into tensorboard from images frames.
Args:
video_name: name of video string.
step_idx: int of checkpoint index to be displayed.
images: list of n frames. Each frame is a np.ndarray of shape.
fps: frame per second for output video.
Returns:
None.
"""
# initial shape of np.ndarray list: N * (H, W, 3)
frame_tensors = [
torch.from_numpy(np_arr).unsqueeze(0) for np_arr in images
]
video_tensor = torch.cat(tuple(frame_tensors))
video_tensor = video_tensor.permute(0, 3, 1, 2).unsqueeze(0)
# final shape of video tensor: (1, n, 3, H, W)
self.add_video(video_name, video_tensor, fps=fps, global_step=step_idx)
class DummyWriter:
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def close(self):
pass
def __getattr__(self, item):
return lambda *args, **kwargs: None
def get_tensorboard_writer(
log_dir: str, *args, **kwargs
) -> Union[DummyWriter, TensorboardWriter]:
r"""Get tensorboard writer if log_dir is specified, otherwise,
return dummy writer instead.
Args:
log_dir: log directory path for tensorboard SummaryWriter.
*args: additional positional args.
**kwargs: additional keyword args.
Returns:
Either the created tensorboard writer or a dummy writer.
"""
if log_dir:
return TensorboardWriter(log_dir, *args, **kwargs)
else:
return DummyWriter()
......@@ -19,6 +19,7 @@ from habitat.config.default import get_config as cfg_env
from habitat.datasets.registration import make_dataset
from rl.ppo import PPO, Policy, RolloutStorage
from rl.ppo.utils import batch_obs, ppo_args, update_linear_schedule
from tensorboard_utils import get_tensorboard_writer
class NavRLEnv(habitat.RLEnv):
......@@ -162,7 +163,7 @@ def run_training():
random.seed(args.seed)
device = torch.device("cuda:{}".format(args.pth_gpu_id))
device = torch.device("cuda", args.pth_gpu_id)
logger.add_filehandler(args.log_file)
......@@ -218,8 +219,8 @@ def run_training():
episode_rewards = torch.zeros(envs.num_envs, 1)
episode_counts = torch.zeros(envs.num_envs, 1)
current_episode_reward = torch.zeros(envs.num_envs, 1)
window_episode_reward = deque()
window_episode_counts = deque()
window_episode_reward = deque(maxlen=args.reward_window_size)
window_episode_counts = deque(maxlen=args.reward_window_size)
t_start = time()
env_time = 0
......@@ -227,137 +228,169 @@ def run_training():
count_steps = 0
count_checkpoints = 0
for update in range(args.num_updates):
if args.use_linear_lr_decay:
update_linear_schedule(
agent.optimizer, update, args.num_updates, args.lr
)
with (
get_tensorboard_writer(
log_dir=args.tensorboard_dir, purge_step=count_steps, flush_secs=30
)
) as writer:
for update in range(args.num_updates):
if args.use_linear_lr_decay:
update_linear_schedule(
agent.optimizer, update, args.num_updates, args.lr
)
agent.clip_param = args.clip_param * (1 - update / args.num_updates)
agent.clip_param = args.clip_param * (
1 - update / args.num_updates
)
for step in range(args.num_steps):
t_sample_action = time()
# sample actions
with torch.no_grad():
step_observation = {
k: v[step] for k, v in rollouts.observations.items()
}
for step in range(args.num_steps):
t_sample_action = time()
# sample actions
with torch.no_grad():
step_observation = {
k: v[step] for k, v in rollouts.observations.items()
}
(
values,
actions,
actions_log_probs,
recurrent_hidden_states,
) = actor_critic.act(
step_observation,
rollouts.recurrent_hidden_states[step],
rollouts.masks[step],
)
pth_time += time() - t_sample_action
(
values,
actions,
actions_log_probs,
recurrent_hidden_states,
) = actor_critic.act(
step_observation,
rollouts.recurrent_hidden_states[step],
rollouts.masks[step],
)
pth_time += time() - t_sample_action
t_step_env = time()
t_step_env = time()
outputs = envs.step([a[0].item() for a in actions])
observations, rewards, dones, infos = [
list(x) for x in zip(*outputs)
]
outputs = envs.step([a[0].item() for a in actions])
observations, rewards, dones, infos = [
list(x) for x in zip(*outputs)
]
env_time += time() - t_step_env
env_time += time() - t_step_env
t_update_stats = time()
batch = batch_obs(observations)
rewards = torch.tensor(rewards, dtype=torch.float)
rewards = rewards.unsqueeze(1)
t_update_stats = time()
batch = batch_obs(observations)
rewards = torch.tensor(rewards, dtype=torch.float)
rewards = rewards.unsqueeze(1)
masks = torch.tensor(
[[0.0] if done else [1.0] for done in dones],
dtype=torch.float,
)
masks = torch.tensor(
[[0.0] if done else [1.0] for done in dones], dtype=torch.float
)
current_episode_reward += rewards
episode_rewards += (1 - masks) * current_episode_reward
episode_counts += 1 - masks
current_episode_reward *= masks
current_episode_reward += rewards
episode_rewards += (1 - masks) * current_episode_reward
episode_counts += 1 - masks
current_episode_reward *= masks
rollouts.insert(
batch,
recurrent_hidden_states,
actions,
actions_log_probs,
values,
rewards,
masks,
)
rollouts.insert(
batch,
recurrent_hidden_states,
actions,
actions_log_probs,
values,
rewards,
masks,
)
count_steps += envs.num_envs
pth_time += time() - t_update_stats
count_steps += envs.num_envs
pth_time += time() - t_update_stats
if len(window_episode_reward) == args.reward_window_size:
window_episode_reward.popleft()
window_episode_counts.popleft()
window_episode_reward.append(episode_rewards.clone())
window_episode_counts.append(episode_counts.clone())
window_episode_reward.append(episode_rewards.clone())
window_episode_counts.append(episode_counts.clone())
t_update_model = time()
with torch.no_grad():
last_observation = {
k: v[-1] for k, v in rollouts.observations.items()
}
next_value = actor_critic.get_value(
last_observation,
rollouts.recurrent_hidden_states[-1],
rollouts.masks[-1],
).detach()
rollouts.compute_returns(
next_value, args.use_gae, args.gamma, args.tau
)
t_update_model = time()
with torch.no_grad():
last_observation = {
k: v[-1] for k, v in rollouts.observations.items()
}
next_value = actor_critic.get_value(
last_observation,
rollouts.recurrent_hidden_states[-1],
rollouts.masks[-1],
).detach()
rollouts.compute_returns(
next_value, args.use_gae, args.gamma, args.tau
)
value_loss, action_loss, dist_entropy = agent.update(rollouts)
value_loss, action_loss, dist_entropy = agent.update(rollouts)
rollouts.after_update()
pth_time += time() - t_update_model
rollouts.after_update()
pth_time += time() - t_update_model
# log stats
if update > 0 and update % args.log_interval == 0:
logger.info(
"update: {}\tfps: {:.3f}\t".format(
update, count_steps / (time() - t_start)
losses = [value_loss, action_loss]
stats = zip(
["count", "reward"],
[window_episode_counts, window_episode_reward],
)
deltas = {
k: (
(v[-1] - v[0]).sum().item()
if len(v) > 1
else v[0].sum().item()
)
for k, v in stats
}
deltas["count"] = max(deltas["count"], 1.0)
writer.add_scalar(
"reward", deltas["reward"] / deltas["count"], count_steps
)
logger.info(
"update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
"frames: {}".format(update, env_time, pth_time, count_steps)
writer.add_scalars(
"losses",
{k: l for l, k in zip(losses, ["value", "policy"])},
count_steps,
)
window_rewards = (
window_episode_reward[-1] - window_episode_reward[0]
).sum()
window_counts = (
window_episode_counts[-1] - window_episode_counts[0]
).sum()
# log stats
if update > 0 and update % args.log_interval == 0:
logger.info(
"update: {}\tfps: {:.3f}\t".format(
update, count_steps / (time() - t_start)
)
)
if window_counts > 0:
logger.info(
"Average window size {} reward: {:3f}".format(
len(window_episode_reward),
(window_rewards / window_counts).item(),
"update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
"frames: {}".format(
update, env_time, pth_time, count_steps
)
)
else:
logger.info("No episodes finish in current window")
# checkpoint model
if update % args.checkpoint_interval == 0:
checkpoint = {"state_dict": agent.state_dict()}
torch.save(
checkpoint,
os.path.join(
args.checkpoint_folder,
"ckpt.{}.pth".format(count_checkpoints),
),
)
count_checkpoints += 1
window_rewards = (
window_episode_reward[-1] - window_episode_reward[0]
).sum()
window_counts = (
window_episode_counts[-1] - window_episode_counts[0]
).sum()
if window_counts > 0:
logger.info(
"Average window size {} reward: {:3f}".format(
len(window_episode_reward),
(window_rewards / window_counts).item(),
)
)
else:
logger.info("No episodes finish in current window")
# checkpoint model
if update % args.checkpoint_interval == 0:
checkpoint = {"state_dict": agent.state_dict(), "args": args}
torch.save(
checkpoint,
os.path.join(
args.checkpoint_folder,
"ckpt.{}.pth".format(count_checkpoints),
),
)
count_checkpoints += 1
if __name__ == "__main__":
......
res/img/tensorboard_video_demo.gif

1.68 MiB

......@@ -127,7 +127,7 @@ def test_collisions():
for _ in range(50):
action = np.random.choice(actions)
env.step(action)
collisions = env.get_metrics()["collisions"]
collisions = env.get_metrics()["collisions"]["count"]
loc = env.sim.get_agent_state().position
if (
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment