diff --git a/.editorconfig b/.editorconfig index 31ce00827d608174d170d0ad874e23301c2d9436..8e32cffced05094e281b67a7422411ddef77634d 100644 --- a/.editorconfig +++ b/.editorconfig @@ -20,4 +20,4 @@ include_trailing_comma = true ensure_newline_before_comments=true use_parentheses = true known_first_party = habitat,habitat_sim,habitat_baselines,version -known_third_party = PIL,attr,conf,gym,imageio,matplotlib,mock,numba,numpy,orbslam2,pyrobot,pytest,quaternion,requests,scipy,setuptools,torch,torchvision,tqdm,yacs +known_third_party = PIL,attr,conf,gym,ifcfg,imageio,matplotlib,mock,numba,numpy,orbslam2,pyrobot,pytest,quaternion,requests,scipy,setuptools,torch,torchvision,tqdm,yacs diff --git a/habitat_baselines/__init__.py b/habitat_baselines/__init__.py index 4fd2c3c3c3111911db3f57269d0b824af5584585..a5d02f3e552323c6b1d35caef81455a107395ad6 100644 --- a/habitat_baselines/__init__.py +++ b/habitat_baselines/__init__.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from habitat_baselines.common.base_trainer import BaseRLTrainer, BaseTrainer +from habitat_baselines.rl.ddppo import DDPPOTrainer from habitat_baselines.rl.ppo.ppo_trainer import PPOTrainer, RolloutStorage __all__ = ["BaseTrainer", "BaseRLTrainer", "PPOTrainer", "RolloutStorage"] diff --git a/habitat_baselines/common/rollout_storage.py b/habitat_baselines/common/rollout_storage.py index fadd74017f74408d48a1c5c67a47c7c0587c96ac..24cc19bb852fe5375ac47b5f0c0a496e737a5cfc 100644 --- a/habitat_baselines/common/rollout_storage.py +++ b/habitat_baselines/common/rollout_storage.py @@ -55,7 +55,7 @@ class RolloutStorage: self.actions = self.actions.long() self.prev_actions = self.prev_actions.long() - self.masks = torch.ones(num_steps + 1, num_envs, 1) + self.masks = torch.zeros(num_steps + 1, num_envs, 1) self.num_steps = num_steps self.step = 0 @@ -97,21 +97,26 @@ class RolloutStorage: self.rewards[self.step].copy_(rewards) self.masks[self.step + 1].copy_(masks) - self.step = (self.step + 1) % self.num_steps + self.step = self.step + 1 def after_update(self): for sensor in self.observations: - self.observations[sensor][0].copy_(self.observations[sensor][-1]) + self.observations[sensor][0].copy_( + self.observations[sensor][self.step] + ) - self.recurrent_hidden_states[0].copy_(self.recurrent_hidden_states[-1]) - self.masks[0].copy_(self.masks[-1]) - self.prev_actions[0].copy_(self.prev_actions[-1]) + self.recurrent_hidden_states[0].copy_( + self.recurrent_hidden_states[self.step] + ) + self.masks[0].copy_(self.masks[self.step]) + self.prev_actions[0].copy_(self.prev_actions[self.step]) + self.step = 0 def compute_returns(self, next_value, use_gae, gamma, tau): if use_gae: - self.value_preds[-1] = next_value + self.value_preds[self.step] = next_value gae = 0 - for step in reversed(range(self.rewards.size(0))): + for step in reversed(range(self.step)): delta = ( self.rewards[step] + gamma * self.value_preds[step + 1] * self.masks[step + 1] @@ -120,8 +125,8 @@ class RolloutStorage: gae = delta + gamma * tau * self.masks[step + 1] * gae self.returns[step] = gae + self.value_preds[step] else: - self.returns[-1] = next_value - for step in reversed(range(self.rewards.size(0))): + self.returns[self.step] = next_value + for step in reversed(range(self.step)): self.returns[step] = ( self.returns[step + 1] * gamma * self.masks[step + 1] + self.rewards[step] @@ -153,25 +158,25 @@ class RolloutStorage: for sensor in self.observations: observations_batch[sensor].append( - self.observations[sensor][:-1, ind] + self.observations[sensor][: self.step, ind] ) recurrent_hidden_states_batch.append( self.recurrent_hidden_states[0, :, ind] ) - actions_batch.append(self.actions[:, ind]) - prev_actions_batch.append(self.prev_actions[:-1, ind]) - value_preds_batch.append(self.value_preds[:-1, ind]) - return_batch.append(self.returns[:-1, ind]) - masks_batch.append(self.masks[:-1, ind]) + actions_batch.append(self.actions[: self.step, ind]) + prev_actions_batch.append(self.prev_actions[: self.step, ind]) + value_preds_batch.append(self.value_preds[: self.step, ind]) + return_batch.append(self.returns[: self.step, ind]) + masks_batch.append(self.masks[: self.step, ind]) old_action_log_probs_batch.append( - self.action_log_probs[:, ind] + self.action_log_probs[: self.step, ind] ) - adv_targ.append(advantages[:, ind]) + adv_targ.append(advantages[: self.step, ind]) - T, N = self.num_steps, num_envs_per_batch + T, N = self.step, num_envs_per_batch # These are all tensors of size (T, N, -1) for sensor in observations_batch: diff --git a/habitat_baselines/config/default.py b/habitat_baselines/config/default.py index e28e5c962a586cf114b0f786797a748b73e63b37..b4a3abbcf789678f14ad9bf1f553f86851716ade 100644 --- a/habitat_baselines/config/default.py +++ b/habitat_baselines/config/default.py @@ -62,13 +62,32 @@ _C.RL.PPO.lr = 7e-4 _C.RL.PPO.eps = 1e-5 _C.RL.PPO.max_grad_norm = 0.5 _C.RL.PPO.num_steps = 5 -_C.RL.PPO.hidden_size = 512 _C.RL.PPO.use_gae = True _C.RL.PPO.use_linear_lr_decay = False _C.RL.PPO.use_linear_clip_decay = False _C.RL.PPO.gamma = 0.99 _C.RL.PPO.tau = 0.95 _C.RL.PPO.reward_window_size = 50 +_C.RL.PPO.use_normalized_advantage = True +_C.RL.PPO.hidden_size = 512 +# ----------------------------------------------------------------------------- +# DECENTRALIZED DISTRIBUTED PROXIMAL POLICY OPTIMIZATION (DD-PPO) +# ----------------------------------------------------------------------------- +_C.RL.DDPPO = CN() +_C.RL.DDPPO.sync_frac = 0.6 +_C.RL.DDPPO.distrib_backend = "GLOO" +_C.RL.DDPPO.rnn_type = "LSTM" +_C.RL.DDPPO.num_recurrent_layers = 2 +_C.RL.DDPPO.backbone = "resnet50" +_C.RL.DDPPO.pretrained_weights = "data/ddppo-models/gibson-2plus-resnet50.pth" +# Loads pretrained weights +_C.RL.DDPPO.pretrained = False +# Loads just the visual encoder backbone weights +_C.RL.DDPPO.pretrained_encoder = False +# Whether or not the visual encoder backbone will be trained +_C.RL.DDPPO.train_encoder = True +# Whether or not to reset the critic linear layer +_C.RL.DDPPO.reset_critic = True # ----------------------------------------------------------------------------- # ORBSLAM2 BASELINE # ----------------------------------------------------------------------------- diff --git a/habitat_baselines/config/pointnav/ddppo_pointnav.yaml b/habitat_baselines/config/pointnav/ddppo_pointnav.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b6f56c62d5d4d2db259d57643d810df6768adb26 --- /dev/null +++ b/habitat_baselines/config/pointnav/ddppo_pointnav.yaml @@ -0,0 +1,60 @@ +BASE_TASK_CONFIG_PATH: "configs/tasks/pointnav_gibson.yaml" +TRAINER_NAME: "ddppo" +ENV_NAME: "NavRLEnv" +SIMULATOR_GPU_ID: 0 +TORCH_GPU_ID: 0 +VIDEO_OPTION: [] +TENSORBOARD_DIR: "tb" +VIDEO_DIR: "video_dir" +TEST_EPISODE_COUNT: 994 +EVAL_CKPT_PATH_DIR: "data/new_checkpoints" +NUM_PROCESSES: 4 +SENSORS: ["DEPTH_SENSOR"] +CHECKPOINT_FOLDER: "data/new_checkpoints" +NUM_UPDATES: 10000 +LOG_INTERVAL: 10 +CHECKPOINT_INTERVAL: 50 + +RL: + SUCCESS_REWARD: 2.5 + PPO: + # ppo params + clip_param: 0.2 + ppo_epoch: 2 + num_mini_batch: 2 + value_loss_coef: 0.5 + entropy_coef: 0.01 + lr: 2.5e-4 + eps: 1e-5 + max_grad_norm: 0.2 + num_steps: 128 + use_gae: True + gamma: 0.99 + tau: 0.95 + use_linear_clip_decay: False + use_linear_lr_decay: False + reward_window_size: 50 + + use_normalized_advantage: False + + hidden_size: 512 + + DDPPO: + sync_frac: 0.6 + # The PyTorch distributed backend to use + distrib_backend: GLOO + # Visual encoder backbone + pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth + # Initialize with pretrained weights + pretrained: False + # Initialize just the visual encoder backbone with pretrained weights + pretrained_encoder: False + # Whether or not the visual encoder backbone will be trained. + train_encoder: True + # Whether or not to reset the critic linear layer + reset_critic: True + + # Model parameters + backbone: resnet50 + rnn_type: LSTM + num_recurrent_layers: 2 diff --git a/habitat_baselines/config/test/ddppo_pointnav_test.yaml b/habitat_baselines/config/test/ddppo_pointnav_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f9d161f7efafc8cb2e71823c37b85367b32cd73e --- /dev/null +++ b/habitat_baselines/config/test/ddppo_pointnav_test.yaml @@ -0,0 +1,57 @@ +BASE_TASK_CONFIG_PATH: "configs/tasks/pointnav.yaml" +TRAINER_NAME: "ddppo" +ENV_NAME: "NavRLEnv" +SIMULATOR_GPU_ID: 0 +TORCH_GPU_ID: 0 +VIDEO_OPTION: [] +TENSORBOARD_DIR: "" +EVAL_CKPT_PATH_DIR: "data/test_checkpoints/ddppo/pointnav/ckpt.0.pth" +NUM_PROCESSES: 1 +CHECKPOINT_FOLDER: "data/test_checkpoints/ddppo/pointnav/" +NUM_UPDATES: 2 +LOG_INTERVAL: 100 +CHECKPOINT_INTERVAL: 1 + +RL: + SUCCESS_REWARD: 2.5 + PPO: + # ppo params + clip_param: 0.2 + ppo_epoch: 2 + num_mini_batch: 1 + value_loss_coef: 0.5 + entropy_coef: 0.01 + lr: 2.5e-4 + eps: 1e-5 + max_grad_norm: 0.2 + num_steps: 16 + use_gae: True + gamma: 0.99 + tau: 0.95 + use_linear_clip_decay: False + use_linear_lr_decay: False + reward_window_size: 50 + + use_normalized_advantage: False + + hidden_size: 512 + + DDPPO: + sync_frac: 0.6 + # The PyTorch distributed backend to use + distrib_backend: GLOO + # Visual encoder backbone + pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth + # Initialize with pretrained weights + pretrained: False + # Initialize just the visual encoder backbone with pretrained weights + pretrained_encoder: False + # Whether or not the visual encoder backbone will be trained. + train_encoder: True + # Whether or not to reset the critic linear layer + reset_critic: True + + # Model parameters + backbone: resnet50 + rnn_type: LSTM + num_recurrent_layers: 2 diff --git a/habitat_baselines/rl/ddppo/README.md b/habitat_baselines/rl/ddppo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5336ae67d968deab6528212565f2d30cace88164 --- /dev/null +++ b/habitat_baselines/rl/ddppo/README.md @@ -0,0 +1,64 @@ +# Decentralized Distributed PPO + +Provides changes to the core baseline ppo algorithm and training script to implemented Decentralized Distributed PPO (DD-PPO). +DD-PPO leverages distributed data parallelism to seamlessly scale PPO to hundreds of GPUs with no centralized server. + +See the [paper](https://arxiv.org/abs/1911.00357) for more detail. + +## Running + +There are two example scripts to run provided. A single node script that leverages `torch.distributed.launch` to create multiple workers: +`single_node.sh`, and a multi-node script that leverages [SLURM](https://slurm.schedmd.com/documentation.html) to create all the works on multiple nodes: `multi_node_slurm.sh`. + +The two recommended backends are GLOO and NCCL. Use NCCL if your system has it, and GLOO if otherwise. + +See [pytorch's distributed docs](https://pytorch.org/docs/stable/distributed.html#backends-that-come-with-pytorch) +and [pytorch's distributed tutorial](https://pytorch.org/tutorials/intermediate/dist_tuto.html) for more information. + +## Pretrained Models (PointGoal Navigation with GPS+Compass) + + +All weights available as a zip [here](https://drive.google.com/open?id=1ueXuIqP2HZ0oxhpDytpc3hpciXSd8H16). + +### Depth models + +| Architecture | Training Data | Val SPL | Test SPL | URL | +| ------------ | ------------- | ------- | -------- | --- | +| ResNet50 + LSTM512 | Gibson 4+ | 0.922 | 0.917 | | +| ResNet50 + LSTM512 | Gibson 4+ and MP3D(train/val/test)<br/> **Caution:** Trained on MP3D val and test | 0.956 | 0.941 | +| ResNet50 + LSTM512 | Gibson 2+ | 0.956 | 0.944 | | +| SE-ResNeXt50 + LSTM512 | Gibson 2+ | 0.959 | 0.943 | | +| SE-ResNeXt101 + LSTM1024 | Gibson 2+ | 0.969 | 0.948 | | + +### RGB models + +| Architecture | Training Data | Val SPL | Test SPL | URL | +| ------------ | ------------- | ------- | -------- | --- | +| ResNet50 + LSTM512 | Gibson 2+ and MP3D(train/val/test)<br/> **Caution:** Trained on MP3D val and test | | | +| SE-ResNeXt50 + LSTM512 | Gibson 2+ and MP3D(train/val/test)<br/> **Caution:** Trained on MP3D val and test | 0.933 | 0.920 | + + +### Blind Models + +| Architecture | Training Data | Val SPL | Test SPL | URL | +| ------------ | ------------- | ------- | -------- | --- | +| LSTM512 | Gibson 0+ and MP3D(train/val/test)<br/> **Caution:** Trained on MP3D val and test | 0.729 | 0.676 | + + + + +**Note:** Evaluation was done with *sampled* actions. + +All model weights are subject to [Matterport3D Terms-of-Use](http://dovahkiin.stanford.edu/matterport/public/MP_TOS.pdf). + + +## Citing + +If you use DD-PPO or the model-weights in your research, please cite the following [paper](https://arxiv.org/abs/1911.00357): + + @article{wijmans2020ddppo, + title = {{D}ecentralized {D}istributed {PPO}: {S}olving {P}oint{G}oal {N}avigation}, + author = {Erik Wijmans and Abhishek Kadian and Ari Morcos and Stefan Lee and Irfan Essa and Devi Parikh and Manolis Savva and Dhruv Batra}, + journal = {International Conference on Learning Representations (ICLR)}, + year = {2020} + } diff --git a/habitat_baselines/rl/ddppo/__init__.py b/habitat_baselines/rl/ddppo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c0f29b5b20bab51f2658a49660c8f68a45c0502 --- /dev/null +++ b/habitat_baselines/rl/ddppo/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from habitat_baselines.rl.ddppo.algo import DDPPOTrainer diff --git a/habitat_baselines/rl/ddppo/algo/__init__.py b/habitat_baselines/rl/ddppo/algo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..090eb6e0461f7e3cb1789d988aee9b9b200e3edf --- /dev/null +++ b/habitat_baselines/rl/ddppo/algo/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from habitat_baselines.rl.ddppo.algo.ddppo_trainer import DDPPOTrainer diff --git a/habitat_baselines/rl/ddppo/algo/ddp_utils.py b/habitat_baselines/rl/ddppo/algo/ddp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..892a1ea422137efae27083da2263c40c07628ebf --- /dev/null +++ b/habitat_baselines/rl/ddppo/algo/ddp_utils.py @@ -0,0 +1,165 @@ +import os +import os.path as osp +import shlex +import signal +import subprocess +import threading +from typing import Any, Optional, Tuple + +import ifcfg +import torch +import torch.distributed as distrib + +from habitat import logger + +EXIT = threading.Event() +EXIT.clear() +REQUEUE = threading.Event() +REQUEUE.clear() + + +# Default port to initialized the TCP store on +DEFAULT_PORT = 8738 +# Default address of world rank 0 +DEFAULT_MASTER_ADDR = "127.0.0.1" + +SLURM_JOBID = os.environ.get("SLURM_JOB_ID", None) +INTERRUPTED_STATE_FILE = osp.join( + os.environ["HOME"], ".interrupted_states", f"{SLURM_JOBID}.pth" +) + + +def _clean_exit_handler(signum, frame): + EXIT.set() + print("Exiting cleanly", flush=True) + + +def _requeue_handler(signal, frame): + EXIT.set() + REQUEUE.set() + + +def add_signal_handlers(): + signal.signal(signal.SIGINT, _clean_exit_handler) + signal.signal(signal.SIGTERM, _clean_exit_handler) + + # SIGUSR2 can be sent to all processes to have them cleanup + # and exit nicely. This is nice to use with SLURM as scancel <job_id> + # sets a 30 second timer for the job to exit, and it can take more than + # 30 seconds for the job to cleanup and exit nicely. When using NCCL, + # forcing the job to exit without cleaning up can be bad. + # scancel --signal SIGUSR2 <job_id> will set no such timer and will give + # the job ample time to cleanup and exit. + signal.signal(signal.SIGUSR2, _clean_exit_handler) + + signal.signal(signal.SIGUSR1, _requeue_handler) + + +def save_interrupted_state(state: Any, filename: str = None): + r"""Saves the interrupted job state to the specified filename. + This is useful when working with preemptable job partitions. + + This method will do nothing if SLURM is not currently being used and the filename is the default + + :param state: The state to save + :param filename: The filename. Defaults to "${HOME}/.interrupted_states/${SLURM_JOBID}.pth" + """ + if SLURM_JOBID is None and filename is None: + logger.warn("SLURM_JOBID is none, not saving interrupted state") + return + + if filename is None: + filename = INTERRUPTED_STATE_FILE + + torch.save(state, filename) + + +def load_interrupted_state(filename: str = None) -> Optional[Any]: + r"""Loads the saved interrupted state + + :param filename: The filename of the saved state. + Defaults to "${HOME}/.interrupted_states/${SLURM_JOBID}.pth" + + :return: The saved state if the file exists, else none + """ + if SLURM_JOBID is None and filename is None: + return None + + if filename is None: + filename = INTERRUPTED_STATE_FILE + + if not osp.exists(filename): + return None + + return torch.load(filename, map_location="cpu") + + +def requeue_job(): + r"""Requeues the job by calling `scontrol requeue ${SLURM_JOBID}` + """ + if SLURM_JOBID is None: + return + + if not REQUEUE.is_set(): + return + + distrib.barrier() + + if distrib.get_rank() == 0: + logger.info(f"Requeueing job {SLURM_JOBID}") + subprocess.check_call(shlex.split("scontrol requeue {SLURM_JOBID}")) + + +def get_ifname(): + return ifcfg.default_interface()["device"] + + +def init_distrib_slurm( + backend: str = "nccl", +) -> Tuple[int, torch.distributed.TCPStore]: + r"""Initializes torch.distributed by parsing environment variables set + by SLURM when `srun` is used or by parsing environment variables set + by torch.distributed.launch + + :param backend: Which torch.distributed backend to use + + :returns: Tuple of the local_rank (aka which GPU to use for this process) + and the TCPStore used for the rendezvous + """ + assert ( + torch.distributed.is_available() + ), "torch.distributed must be available" + + if "GLOO_SOCKET_IFNAME" not in os.environ: + os.environ["GLOO_SOCKET_IFNAME"] = get_ifname() + + if "NCCL_SOCKET_IFNAME" not in os.environ: + os.environ["NCCL_SOCKET_IFNAME"] = get_ifname() + + master_port = int(os.environ.get("MASTER_PORT", DEFAULT_PORT)) + master_addr = os.environ.get("MASTER_ADDR", DEFAULT_MASTER_ADDR) + + # Check to see if we should parse from torch.distributed.launch + if os.environ.get("LOCAL_RANK", None) is not None: + local_rank = int(os.environ["LOCAL_RANK"]) + world_rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + # Else parse from SLURM is using SLURM + elif os.environ.get("SLURM_JOBID", None) is not None: + local_rank = int(os.environ["SLURM_LOCALID"]) + world_rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NTASKS"]) + # Otherwise setup for just 1 process, this is nice for testing + else: + local_rank = 0 + world_rank = 0 + world_size = 1 + + tcp_store = distrib.TCPStore( + master_addr, master_port, world_size, world_rank == 0 + ) + distrib.init_process_group( + backend, store=tcp_store, rank=world_rank, world_size=world_size + ) + + return local_rank, tcp_store diff --git a/habitat_baselines/rl/ddppo/algo/ddppo.py b/habitat_baselines/rl/ddppo/algo/ddppo.py new file mode 100644 index 0000000000000000000000000000000000000000..4da9e8d6c3c61c20f0d097710f667b05f8a3734d --- /dev/null +++ b/habitat_baselines/rl/ddppo/algo/ddppo.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +import torch.distributed as distrib + +from habitat_baselines.common.rollout_storage import RolloutStorage +from habitat_baselines.rl.ppo import PPO + +EPS_PPO = 1e-5 + + +def distributed_mean_and_var( + values: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Computes the mean and variances of a tensor over multiple workers. + + This method is equivalent to first collecting all versions of values and + then computing the mean and variance locally over that + + :param values: (*,) shaped tensors to compute mean and variance over. Assumed + to be solely the workers local copy of this tensor, + the resultant mean and variance will be computed + over _all_ workers version of this tensor. + """ + assert distrib.is_initialized(), "Distributed must be initialized" + + world_size = distrib.get_world_size() + mean = values.mean() + distrib.all_reduce(mean) + mean /= world_size + + sq_diff = (values - mean).pow(2).mean() + distrib.all_reduce(sq_diff) + var = sq_diff / world_size + + return mean, var + + +class DecentralizedDistributedMixin: + def _get_advantages_distributed( + self, rollouts: RolloutStorage + ) -> torch.Tensor: + advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1] + if not self.use_normalized_advantage: + return advantages + + mean, var = distributed_mean_and_var(advantages) + + return (advantages - mean) / (var.sqrt() + EPS_PPO) + + def init_distributed(self, find_unused_params: bool = True) -> None: + r"""Initializes distributed training for the model + + 1. Broadcasts the model weights from world_rank 0 to all other workers + 2. Adds gradient hooks to the model + + :param find_unused_params: Whether or not to filter out unused parameters + before gradient reduction. This *must* be True if + there are any parameters in the model that where unused in the + forward pass, otherwise the gradient reduction + will not work correctly. + """ + # NB: Used to hide the hooks from the nn.Module, + # so they don't show up in the state_dict + class Guard: + def __init__(self, model, device): + if torch.cuda.is_available(): + self.ddp = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[device], output_device=device + ) + else: + self.ddp = torch.nn.parallel.DistributedDataParallel(model) + + self._ddp_hooks = Guard(self.actor_critic, self.device) + self.get_advantages = self._get_advantages_distributed + + self.reducer = self._ddp_hooks.ddp.reducer + self.find_unused_params = find_unused_params + + def before_backward(self, loss): + super().before_backward(loss) + + if self.find_unused_params: + self.reducer.prepare_for_backward([loss]) + else: + self.reducer.prepare_for_backward([]) + + +# Mixin goes second that way the PPO __init__ will still be called +class DDPPO(PPO, DecentralizedDistributedMixin): + pass diff --git a/habitat_baselines/rl/ddppo/algo/ddppo_trainer.py b/habitat_baselines/rl/ddppo/algo/ddppo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..838f08ca66bef3ba7cb861fce6cd4dc802ec5a91 --- /dev/null +++ b/habitat_baselines/rl/ddppo/algo/ddppo_trainer.py @@ -0,0 +1,443 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import os +import random +import time +from collections import deque +from typing import Dict, List + +import numpy as np +import torch +import torch.distributed as distrib +import torch.nn as nn +from gym import spaces +from gym.spaces.dict_space import Dict as SpaceDict +from torch.optim.lr_scheduler import LambdaLR + +from habitat import Config, logger +from habitat_baselines.common.baseline_registry import baseline_registry +from habitat_baselines.common.env_utils import construct_envs +from habitat_baselines.common.environments import get_env_class +from habitat_baselines.common.rollout_storage import RolloutStorage +from habitat_baselines.common.tensorboard_utils import TensorboardWriter +from habitat_baselines.common.utils import batch_obs, linear_decay +from habitat_baselines.rl.ddppo.algo.ddp_utils import ( + EXIT, + REQUEUE, + add_signal_handlers, + init_distrib_slurm, + load_interrupted_state, + requeue_job, + save_interrupted_state, +) +from habitat_baselines.rl.ddppo.algo.ddppo import DDPPO +from habitat_baselines.rl.ddppo.policy.resnet_policy import ( + PointNavResNetPolicy, +) +from habitat_baselines.rl.ppo.ppo_trainer import PPOTrainer + + +@baseline_registry.register_trainer(name="ddppo") +class DDPPOTrainer(PPOTrainer): + # DD-PPO cuts rollouts short to mitigate the straggler effect + # This, in theory, can cause some rollouts to be very short. + # All rollouts contributed equally to the loss/model-update, + # thus very short rollouts can be problematic. This threshold + # limits the how short a short rollout can be as a fraction of the + # max rollout length + SHORT_ROLLOUT_THRESHOLD: float = 0.25 + + def __init__(self, config=None): + interrupted_state = load_interrupted_state() + if interrupted_state is not None: + config = interrupted_state["config"] + + super().__init__(config) + + def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None: + r"""Sets up actor critic and agent for DD-PPO. + + Args: + ppo_cfg: config node with relevant params + + Returns: + None + """ + logger.add_filehandler(self.config.LOG_FILE) + + self.actor_critic = PointNavResNetPolicy( + observation_space=self.envs.observation_spaces[0], + action_space=self.envs.action_spaces[0], + hidden_size=ppo_cfg.hidden_size, + rnn_type=self.config.RL.DDPPO.rnn_type, + num_recurrent_layers=self.config.RL.DDPPO.num_recurrent_layers, + backbone=self.config.RL.DDPPO.backbone, + goal_sensor_uuid=self.config.TASK_CONFIG.TASK.GOAL_SENSOR_UUID, + normalize_visual_inputs="rgb" + in self.envs.observation_spaces[0].spaces, + ) + self.actor_critic.to(self.device) + + if ( + self.config.RL.DDPPO.pretrained_encoder + or self.config.RL.DDPPO.pretrained + ): + pretrained_state = torch.load( + self.config.RL.DDPPO.pretrained_weights, map_location="cpu" + ) + + if self.config.RL.DDPPO.pretrained: + self.actor_critic.load_state_dict( + { + k[len("actor_critic.") :]: v + for k, v in pretrained_state["state_dict"].items() + } + ) + elif self.config.RL.DDPPO.pretrained_encoder: + prefix = "actor_critic.net.visual_encoder." + self.actor_critic.net.visual_encoder.load_state_dict( + { + k[len(prefix) :]: v + for k, v in pretrained_state["state_dict"].items() + if k.startswith(prefix) + } + ) + + if not self.config.RL.DDPPO.train_encoder: + self._static_encoder = True + for param in self.actor_critic.net.visual_encoder.parameters(): + param.requires_grad_(False) + + if self.config.RL.DDPPO.reset_critic: + nn.init.orthogonal_(self.actor_critic.critic.fc.weight) + nn.init.constant_(self.actor_critic.critic.fc.bias, 0) + + self.agent = DDPPO( + actor_critic=self.actor_critic, + clip_param=ppo_cfg.clip_param, + ppo_epoch=ppo_cfg.ppo_epoch, + num_mini_batch=ppo_cfg.num_mini_batch, + value_loss_coef=ppo_cfg.value_loss_coef, + entropy_coef=ppo_cfg.entropy_coef, + lr=ppo_cfg.lr, + eps=ppo_cfg.eps, + max_grad_norm=ppo_cfg.max_grad_norm, + use_normalized_advantage=ppo_cfg.use_normalized_advantage, + ) + + def train(self) -> None: + r"""Main method for DD-PPO. + + Returns: + None + """ + self.local_rank, tcp_store = init_distrib_slurm( + self.config.RL.DDPPO.distrib_backend + ) + add_signal_handlers() + + # Stores the number of workers that have finished their rollout + num_rollouts_done_store = distrib.PrefixStore( + "rollout_tracker", tcp_store + ) + num_rollouts_done_store.set("num_done", "0") + + self.world_rank = distrib.get_rank() + self.world_size = distrib.get_world_size() + + random.seed(self.config.TASK_CONFIG.SEED + self.world_rank) + np.random.seed(self.config.TASK_CONFIG.SEED + self.world_rank) + + self.config.defrost() + self.config.TORCH_GPU_ID = self.local_rank + self.config.SIMULATOR_GPU_ID = self.local_rank + self.config.freeze() + + if torch.cuda.is_available(): + self.device = torch.device("cuda", self.local_rank) + torch.cuda.set_device(self.device) + else: + self.device = torch.device("cpu") + + self.envs = construct_envs( + self.config, get_env_class(self.config.ENV_NAME) + ) + + ppo_cfg = self.config.RL.PPO + if ( + not os.path.isdir(self.config.CHECKPOINT_FOLDER) + and self.world_rank == 0 + ): + os.makedirs(self.config.CHECKPOINT_FOLDER) + + self._setup_actor_critic_agent(ppo_cfg) + self.agent.init_distributed(find_unused_params=True) + + if self.world_rank == 0: + logger.info( + "agent number of trainable parameters: {}".format( + sum( + param.numel() + for param in self.agent.parameters() + if param.requires_grad + ) + ) + ) + + observations = self.envs.reset() + batch = batch_obs(observations) + + obs_space = self.envs.observation_spaces[0] + if self._static_encoder: + self._encoder = self.actor_critic.net.visual_encoder + obs_space = SpaceDict( + { + "visual_features": spaces.Box( + low=np.finfo(np.float32).min, + high=np.finfo(np.float32).max, + shape=self._encoder.output_shape, + dtype=np.float32, + ), + **obs_space.spaces, + } + ) + with torch.no_grad(): + batch["visual_features"] = self._encoder(batch) + + rollouts = RolloutStorage( + ppo_cfg.num_steps, + self.envs.num_envs, + obs_space, + self.envs.action_spaces[0], + ppo_cfg.hidden_size, + num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, + ) + rollouts.to(self.device) + + for sensor in rollouts.observations: + rollouts.observations[sensor][0].copy_(batch[sensor]) + + # batch and observations may contain shared PyTorch CUDA + # tensors. We must explicitly clear them here otherwise + # they will be kept in memory for the entire duration of training! + batch = None + observations = None + + episode_rewards = torch.zeros( + self.envs.num_envs, 1, device=self.device + ) + episode_counts = torch.zeros(self.envs.num_envs, 1, device=self.device) + current_episode_reward = torch.zeros( + self.envs.num_envs, 1, device=self.device + ) + window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size) + window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size) + + t_start = time.time() + env_time = 0 + pth_time = 0 + count_steps = 0 + count_checkpoints = 0 + start_update = 0 + prev_time = 0 + + lr_scheduler = LambdaLR( + optimizer=self.agent.optimizer, + lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES), + ) + + interrupted_state = load_interrupted_state() + if interrupted_state is not None: + self.agent.load_state_dict(interrupted_state["state_dict"]) + self.agent.optimizer.load_state_dict( + interrupted_state["optim_state"] + ) + lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"]) + + requeue_stats = interrupted_state["requeue_stats"] + env_time = requeue_stats["env_time"] + pth_time = requeue_stats["pth_time"] + count_steps = requeue_stats["count_steps"] + count_checkpoints = requeue_stats["count_checkpoints"] + start_update = requeue_stats["start_update"] + prev_time = requeue_stats["prev_time"] + + with ( + TensorboardWriter( + self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs + ) + if self.world_rank == 0 + else contextlib.suppress() + ) as writer: + for update in range(start_update, self.config.NUM_UPDATES): + if ppo_cfg.use_linear_lr_decay: + lr_scheduler.step() + + if ppo_cfg.use_linear_clip_decay: + self.agent.clip_param = ppo_cfg.clip_param * linear_decay( + update, self.config.NUM_UPDATES + ) + + if EXIT.is_set(): + self.envs.close() + + if REQUEUE.is_set() and self.world_rank == 0: + requeue_stats = dict( + env_time=env_time, + pth_time=pth_time, + count_steps=count_steps, + count_checkpoints=count_checkpoints, + start_update=update, + prev_time=(time.time() - t_start) + prev_time, + ) + save_interrupted_state( + dict( + state_dict=self.agent.state_dict(), + optim_state=self.agent.optimizer.state_dict(), + lr_sched_state=lr_scheduler.state_dict(), + config=self.config, + requeue_stats=requeue_stats, + ) + ) + + requeue_job() + return + + count_steps_delta = 0 + self.agent.eval() + for step in range(ppo_cfg.num_steps): + + ( + delta_pth_time, + delta_env_time, + delta_steps, + ) = self._collect_rollout_step( + rollouts, + current_episode_reward, + episode_rewards, + episode_counts, + ) + pth_time += delta_pth_time + env_time += delta_env_time + count_steps_delta += delta_steps + + # This is where the preemption of workers happens. If a + # worker detects it will be a straggler, it preempts itself! + if ( + step + >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD + ) and int(num_rollouts_done_store.get("num_done")) > ( + self.config.RL.DDPPO.sync_frac * self.world_size + ): + break + + num_rollouts_done_store.add("num_done", 1) + + self.agent.train() + if self._static_encoder: + self._encoder.eval() + + ( + delta_pth_time, + value_loss, + action_loss, + dist_entropy, + ) = self._update_agent(ppo_cfg, rollouts) + pth_time += delta_pth_time + + stats = torch.stack([episode_rewards, episode_counts], 0) + distrib.all_reduce(stats) + + window_episode_reward.append(stats[0].clone()) + window_episode_counts.append(stats[1].clone()) + + stats = torch.tensor( + [value_loss, action_loss, count_steps_delta], + device=self.device, + ) + distrib.all_reduce(stats) + count_steps += stats[2].item() + + if self.world_rank == 0: + num_rollouts_done_store.set("num_done", "0") + + losses = [ + stats[0].item() / self.world_size, + stats[1].item() / self.world_size, + ] + stats = zip( + ["count", "reward"], + [window_episode_counts, window_episode_reward], + ) + deltas = { + k: ( + (v[-1] - v[0]).sum().item() + if len(v) > 1 + else v[0].sum().item() + ) + for k, v in stats + } + deltas["count"] = max(deltas["count"], 1.0) + + writer.add_scalar( + "reward", + deltas["reward"] / deltas["count"], + count_steps, + ) + + writer.add_scalars( + "losses", + {k: l for l, k in zip(losses, ["value", "policy"])}, + count_steps, + ) + + # log stats + if update > 0 and update % self.config.LOG_INTERVAL == 0: + logger.info( + "update: {}\tfps: {:.3f}\t".format( + update, + count_steps + / ((time.time() - t_start) + prev_time), + ) + ) + + logger.info( + "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t" + "frames: {}".format( + update, env_time, pth_time, count_steps + ) + ) + + window_rewards = ( + window_episode_reward[-1] + - window_episode_reward[0] + ).sum() + window_counts = ( + window_episode_counts[-1] + - window_episode_counts[0] + ).sum() + + if window_counts > 0: + logger.info( + "Average window size {} reward: {:3f}".format( + len(window_episode_reward), + (window_rewards / window_counts).item(), + ) + ) + else: + logger.info("No episodes finish in current window") + + # checkpoint model + if update % self.config.CHECKPOINT_INTERVAL == 0: + self.save_checkpoint( + f"ckpt.{count_checkpoints}.pth", + dict(step=count_steps), + ) + count_checkpoints += 1 + + self.envs.close() diff --git a/habitat_baselines/rl/ddppo/multi_node_slurm.sh b/habitat_baselines/rl/ddppo/multi_node_slurm.sh new file mode 100644 index 0000000000000000000000000000000000000000..e2622e025a69bcd5813388db45bbeaa392deb5bf --- /dev/null +++ b/habitat_baselines/rl/ddppo/multi_node_slurm.sh @@ -0,0 +1,22 @@ +#!/bin/bash +#SBATCH --job-name=ddppo +#SBATCH --output=logs.ddppo.out +#SBATCH --error=logs.ddppo.err +#SBATCH --gres gpu:1 +#SBATCH --nodes 1 +#SBATCH --cpus-per-task 10 +#SBATCH --ntasks-per-node 1 +#SBATCH --mem=60GB +#SBATCH --time=12:00 +#SBATCH --signal=USR1@600 +#SBATCH --partition=dev + +export GLOG_minloglevel=2 +export MAGNUM_LOG=quiet + +export MASTER_ADDR=$(srun --ntasks=1 hostname 2>&1 | tail -n1) + +set -x +srun python -u -m habitat_baselines.run \ + --exp-config habitat_baselines/config/pointnav/ddppo_pointnav.yaml \ + --run-type train diff --git a/habitat_baselines/rl/ddppo/policy/__init__.py b/habitat_baselines/rl/ddppo/policy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..48d7c9e4397f6645d17b582e10f291a6d9075289 --- /dev/null +++ b/habitat_baselines/rl/ddppo/policy/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from .resnet_policy import PointNavResNetPolicy diff --git a/habitat_baselines/rl/ddppo/policy/resnet.py b/habitat_baselines/rl/ddppo/policy/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9dc894b6149c6ead63298ffac69e5f6ed83a722a --- /dev/null +++ b/habitat_baselines/rl/ddppo/policy/resnet.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +def conv3x3(in_planes, out_planes, stride=1, groups=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + groups=groups, + ) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False + ) + + +class BasicBlock(nn.Module): + expansion = 1 + resneXt = False + + def __init__( + self, + inplanes, + planes, + ngroups, + stride=1, + downsample=None, + cardinality=1, + ): + super(BasicBlock, self).__init__() + self.convs = nn.Sequential( + conv3x3(inplanes, planes, stride, groups=cardinality), + nn.GroupNorm(ngroups, planes), + nn.ReLU(True), + conv3x3(planes, planes, groups=cardinality), + nn.GroupNorm(ngroups, planes), + ) + self.downsample = downsample + self.relu = nn.ReLU(True) + + def forward(self, x): + residual = x + + out = self.convs(x) + + if self.downsample is not None: + residual = self.downsample(x) + + return self.relu(out + residual) + + +def _build_bottleneck_branch( + inplanes, planes, ngroups, stride, expansion, groups=1 +): + return nn.Sequential( + conv1x1(inplanes, planes), + nn.GroupNorm(ngroups, planes), + nn.ReLU(True), + conv3x3(planes, planes, stride, groups=groups), + nn.GroupNorm(ngroups, planes), + nn.ReLU(True), + conv1x1(planes, planes * expansion), + nn.GroupNorm(ngroups, planes * expansion), + ) + + +class SE(nn.Module): + def __init__(self, planes, r=16): + super().__init__() + self.squeeze = nn.AdaptiveAvgPool2d(1) + self.excite = nn.Sequential( + nn.Linear(planes, int(planes / r)), + nn.ReLU(True), + nn.Linear(int(planes / r), planes), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + x = self.squeeze(x) + x = x.view(b, c) + x = self.excite(x) + + return x.view(b, c, 1, 1) + + +def _build_se_branch(planes, r=16): + return SE(planes, r) + + +class Bottleneck(nn.Module): + expansion = 4 + resneXt = False + + def __init__( + self, + inplanes, + planes, + ngroups, + stride=1, + downsample=None, + cardinality=1, + ): + super().__init__() + self.convs = _build_bottleneck_branch( + inplanes, + planes, + ngroups, + stride, + self.expansion, + groups=cardinality, + ) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + def _impl(self, x): + identity = x + + out = self.convs(x) + + if self.downsample is not None: + identity = self.downsample(x) + + return self.relu(out + identity) + + def forward(self, x): + return self._impl(x) + + +class SEBottleneck(Bottleneck): + def __init__( + self, + inplanes, + planes, + ngroups, + stride=1, + downsample=None, + cardinality=1, + ): + super().__init__( + inplanes, planes, ngroups, stride, downsample, cardinality + ) + + self.se = _build_se_branch(planes * self.expansion) + + def _impl(self, x): + identity = x + + out = self.convs(x) + out = self.se(out) * out + + if self.downsample is not None: + identity = self.downsample(x) + + return self.relu(out + identity) + + +class SEResNeXtBottleneck(SEBottleneck): + expansion = 2 + resneXt = True + + +class ResNeXtBottleneck(Bottleneck): + expansion = 2 + resneXt = True + + +class ResNet(nn.Module): + def __init__( + self, in_channels, base_planes, ngroups, block, layers, cardinality=1 + ): + super(ResNet, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d( + in_channels, + base_planes, + kernel_size=7, + stride=2, + padding=3, + bias=False, + ), + nn.GroupNorm(ngroups, base_planes), + nn.ReLU(True), + ) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.cardinality = cardinality + + self.inplanes = base_planes + if block.resneXt: + base_planes *= 2 + + self.layer1 = self._make_layer(block, ngroups, base_planes, layers[0]) + self.layer2 = self._make_layer( + block, ngroups, base_planes * 2, layers[1], stride=2 + ) + self.layer3 = self._make_layer( + block, ngroups, base_planes * 2 * 2, layers[2], stride=2 + ) + self.layer4 = self._make_layer( + block, ngroups, base_planes * 2 * 2 * 2, layers[3], stride=2 + ) + + self.final_channels = self.inplanes + self.final_spatial_compress = 1.0 / (2 ** 5) + + def _make_layer(self, block, ngroups, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.GroupNorm(ngroups, planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + ngroups, + stride, + downsample, + cardinality=self.cardinality, + ) + ) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, ngroups)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + return x + + +def resnet18(in_channels, base_planes, ngroups): + model = ResNet(in_channels, base_planes, ngroups, BasicBlock, [2, 2, 2, 2]) + + return model + + +def resnet50(in_channels, base_planes, ngroups): + model = ResNet(in_channels, base_planes, ngroups, Bottleneck, [3, 4, 6, 3]) + + return model + + +def resneXt50(in_channels, base_planes, ngroups): + model = ResNet( + in_channels, + base_planes, + ngroups, + ResNeXtBottleneck, + [3, 4, 6, 3], + cardinality=int(base_planes / 2), + ) + + return model + + +def se_resnet50(in_channels, base_planes, ngroups): + model = ResNet( + in_channels, base_planes, ngroups, SEBottleneck, [3, 4, 6, 3] + ) + + return model + + +def se_resneXt50(in_channels, base_planes, ngroups): + model = ResNet( + in_channels, + base_planes, + ngroups, + SEResNeXtBottleneck, + [3, 4, 6, 3], + cardinality=int(base_planes / 2), + ) + + return model + + +def se_resneXt101(in_channels, base_planes, ngroups): + model = ResNet( + in_channels, + base_planes, + ngroups, + SEResNeXtBottleneck, + [3, 4, 23, 3], + cardinality=int(base_planes / 2), + ) + + return model diff --git a/habitat_baselines/rl/ddppo/policy/resnet_policy.py b/habitat_baselines/rl/ddppo/policy/resnet_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..06db51c42a54ce87e0fa412fc868ca8e3ac8f30d --- /dev/null +++ b/habitat_baselines/rl/ddppo/policy/resnet_policy.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from habitat_baselines.common.utils import CategoricalNet, Flatten +from habitat_baselines.rl.ddppo.policy import resnet +from habitat_baselines.rl.ddppo.policy.running_mean_and_var import ( + RunningMeanAndVar, +) +from habitat_baselines.rl.models.rnn_state_encoder import RNNStateEncoder +from habitat_baselines.rl.ppo import Net, Policy + + +class PointNavResNetPolicy(Policy): + def __init__( + self, + observation_space, + action_space, + goal_sensor_uuid="pointgoal_with_gps_compass", + hidden_size=512, + num_recurrent_layers=2, + rnn_type="LSTM", + resnet_baseplanes=32, + backbone="resnet50", + normalize_visual_inputs=False, + ): + super().__init__( + PointNavResNetNet( + observation_space=observation_space, + action_space=action_space, + goal_sensor_uuid=goal_sensor_uuid, + hidden_size=hidden_size, + num_recurrent_layers=num_recurrent_layers, + rnn_type=rnn_type, + backbone=backbone, + resnet_baseplanes=resnet_baseplanes, + normalize_visual_inputs=normalize_visual_inputs, + ), + action_space.n, + ) + + +class ResNetEncoder(nn.Module): + def __init__( + self, + observation_space, + baseplanes=32, + ngroups=32, + spatial_size=128, + make_backbone=None, + normalize_visual_inputs=False, + ): + super().__init__() + + if "rgb" in observation_space.spaces: + self._n_input_rgb = observation_space.spaces["rgb"].shape[2] + spatial_size = observation_space.spaces["rgb"].shape[0] // 2 + else: + self._n_input_rgb = 0 + + if "depth" in observation_space.spaces: + self._n_input_depth = observation_space.spaces["depth"].shape[2] + spatial_size = observation_space.spaces["depth"].shape[0] // 2 + else: + self._n_input_depth = 0 + + if normalize_visual_inputs: + self.running_mean_and_var = RunningMeanAndVar( + self._n_input_depth + self._n_input_rgb + ) + else: + self.running_mean_and_var = nn.Sequential() + + if not self.is_blind: + input_channels = self._n_input_depth + self._n_input_rgb + self.backbone = make_backbone(input_channels, baseplanes, ngroups) + + final_spatial = int( + spatial_size * self.backbone.final_spatial_compress + ) + after_compression_flat_size = 2048 + num_compression_channels = int( + round(after_compression_flat_size / (final_spatial ** 2)) + ) + self.compression = nn.Sequential( + nn.Conv2d( + self.backbone.final_channels, + num_compression_channels, + kernel_size=3, + padding=1, + bias=False, + ), + nn.GroupNorm(1, num_compression_channels), + nn.ReLU(True), + ) + + self.output_shape = ( + num_compression_channels, + final_spatial, + final_spatial, + ) + + @property + def is_blind(self): + return self._n_input_rgb + self._n_input_depth == 0 + + def layer_init(self): + for layer in self.modules(): + if isinstance(layer, (nn.Conv2d, nn.Linear)): + nn.init.kaiming_normal_( + layer.weight, nn.init.calculate_gain("relu") + ) + if layer.bias is not None: + nn.init.constant_(layer.bias, val=0) + + def forward(self, observations): + if self.is_blind: + return None + + cnn_input = [] + if self._n_input_rgb > 0: + rgb_observations = observations["rgb"] + # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] + rgb_observations = rgb_observations.permute(0, 3, 1, 2) + rgb_observations = rgb_observations / 255.0 # normalize RGB + cnn_input.append(rgb_observations) + + if self._n_input_depth > 0: + depth_observations = observations["depth"] + + # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] + depth_observations = depth_observations.permute(0, 3, 1, 2) + + cnn_input.append(depth_observations) + + x = torch.cat(cnn_input, dim=1) + x = F.avg_pool2d(x, 2) + + x = self.running_mean_and_var(x) + x = self.backbone(x) + x = self.compression(x) + return x + + +class PointNavResNetNet(Net): + """Network which passes the input image through CNN and concatenates + goal vector with CNN's output and passes that through RNN. + """ + + def __init__( + self, + observation_space, + action_space, + goal_sensor_uuid, + hidden_size, + num_recurrent_layers, + rnn_type, + backbone, + resnet_baseplanes, + normalize_visual_inputs, + ): + super().__init__() + self.goal_sensor_uuid = goal_sensor_uuid + + self.prev_action_embedding = nn.Embedding(action_space.n + 1, 32) + self._n_prev_action = 32 + + self._n_input_goal = ( + observation_space.spaces[self.goal_sensor_uuid].shape[0] + 1 + ) + self.tgt_embeding = nn.Linear(self._n_input_goal, 32) + self._n_input_goal = 32 + + self._hidden_size = hidden_size + + rnn_input_size = self._n_input_goal + self._n_prev_action + self.visual_encoder = ResNetEncoder( + observation_space, + baseplanes=resnet_baseplanes, + ngroups=resnet_baseplanes // 2, + make_backbone=getattr(resnet, backbone), + normalize_visual_inputs=normalize_visual_inputs, + ) + + if not self.visual_encoder.is_blind: + self.visual_fc = nn.Sequential( + Flatten(), + nn.Linear( + np.prod(self.visual_encoder.output_shape), hidden_size + ), + nn.ReLU(True), + ) + + self.state_encoder = RNNStateEncoder( + (0 if self.is_blind else self._hidden_size) + rnn_input_size, + self._hidden_size, + rnn_type=rnn_type, + num_layers=num_recurrent_layers, + ) + + self.train() + + @property + def output_size(self): + return self._hidden_size + + @property + def is_blind(self): + return self.visual_encoder.is_blind + + @property + def num_recurrent_layers(self): + return self.state_encoder.num_recurrent_layers + + def get_tgt_encoding(self, observations): + goal_observations = observations[self.goal_sensor_uuid] + goal_observations = torch.stack( + [ + goal_observations[:, 0], + torch.cos(-goal_observations[:, 1]), + torch.sin(-goal_observations[:, 1]), + ], + -1, + ) + + return self.tgt_embeding(goal_observations) + + def forward(self, observations, rnn_hidden_states, prev_actions, masks): + x = [] + if not self.is_blind: + if "visual_features" in observations: + visual_feats = observations["visual_features"] + else: + visual_feats = self.visual_encoder(observations) + + visual_feats = self.visual_fc(visual_feats) + x.append(visual_feats) + + tgt_encoding = self.get_tgt_encoding(observations) + prev_actions = self.prev_action_embedding( + ((prev_actions.float() + 1) * masks).long().squeeze(-1) + ) + + x += [tgt_encoding, prev_actions] + + x = torch.cat(x, dim=1) + x, rnn_hidden_states = self.state_encoder(x, rnn_hidden_states, masks) + + return x, rnn_hidden_states diff --git a/habitat_baselines/rl/ddppo/policy/running_mean_and_var.py b/habitat_baselines/rl/ddppo/policy/running_mean_and_var.py new file mode 100644 index 0000000000000000000000000000000000000000..f03533be25376547ed215ce27b52fdf927678c3b --- /dev/null +++ b/habitat_baselines/rl/ddppo/policy/running_mean_and_var.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as distrib +import torch.nn as nn +import torch.nn.functional as F + + +class RunningMeanAndVar(nn.Module): + def __init__(self, n_channels): + super().__init__() + self.register_buffer("_mean", torch.zeros(1, n_channels, 1, 1)) + self.register_buffer("_var", torch.zeros(1, n_channels, 1, 1)) + self.register_buffer("_count", torch.zeros(())) + + self._distributed = distrib.is_initialized() + + def forward(self, x): + if self.training: + new_mean = F.adaptive_avg_pool2d(x, 1).sum(0, keepdim=True) + new_count = torch.full_like(self._count, x.size(0)) + + if self._distributed: + distrib.all_reduce(new_mean) + distrib.all_reduce(new_count) + + new_mean /= new_count + + new_var = F.adaptive_avg_pool2d((x - new_mean).pow(2), 1).sum( + 0, keepdim=True + ) + + if self._distributed: + distrib.all_reduce(new_var) + + # No - 1 on all the variance as the number of pixels + # seen over training is simply absurd, so it doesn't matter + new_var /= new_count + + m_a = self._var * (self._count) + m_b = new_var * (new_count) + M2 = ( + m_a + + m_b + + (new_mean - self._mean).pow(2) + * self._count + * new_count + / (self._count + new_count) + ) + + self._var = M2 / (self._count + new_count) + self._mean = (self._count * self._mean + new_count * new_mean) / ( + self._count + new_count + ) + + self._count += new_count + + stdev = torch.sqrt( + torch.max(self._var, torch.full_like(self._var, 1e-2)) + ) + return (x - self._mean) / stdev diff --git a/habitat_baselines/rl/ddppo/requirements.txt b/habitat_baselines/rl/ddppo/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0c9b5857317384d8830b1ff11a285d0a3253e9a0 --- /dev/null +++ b/habitat_baselines/rl/ddppo/requirements.txt @@ -0,0 +1 @@ +ifcfg diff --git a/habitat_baselines/rl/ddppo/single_node.sh b/habitat_baselines/rl/ddppo/single_node.sh new file mode 100644 index 0000000000000000000000000000000000000000..7aba0afc87d70015c6595791a3be7337f653a2e6 --- /dev/null +++ b/habitat_baselines/rl/ddppo/single_node.sh @@ -0,0 +1,12 @@ +#/bin/bash + +export GLOG_minloglevel=2 +export MAGNUM_LOG=quiet + +set -x +python -u -m torch.distributed.launch \ + --use_env \ + --nproc_per_node 1 \ + habitat_baselines/run.py \ + --exp-config habitat_baselines/config/pointnav/ddppo_pointnav.yaml \ + --run-type train diff --git a/habitat_baselines/rl/ppo/__init__.py b/habitat_baselines/rl/ppo/__init__.py index febc3fd73942a8a403503ff26bda8a7a4f30d260..641d742d4a8f65f07efda88affb4254a025ce6d2 100644 --- a/habitat_baselines/rl/ppo/__init__.py +++ b/habitat_baselines/rl/ppo/__init__.py @@ -7,4 +7,4 @@ from habitat_baselines.rl.ppo.policy import Net, PointNavBaselinePolicy, Policy from habitat_baselines.rl.ppo.ppo import PPO -__all__ = ["PPO", "Policy", "Net", "PointNavBaselinePolicy"] +__all__ = ["PPO", "Policy", "RolloutStorage", "Net", "PointNavBaselinePolicy"] diff --git a/habitat_baselines/rl/ppo/ppo.py b/habitat_baselines/rl/ppo/ppo.py index 85bfd9e621d62375ee58155068283c9914d95bb4..c5f74266a4e5a6b526d7a5d6b2099e2b7d382a68 100644 --- a/habitat_baselines/rl/ppo/ppo.py +++ b/habitat_baselines/rl/ppo/ppo.py @@ -41,7 +41,11 @@ class PPO(nn.Module): self.max_grad_norm = max_grad_norm self.use_clipped_value_loss = use_clipped_value_loss - self.optimizer = optim.Adam(actor_critic.parameters(), lr=lr, eps=eps) + self.optimizer = optim.Adam( + list(filter(lambda p: p.requires_grad, actor_critic.parameters())), + lr=lr, + eps=eps, + ) self.device = next(actor_critic.parameters()).device self.use_normalized_advantage = use_normalized_advantage diff --git a/habitat_baselines/rl/ppo/ppo_trainer.py b/habitat_baselines/rl/ppo/ppo_trainer.py index c9bf606de11033e5a8f4bd0516240d9fb57ca2b3..76e6d3b0ecfdffe52ccac48d36a29076b12a794b 100644 --- a/habitat_baselines/rl/ppo/ppo_trainer.py +++ b/habitat_baselines/rl/ppo/ppo_trainer.py @@ -7,7 +7,7 @@ import os import time from collections import deque -from typing import Dict, List +from typing import Dict, List, Optional import numpy as np import torch @@ -44,6 +44,9 @@ class PPOTrainer(BaseRLTrainer): if config is not None: logger.info(f"config: {config}") + self._static_encoder = False + self._encoder = None + def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None: r"""Sets up actor critic and agent for PPO. @@ -73,9 +76,12 @@ class PPOTrainer(BaseRLTrainer): lr=ppo_cfg.lr, eps=ppo_cfg.eps, max_grad_norm=ppo_cfg.max_grad_norm, + use_normalized_advantage=ppo_cfg.use_normalized_advantage, ) - def save_checkpoint(self, file_name: str) -> None: + def save_checkpoint( + self, file_name: str, extra_state: Optional[Dict] = None + ) -> None: r"""Save checkpoint with specified name. Args: @@ -88,6 +94,9 @@ class PPOTrainer(BaseRLTrainer): "state_dict": self.agent.state_dict(), "config": self.config, } + if extra_state is not None: + checkpoint["extra_state"] = extra_state + torch.save( checkpoint, os.path.join(self.config.CHECKPOINT_FOLDER, file_name) ) @@ -141,11 +150,15 @@ class PPOTrainer(BaseRLTrainer): t_update_stats = time.time() batch = batch_obs(observations) - rewards = torch.tensor(rewards, dtype=torch.float) + rewards = torch.tensor( + rewards, dtype=torch.float, device=episode_rewards.device + ) rewards = rewards.unsqueeze(1) masks = torch.tensor( - [[0.0] if done else [1.0] for done in dones], dtype=torch.float + [[0.0] if done else [1.0] for done in dones], + dtype=torch.float, + device=episode_rewards.device, ) current_episode_reward += rewards @@ -153,6 +166,10 @@ class PPOTrainer(BaseRLTrainer): episode_counts += 1 - masks current_episode_reward *= masks + if self._static_encoder: + with torch.no_grad(): + batch["visual_features"] = self._encoder(batch) + rollouts.insert( batch, recurrent_hidden_states, @@ -171,13 +188,13 @@ class PPOTrainer(BaseRLTrainer): t_update_model = time.time() with torch.no_grad(): last_observation = { - k: v[-1] for k, v in rollouts.observations.items() + k: v[rollouts.step] for k, v in rollouts.observations.items() } next_value = self.actor_critic.get_value( last_observation, - rollouts.recurrent_hidden_states[-1], - rollouts.prev_actions[-1], - rollouts.masks[-1], + rollouts.recurrent_hidden_states[rollouts.step], + rollouts.prev_actions[rollouts.step], + rollouts.masks[rollouts.step], ).detach() rollouts.compute_returns( @@ -356,7 +373,9 @@ class PPOTrainer(BaseRLTrainer): # checkpoint model if update % self.config.CHECKPOINT_INTERVAL == 0: - self.save_checkpoint(f"ckpt.{count_checkpoints}.pth") + self.save_checkpoint( + f"ckpt.{count_checkpoints}.pth", dict(step=count_steps) + ) count_checkpoints += 1 self.envs.close() @@ -442,6 +461,7 @@ class PPOTrainer(BaseRLTrainer): if len(self.config.VIDEO_OPTION) > 0: os.makedirs(self.config.VIDEO_DIR, exist_ok=True) + self.actor_critic.eval() while ( len(stats_episodes) < self.config.TEST_EPISODE_COUNT and self.envs.num_envs > 0 @@ -565,20 +585,20 @@ class PPOTrainer(BaseRLTrainer): f"Average episode {self.metric_uuid}: {episode_metric_mean:.6f}" ) + step_id = checkpoint_index + if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]: + step_id = ckpt_dict["extra_state"]["step"] + writer.add_scalars( - "eval_reward", - {"average reward": episode_reward_mean}, - checkpoint_index, + "eval_reward", {"average reward": episode_reward_mean}, step_id ) writer.add_scalars( f"eval_{self.metric_uuid}", {f"average {self.metric_uuid}": episode_metric_mean}, - checkpoint_index, + step_id, ) writer.add_scalars( - "eval_success", - {"average success": episode_success_mean}, - checkpoint_index, + "eval_success", {"average success": episode_success_mean}, step_id ) self.envs.close() diff --git a/habitat_baselines/rl/requirements.txt b/habitat_baselines/rl/requirements.txt index 3c0a2cfc40e26fec7a6809da834938864cfb8442..c79d56b5e269e082b4ee3a8573819024b2100340 100644 --- a/habitat_baselines/rl/requirements.txt +++ b/habitat_baselines/rl/requirements.txt @@ -1,5 +1,5 @@ moviepy>=1.0.1 -torch==1.1.0 +torch>=1.3.1 # full tensorflow required for tensorboard video support tensorflow==1.13.1 tb-nightly diff --git a/test/test_baseline_trainers.py b/test/test_baseline_trainers.py index 4e1bcfbe3ce0f80b36207fac178188f87349bdc5..b86b958c7c2fe5555fd2f0bdd131ea0d94d1d890 100644 --- a/test/test_baseline_trainers.py +++ b/test/test_baseline_trainers.py @@ -12,6 +12,7 @@ import pytest try: import torch + import torch.distributed from habitat_baselines.run import run_exp from habitat_baselines.common.base_trainer import BaseRLTrainer @@ -49,6 +50,10 @@ def test_trainers(test_cfg_path, mode, gpu2gpu): ["TASK_CONFIG.SIMULATOR.HABITAT_SIM_V0.GPU_GPU", str(gpu2gpu)], ) + # Deinit processes group + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + @pytest.mark.skipif( not baseline_installed, reason="baseline sub-module not installed"