Skip to content
Snippets Groups Projects
Commit 31318f81 authored by JasonJiazhiZhang's avatar JasonJiazhiZhang Committed by Abhishek Kadian
Browse files

Refactor baselines: add generic trainer and common utils (#153)

* refactor and  add generic trainer class

* fix to pass tests

* change BaseModel to BaseTrainer

* fix tensorboard import causing CI failure

* modify circle-ci test script accordingly

* doc, typing and other changes

* rename BASELINE to TRAINER

* merge from upstream master

* Update Habitat-API to allow for no rendering sensors (#139)

Update Habitat-API to allow for no rendering sensors

* Added installation requirements step for sim installation in CI setup. (#159)

Added installation requirements step for sim installation in CI setup

* move RolloutStorage to utils

* add environments.py

* make ckpt dir if not exist

* small fixes according to comments

* changes according to comments

* update readme

* fix old config compatibility

* fix missed isort lint
parent 7eb3da12
No related branches found
No related tags found
No related merge requests found
Showing
with 1252 additions and 701 deletions
...@@ -196,7 +196,7 @@ jobs: ...@@ -196,7 +196,7 @@ jobs:
. activate habitat; cd habitat-api . activate habitat; cd habitat-api
python setup.py test python setup.py test
python setup.py develop --all python setup.py develop --all
python -u habitat_baselines/train_ppo.py --log-file "train.log" --checkpoint-folder "data/checkpoints" --sim-gpu-id 0 --pth-gpu-id 0 --num-processes 1 --num-mini-batch 1 --num-updates 10 python -u habitat_baselines/run.py --exp-config habitat_baselines/config/pointnav/ppo_train_test.yaml --run-type train
workflows: workflows:
......
BASELINE: TRAINER:
RL: RL:
SUCCESS_REWARD: 10.0 SUCCESS_REWARD: 10.0
SLACK_REWARD: -0.01 SLACK_REWARD: -0.01
...@@ -25,46 +25,18 @@ For training on sample data please follow steps in the repository README. You sh ...@@ -25,46 +25,18 @@ For training on sample data please follow steps in the repository README. You sh
**train**: **train**:
```bash ```bash
python -u habitat_baselines/train_ppo.py \ python -u habitat_baselines/run.py --exp-config habitat_baselines/config/pointnav/ppo.yaml --run-type train
--use-gae \
--sim-gpu-id 0 \
--pth-gpu-id 0 \
--lr 2.5e-4 \
--clip-param 0.1 \
--value-loss-coef 0.5 \
--num-processes 4 \
--num-steps 128 \
--num-mini-batch 4 \
--num-updates 100000 \
--use-linear-lr-decay \
--use-linear-clip-decay \
--entropy-coef 0.01 \
--log-file "train.log" \
--log-interval 5 \
--checkpoint-folder "data/checkpoints" \
--checkpoint-interval 50 \
--task-config "configs/tasks/pointnav.yaml" \
``` ```
**test**: **test**:
```bash ```bash
python -u habitat_baselines/evaluate_ppo.py \ python -u habitat_baselines/run.py --exp-config habitat_baselines/config/pointnav/ppo.yaml --run-type eval
--model-path "/path/to/checkpoint" \
--sim-gpu-id 0 \
--pth-gpu-id 0 \
--num-processes 4 \
--count-test-episodes 100 \
--task-config "configs/tasks/pointnav.yaml" \
``` ```
We also provide trained RGB, RGBD, Blind PPO models. We also provide trained RGB, RGBD, Blind PPO models.
To use them download pre-trained pytorch models from [link](https://dl.fbaipublicfiles.com/habitat/data/baselines/v1/habitat_baselines_v1.zip) and unzip and specify model path [here](agents/ppo_agents.py#L132). To use them download pre-trained pytorch models from [link](https://dl.fbaipublicfiles.com/habitat/data/baselines/v1/habitat_baselines_v1.zip) and unzip and specify model path [here](agents/ppo_agents.py#L132).
Set argument `--task-config` to `tasks/pointnav_mp3d.yaml` for training on [MatterPort3D point goal navigation dataset](/README.md#task-datasets). Change field `task_config` in `habitat_baselines/config/pointnav/ppo.yaml` to `tasks/pointnav_mp3d.yaml` for training on [MatterPort3D point goal navigation dataset](/README.md#task-datasets).
### Classic ### Classic
...@@ -74,26 +46,16 @@ Set argument `--task-config` to `tasks/pointnav_mp3d.yaml` for training on [Matt ...@@ -74,26 +46,16 @@ Set argument `--task-config` to `tasks/pointnav_mp3d.yaml` for training on [Matt
"Benchmarking Classic and Learned Navigation in Complex 3D Environments". "Benchmarking Classic and Learned Navigation in Complex 3D Environments".
### Additional Utilities ### Additional Utilities
**single-episode training**: **Episode iterator options**:
Algorithms can be trained with a single-episode option. This option can be used as a sanity check since good algorithms should overfit one episode relatively fast. To enable this option, add `DATASET.NUM_EPISODE_SAMPLE 1` *at the end* of the training command, or include the single-episode yaml file in `--task-config` like this: Coming very soon
```
--task-config "configs/tasks/pointnav.yaml,configs/datasets/single_episode.yaml"
```
**tensorboard and video generation support** **Tensorboard and video generation support**
Enable tensorboard logging by adding `--tensorboard-dir logdir` when running `train_ppo.py` or `evaluate_ppo.py` Enable tensorboard by changing `tensorboard_dir` field in `habitat_baselines/config/pointnav/ppo.yaml`.
Enable video generation for `evaluate_ppo.py` using `--video-option`: specifying `tensorboard`(for displaying on tensorboard) or `disk` (for saving videos on disk), for example: Enable video generation for `eval` mode by changing `video_option`: `tensorboard,disk` (for displaying on tensorboard and for saving videos on disk, respectively)
```
python -u habitat_baselines/evaluate_ppo.py Generated navigation episode recordings should look like this on tensorboard:
...
--count-test-episodes 2 \
--video-option tensorboard \
--tensorboard-dir tb_eval \
--model-path data/checkpoints/ckpt.xx.pth
```
The above command should generate navigation episode recordings and display them on tensorboard like this:
<p align="center"> <p align="center">
<img src="../res/img/tensorboard_video_demo.gif" height="500"> <img src="../res/img/tensorboard_video_demo.gif" height="500">
</p> </p>
...@@ -3,3 +3,8 @@ ...@@ -3,3 +3,8 @@
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from habitat_baselines.common.base_trainer import BaseRLTrainer, BaseTrainer
from habitat_baselines.rl.ppo.ppo_trainer import PPOTrainer, RolloutStorage
__all__ = ["BaseTrainer", "BaseRLTrainer", "PPOTrainer", "RolloutStorage"]
...@@ -16,8 +16,8 @@ import habitat ...@@ -16,8 +16,8 @@ import habitat
from habitat.config import Config from habitat.config import Config
from habitat.config.default import get_config from habitat.config.default import get_config
from habitat.core.agent import Agent from habitat.core.agent import Agent
from habitat_baselines.common.utils import batch_obs
from habitat_baselines.rl.ppo import Policy from habitat_baselines.rl.ppo import Policy
from habitat_baselines.rl.ppo.utils import batch_obs
def get_default_config(): def get_default_config():
......
...@@ -70,16 +70,16 @@ def make_good_config_for_orbslam2(config): ...@@ -70,16 +70,16 @@ def make_good_config_for_orbslam2(config):
config.SIMULATOR.RGB_SENSOR.HEIGHT = 256 config.SIMULATOR.RGB_SENSOR.HEIGHT = 256
config.SIMULATOR.DEPTH_SENSOR.WIDTH = 256 config.SIMULATOR.DEPTH_SENSOR.WIDTH = 256
config.SIMULATOR.DEPTH_SENSOR.HEIGHT = 256 config.SIMULATOR.DEPTH_SENSOR.HEIGHT = 256
config.BASELINE.ORBSLAM2.CAMERA_HEIGHT = config.SIMULATOR.DEPTH_SENSOR.POSITION[ config.TRAINER.ORBSLAM2.CAMERA_HEIGHT = config.SIMULATOR.DEPTH_SENSOR.POSITION[
1 1
] ]
config.BASELINE.ORBSLAM2.H_OBSTACLE_MIN = ( config.TRAINER.ORBSLAM2.H_OBSTACLE_MIN = (
0.3 * config.BASELINE.ORBSLAM2.CAMERA_HEIGHT 0.3 * config.TRAINER.ORBSLAM2.CAMERA_HEIGHT
) )
config.BASELINE.ORBSLAM2.H_OBSTACLE_MAX = ( config.TRAINER.ORBSLAM2.H_OBSTACLE_MAX = (
1.0 * config.BASELINE.ORBSLAM2.CAMERA_HEIGHT 1.0 * config.TRAINER.ORBSLAM2.CAMERA_HEIGHT
) )
config.BASELINE.ORBSLAM2.MIN_PTS_IN_OBSTACLE = ( config.TRAINER.ORBSLAM2.MIN_PTS_IN_OBSTACLE = (
config.SIMULATOR.DEPTH_SENSOR.WIDTH / 2.0 config.SIMULATOR.DEPTH_SENSOR.WIDTH / 2.0
) )
return return
...@@ -607,11 +607,11 @@ def main(): ...@@ -607,11 +607,11 @@ def main():
make_good_config_for_orbslam2(config) make_good_config_for_orbslam2(config)
if args.agent_type == "blind": if args.agent_type == "blind":
agent = BlindAgent(config.BASELINE.ORBSLAM2) agent = BlindAgent(config.TRAINER.ORBSLAM2)
elif args.agent_type == "orbslam2-rgbd": elif args.agent_type == "orbslam2-rgbd":
agent = ORBSLAM2Agent(config.BASELINE.ORBSLAM2) agent = ORBSLAM2Agent(config.TRAINER.ORBSLAM2)
elif args.agent_type == "orbslam2-rgb-monod": elif args.agent_type == "orbslam2-rgb-monod":
agent = ORBSLAM2MonodepthAgent(config.BASELINE.ORBSLAM2) agent = ORBSLAM2MonodepthAgent(config.TRAINER.ORBSLAM2)
else: else:
raise ValueError(args.agent_type, "is unknown type of agent") raise ValueError(args.agent_type, "is unknown type of agent")
benchmark = habitat.Benchmark(args.task_config) benchmark = habitat.Benchmark(args.task_config)
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import ClassVar, Dict, List
class BaseTrainer:
"""
Most generic trainer class that serves as a base template for more
specific trainer classes like RL trainer, SLAM or imitation learner.
Includes only the most basic functionality.
"""
supported_tasks: ClassVar[List[str]]
def train(self) -> None:
raise NotImplementedError
def eval(self) -> None:
raise NotImplementedError
def save_checkpoint(self, file_name) -> None:
raise NotImplementedError
def load_checkpoint(self, checkpoint_path, *args, **kwargs) -> Dict:
raise NotImplementedError
class BaseRLTrainer(BaseTrainer):
"""
Base trainer class for RL based trainers. Future RL-specific
methods should be hosted here.
"""
def __init__(self, config):
super().__init__()
self.config = config
def train(self) -> None:
raise NotImplementedError
def eval(self) -> None:
raise NotImplementedError
def save_checkpoint(self, file_name) -> None:
raise NotImplementedError
def load_checkpoint(self, checkpoint_path, *args, **kwargs) -> Dict:
raise NotImplementedError
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""BaselineRegistry is extended from habitat.Registry to provide
registration for trainer and environments, while keeping Registry
in habitat core intact.
Import the baseline registry object using
``from habitat_baselines.common.baseline_registry import baseline_registry``
Various decorators for registry different kind of classes with unique keys
- Register a environment: ``@registry.register_env``
- Register a trainer: ``@registry.register_trainer``
"""
from typing import Optional
from habitat.core.registry import Registry
class BaselineRegistry(Registry):
@classmethod
def register_trainer(cls, to_register=None, *, name: Optional[str] = None):
r"""Register a RL training algorithm to registry with key 'name'.
Args:
name: Key with which the trainer will be registered.
If None will use the name of the class.
"""
from habitat_baselines.common.base_trainer import BaseTrainer
return cls._register_impl(
"trainer", to_register, name, assert_type=BaseTrainer
)
@classmethod
def get_trainer(cls, name):
return cls._get_impl("trainer", name)
@classmethod
def register_env(cls, to_register=None, *, name: Optional[str] = None):
r"""Register a environment to registry with key 'name'
currently only support subclass of RLEnv.
Args:
name: Key with which the env will be registered.
If None will use the name of the class.
"""
from habitat import RLEnv
return cls._register_impl("env", to_register, name, assert_type=RLEnv)
@classmethod
def get_env(cls, name):
return cls._get_impl("env", name)
baseline_registry = BaselineRegistry()
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import random
from typing import Type
import habitat
from habitat import Config, Env, VectorEnv, make_dataset
def make_env_fn(
task_config: Config, rl_env_config: Config, env_class: Type, rank: int
) -> Env:
r"""
Creates an env of type env_class with specified config and rank.
This is to be passed in as an argument when creating VectorEnv.
Args:
task_config: task config file for creating env.
rl_env_config: RL env config for creating env.
env_class: class type of the env to be created.
rank: rank of env to be created (for seeding).
Returns:
env object created according to specification.
"""
dataset = make_dataset(
task_config.DATASET.TYPE, config=task_config.DATASET
)
env = env_class(
config_env=task_config, config_baseline=rl_env_config, dataset=dataset
)
env.seed(rank)
return env
def construct_envs(config: Config, env_class: Type) -> VectorEnv:
r"""
Create VectorEnv object with specified config and env class type.
To allow better performance, dataset are split into small ones for
each individual env, grouped by scenes.
Args:
config: configs that contain num_processes as well as information
necessary to create individual environments.
env_class: class type of the envs to be created.
Returns:
VectorEnv object created according to specification.
"""
trainer_config = config.TRAINER.RL.PPO
rl_env_config = config.TRAINER.RL
task_config = config.TASK_CONFIG # excluding trainer-specific configs
env_configs, rl_env_configs = [], []
env_classes = [env_class for _ in range(trainer_config.num_processes)]
dataset = make_dataset(task_config.DATASET.TYPE)
scenes = dataset.get_scenes_to_load(task_config.DATASET)
if len(scenes) > 0:
random.shuffle(scenes)
assert len(scenes) >= trainer_config.num_processes, (
"reduce the number of processes as there "
"aren't enough number of scenes"
)
scene_splits = [[] for _ in range(trainer_config.num_processes)]
for idx, scene in enumerate(scenes):
scene_splits[idx % len(scene_splits)].append(scene)
assert sum(map(len, scene_splits)) == len(scenes)
for i in range(trainer_config.num_processes):
env_config = task_config.clone()
env_config.defrost()
if len(scenes) > 0:
env_config.DATASET.CONTENT_SCENES = scene_splits[i]
env_config.SIMULATOR.HABITAT_SIM_V0.GPU_DEVICE_ID = (
trainer_config.sim_gpu_id
)
agent_sensors = trainer_config.sensors.strip().split(",")
env_config.SIMULATOR.AGENT_0.SENSORS = agent_sensors
env_config.freeze()
env_configs.append(env_config)
rl_env_configs.append(rl_env_config)
envs = habitat.VectorEnv(
make_env_fn=make_env_fn,
env_fn_args=tuple(
tuple(
zip(
env_configs,
rl_env_configs,
env_classes,
range(trainer_config.num_processes),
)
)
),
)
return envs
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
r"""
This file hosts task-specific or trainer-specific environments for trainers.
All environments here should be a (direct or indirect ) subclass of Env class
in habitat. Customized environments should be registered using
``@baseline_registry.register_env(name="myEnv")` for reusability
"""
import habitat
from habitat import SimulatorActions
from habitat_baselines.common.baseline_registry import baseline_registry
@baseline_registry.register_env(name="NavRLEnv")
class NavRLEnv(habitat.RLEnv):
def __init__(self, config_env, config_baseline, dataset):
self._config_env = config_env.TASK
self._config_baseline = config_baseline
self._previous_target_distance = None
self._previous_action = None
self._episode_distance_covered = None
super().__init__(config_env, dataset)
def reset(self):
self._previous_action = None
observations = super().reset()
self._previous_target_distance = self.habitat_env.current_episode.info[
"geodesic_distance"
]
return observations
def step(self, action):
self._previous_action = action
return super().step(action)
def get_reward_range(self):
return (
self._config_baseline.SLACK_REWARD - 1.0,
self._config_baseline.SUCCESS_REWARD + 1.0,
)
def get_reward(self, observations):
reward = self._config_baseline.SLACK_REWARD
current_target_distance = self._distance_target()
reward += self._previous_target_distance - current_target_distance
self._previous_target_distance = current_target_distance
if self._episode_success():
reward += self._config_baseline.SUCCESS_REWARD
return reward
def _distance_target(self):
current_position = self._env.sim.get_agent_state().position.tolist()
target_position = self._env.current_episode.goals[0].position
distance = self._env.sim.geodesic_distance(
current_position, target_position
)
return distance
def _episode_success(self):
if (
self._previous_action == SimulatorActions.STOP
and self._distance_target() < self._config_env.SUCCESS_DISTANCE
):
return True
return False
def get_done(self, observations):
done = False
if self._env.episode_over or self._episode_success():
done = True
return done
def get_info(self, observations):
return self.habitat_env.get_metrics()
...@@ -4,63 +4,16 @@ ...@@ -4,63 +4,16 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import argparse
from collections import defaultdict from collections import defaultdict
import numpy as np
import torch import torch
import torch.nn as nn
class Flatten(nn.Module): class RolloutStorage:
def forward(self, x): r"""
return x.view(x.size(0), -1) Class for storing rollout information for RL trainers
class CustomFixedCategorical(torch.distributions.Categorical):
def sample(self, sample_shape=torch.Size()):
return super().sample(sample_shape).unsqueeze(-1)
def log_probs(self, actions):
return (
super()
.log_prob(actions.squeeze(-1))
.view(actions.size(0), -1)
.sum(-1)
.unsqueeze(-1)
)
def mode(self):
return self.probs.argmax(dim=-1, keepdim=True)
class CategoricalNet(nn.Module):
def __init__(self, num_inputs, num_outputs):
super().__init__()
self.linear = nn.Linear(num_inputs, num_outputs)
nn.init.orthogonal_(self.linear.weight, gain=0.01)
nn.init.constant_(self.linear.bias, 0)
def forward(self, x):
x = self.linear(x)
return CustomFixedCategorical(logits=x)
def _flatten_helper(t, n, tensor):
return tensor.view(t * n, *tensor.size()[2:])
def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
r"""Decreases the learning rate linearly
""" """
lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
for param_group in optimizer.param_groups:
param_group["lr"] = lr
class RolloutStorage:
def __init__( def __init__(
self, self,
num_steps, num_steps,
...@@ -75,7 +28,7 @@ class RolloutStorage: ...@@ -75,7 +28,7 @@ class RolloutStorage:
self.observations[sensor] = torch.zeros( self.observations[sensor] = torch.zeros(
num_steps + 1, num_steps + 1,
num_envs, num_envs,
*observation_space.spaces[sensor].shape *observation_space.spaces[sensor].shape,
) )
self.recurrent_hidden_states = torch.zeros( self.recurrent_hidden_states = torch.zeros(
...@@ -168,9 +121,9 @@ class RolloutStorage: ...@@ -168,9 +121,9 @@ class RolloutStorage:
def recurrent_generator(self, advantages, num_mini_batch): def recurrent_generator(self, advantages, num_mini_batch):
num_processes = self.rewards.size(1) num_processes = self.rewards.size(1)
assert num_processes >= num_mini_batch, ( assert num_processes >= num_mini_batch, (
"PPO requires the number of processes ({}) " "Trainer requires the number of processes ({}) "
"to be greater than or equal to the number of " "to be greater than or equal to the number of "
"PPO mini batches ({}).".format(num_processes, num_mini_batch) "trainer mini batches ({}).".format(num_processes, num_mini_batch)
) )
num_envs_per_batch = num_processes // num_mini_batch num_envs_per_batch = num_processes // num_mini_batch
perm = torch.randperm(num_processes) perm = torch.randperm(num_processes)
...@@ -231,18 +184,18 @@ class RolloutStorage: ...@@ -231,18 +184,18 @@ class RolloutStorage:
# Flatten the (T, N, ...) tensors to (T * N, ...) # Flatten the (T, N, ...) tensors to (T * N, ...)
for sensor in observations_batch: for sensor in observations_batch:
observations_batch[sensor] = _flatten_helper( observations_batch[sensor] = self._flatten_helper(
T, N, observations_batch[sensor] T, N, observations_batch[sensor]
) )
actions_batch = _flatten_helper(T, N, actions_batch) actions_batch = self._flatten_helper(T, N, actions_batch)
value_preds_batch = _flatten_helper(T, N, value_preds_batch) value_preds_batch = self._flatten_helper(T, N, value_preds_batch)
return_batch = _flatten_helper(T, N, return_batch) return_batch = self._flatten_helper(T, N, return_batch)
masks_batch = _flatten_helper(T, N, masks_batch) masks_batch = self._flatten_helper(T, N, masks_batch)
old_action_log_probs_batch = _flatten_helper( old_action_log_probs_batch = self._flatten_helper(
T, N, old_action_log_probs_batch T, N, old_action_log_probs_batch
) )
adv_targ = _flatten_helper(T, N, adv_targ) adv_targ = self._flatten_helper(T, N, adv_targ)
yield ( yield (
observations_batch, observations_batch,
...@@ -255,176 +208,16 @@ class RolloutStorage: ...@@ -255,176 +208,16 @@ class RolloutStorage:
adv_targ, adv_targ,
) )
@staticmethod
def batch_obs(observations): def _flatten_helper(t: int, n: int, tensor: torch.Tensor) -> torch.Tensor:
batch = defaultdict(list) r"""
Given a tensor of size (t, n, ..), flatten it to size (t*n, ...).
for obs in observations: Args:
for sensor in obs: t: first dimension of tensor.
batch[sensor].append(obs[sensor]) n: second dimension of tensor.
tensor: target tensor to be flattened.
for sensor in batch:
batch[sensor] = torch.tensor( Returns:
np.array(batch[sensor]), dtype=torch.float flattened tensor of size (t*n, ...)
) """
return batch return tensor.view(t * n, *tensor.size()[2:])
def ppo_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--clip-param",
type=float,
default=0.2,
help="ppo clip parameter (default: 0.2)",
)
parser.add_argument(
"--ppo-epoch",
type=int,
default=4,
help="number of ppo epochs (default: 4)",
)
parser.add_argument(
"--num-mini-batch",
type=int,
default=32,
help="number of batches for ppo (default: 32)",
)
parser.add_argument(
"--value-loss-coef",
type=float,
default=0.5,
help="value loss coefficient (default: 0.5)",
)
parser.add_argument(
"--entropy-coef",
type=float,
default=0.01,
help="entropy term coefficient (default: 0.01)",
)
parser.add_argument(
"--lr", type=float, default=7e-4, help="learning rate (default: 7e-4)"
)
parser.add_argument(
"--eps",
type=float,
default=1e-5,
help="RMSprop optimizer epsilon (default: 1e-5)",
)
parser.add_argument(
"--max-grad-norm",
type=float,
default=0.5,
help="max norm of gradients (default: 0.5)",
)
parser.add_argument(
"--num-steps",
type=int,
default=5,
help="number of forward steps in A2C (default: 5)",
)
parser.add_argument("--hidden-size", type=int, default=512)
parser.add_argument(
"--num-processes",
type=int,
default=16,
help="number of training processes " "to use (default: 16)",
)
parser.add_argument(
"--use-gae",
action="store_true",
default=False,
help="use generalized advantage estimation",
)
parser.add_argument(
"--use-linear-lr-decay",
action="store_true",
default=False,
help="use a linear schedule on the learning rate",
)
parser.add_argument(
"--use-linear-clip-decay",
action="store_true",
default=False,
help="use a linear schedule on the " "ppo clipping parameter",
)
parser.add_argument(
"--gamma",
type=float,
default=0.99,
help="discount factor for rewards (default: 0.99)",
)
parser.add_argument(
"--tau", type=float, default=0.95, help="gae parameter (default: 0.95)"
)
parser.add_argument(
"--log-file", type=str, required=True, help="path for log file"
)
parser.add_argument(
"--reward-window-size",
type=int,
default=50,
help="logging window for rewards",
)
parser.add_argument(
"--log-interval",
type=int,
default=1,
help="number of updates after which metrics are logged",
)
parser.add_argument(
"--checkpoint-interval",
type=int,
default=50,
help="number of updates after which models are checkpointed",
)
parser.add_argument(
"--checkpoint-folder",
type=str,
required=True,
help="folder for storing checkpoints",
)
parser.add_argument(
"--sim-gpu-id",
type=int,
required=True,
help="gpu id on which scenes are loaded",
)
parser.add_argument(
"--pth-gpu-id",
type=int,
required=True,
help="gpu id on which pytorch runs",
)
parser.add_argument(
"--num-updates",
type=int,
default=10000,
help="number of PPO updates to run",
)
parser.add_argument(
"--sensors",
type=str,
default="RGB_SENSOR,DEPTH_SENSOR",
help="comma separated string containing different sensors to use,"
"currently 'RGB_SENSOR' and 'DEPTH_SENSOR' are supported",
)
parser.add_argument(
"--task-config",
type=str,
default="configs/tasks/pointnav.yaml",
help="path to config yaml containing information about task",
)
parser.add_argument("--seed", type=int, default=100)
parser.add_argument(
"opts",
default=None,
nargs=argparse.REMAINDER,
help="Modify config options from command line",
)
parser.add_argument(
"--tensorboard-dir",
type=str,
help="path to tensorboard logging directory",
)
return parser
from typing import Optional, Union #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union
import numpy as np import numpy as np
import torch import torch
from torch.utils.tensorboard import SummaryWriter
# TODO Add checks to replace DummyWriter
class DummyWriter:
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def close(self):
pass
def __getattr__(self, item):
return lambda *args, **kwargs: None
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = DummyWriter
class TensorboardWriter(SummaryWriter): class TensorboardWriter(SummaryWriter):
...@@ -33,23 +62,6 @@ class TensorboardWriter(SummaryWriter): ...@@ -33,23 +62,6 @@ class TensorboardWriter(SummaryWriter):
self.add_video(video_name, video_tensor, fps=fps, global_step=step_idx) self.add_video(video_name, video_tensor, fps=fps, global_step=step_idx)
class DummyWriter:
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def close(self):
pass
def __getattr__(self, item):
return lambda *args, **kwargs: None
def get_tensorboard_writer( def get_tensorboard_writer(
log_dir: str, *args, **kwargs log_dir: str, *args, **kwargs
) -> Union[DummyWriter, TensorboardWriter]: ) -> Union[DummyWriter, TensorboardWriter]:
...@@ -62,7 +74,7 @@ def get_tensorboard_writer( ...@@ -62,7 +74,7 @@ def get_tensorboard_writer(
**kwargs: additional keyword args. **kwargs: additional keyword args.
Returns: Returns:
Either the created tensorboard writer or a dummy writer. either the created tensorboard writer or a dummy writer.
""" """
if log_dir: if log_dir:
return TensorboardWriter(log_dir, *args, **kwargs) return TensorboardWriter(log_dir, *args, **kwargs)
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import glob
import os
from collections import defaultdict
from typing import Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
from habitat import Config
from habitat.utils.visualizations.utils import images_to_video
from habitat_baselines import BaseTrainer
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.tensorboard_utils import (
DummyWriter,
TensorboardWriter,
)
# TODO distribute utilities in this file to separate files
def get_trainer(trainer_name: str, trainer_cfg: Config) -> BaseTrainer:
r"""
Create specific trainer instance according to name.
Args:
trainer_name: name of registered trainer .
trainer_cfg: config file for trainer.
Returns:
an instance of the specified trainer.
"""
trainer = baseline_registry.get_trainer(trainer_name)
assert trainer is not None, f"{trainer_name} is not supported"
return trainer(trainer_cfg)
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class CustomFixedCategorical(torch.distributions.Categorical):
def sample(self, sample_shape=torch.Size()):
return super().sample(sample_shape).unsqueeze(-1)
def log_probs(self, actions):
return (
super()
.log_prob(actions.squeeze(-1))
.view(actions.size(0), -1)
.sum(-1)
.unsqueeze(-1)
)
def mode(self):
return self.probs.argmax(dim=-1, keepdim=True)
class CategoricalNet(nn.Module):
def __init__(self, num_inputs, num_outputs):
super().__init__()
self.linear = nn.Linear(num_inputs, num_outputs)
nn.init.orthogonal_(self.linear.weight, gain=0.01)
nn.init.constant_(self.linear.bias, 0)
def forward(self, x):
x = self.linear(x)
return CustomFixedCategorical(logits=x)
# TODO make this a LRScheduler class
def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
r"""Decreases the learning rate linearly
"""
lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def batch_obs(observations: List[Dict]) -> Dict:
r"""
Transpose a batch of observation dicts to a dict of batched
observations.
Args:
observations: list of dicts of observations.
Returns:
transposed dict of lists of observations.
"""
batch = defaultdict(list)
for obs in observations:
for sensor in obs:
batch[sensor].append(obs[sensor])
for sensor in batch:
batch[sensor] = torch.tensor(
np.array(batch[sensor]), dtype=torch.float
)
return batch
def poll_checkpoint_folder(
checkpoint_folder: str, previous_ckpt_ind: int
) -> Optional[str]:
r""" Return (previous_ckpt_ind + 1)th checkpoint in checkpoint folder
(sorted by time of last modification).
Args:
checkpoint_folder: directory to look for checkpoints.
previous_ckpt_ind: index of checkpoint last returned.
Returns:
return checkpoint path if (previous_ckpt_ind + 1)th checkpoint is found
else return None.
"""
assert os.path.isdir(checkpoint_folder), "invalid checkpoint folder path"
models_paths = list(
filter(os.path.isfile, glob.glob(checkpoint_folder + "/*"))
)
models_paths.sort(key=os.path.getmtime)
ind = previous_ckpt_ind + 1
if ind < len(models_paths):
return models_paths[ind]
return None
def generate_video(
config: Config,
images: List[np.ndarray],
episode_id: int,
checkpoint_idx: int,
spl: float,
tb_writer: Union[DummyWriter, TensorboardWriter],
fps: int = 10,
) -> None:
r"""
Generate video according to specified information.
Args:
config: config object that contains video_option and video_dir.
images: list of images to be converted to video.
episode_id: episode id for video naming.
checkpoint_idx: checkpoint index for video naming.
spl: SPL for this episode for video naming.
tb_writer: tensorboard writer object for uploading video
fps: fps for generated video
Returns:
None
"""
if config.video_option and len(images) > 0:
video_name = f"episode{episode_id}_ckpt{checkpoint_idx}_spl{spl:.2f}"
if "disk" in config.video_option:
images_to_video(images, config.video_dir, video_name)
if "tensorboard" in config.video_option:
tb_writer.add_video_from_np_images(
f"episode{episode_id}", checkpoint_idx, images, fps=fps
)
...@@ -8,7 +8,7 @@ from typing import List, Optional, Union ...@@ -8,7 +8,7 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
from habitat import get_config from habitat import get_config as get_task_config
from habitat.config import Config as CN from habitat.config import Config as CN
DEFAULT_CONFIG_DIR = "configs/" DEFAULT_CONFIG_DIR = "configs/"
...@@ -17,49 +17,88 @@ CONFIG_FILE_SEPARATOR = "," ...@@ -17,49 +17,88 @@ CONFIG_FILE_SEPARATOR = ","
# Config definition # Config definition
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
_C = CN() _C = CN()
_C.SEED = 100 _C.BASE_TASK_CONFIG_PATH = "configs/tasks/pointnav.yaml"
_C.TASK_CONFIG = CN() # task_config will be stored as a config node
_C.CMD_TRAILING_OPTS = "" # store command line options"
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# BASELINE # TRAINER ALGORITHMS
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
_C.BASELINE = CN() _C.TRAINER = CN()
_C.TRAINER.TRAINER_NAME = "ppo"
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# REINFORCEMENT LEARNING (RL) # REINFORCEMENT LEARNING (RL)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
_C.BASELINE.RL = CN() _C.TRAINER.RL = CN()
_C.BASELINE.RL.SUCCESS_REWARD = 10.0 _C.TRAINER.RL.SUCCESS_REWARD = 10.0
_C.BASELINE.RL.SLACK_REWARD = -0.01 _C.TRAINER.RL.SLACK_REWARD = -0.01
# -----------------------------------------------------------------------------
# PROXIMAL POLICY OPTIMIZATION (PPO)
# -----------------------------------------------------------------------------
# TODO move general options out of PPO
_C.TRAINER.RL.PPO = CN()
_C.TRAINER.RL.PPO.clip_param = 0.2
_C.TRAINER.RL.PPO.ppo_epoch = 4
_C.TRAINER.RL.PPO.num_mini_batch = 16
_C.TRAINER.RL.PPO.value_loss_coef = 0.5
_C.TRAINER.RL.PPO.entropy_coef = 0.01
_C.TRAINER.RL.PPO.lr = 7e-4
_C.TRAINER.RL.PPO.eps = 1e-5
_C.TRAINER.RL.PPO.max_grad_norm = 0.5
_C.TRAINER.RL.PPO.num_steps = 5
_C.TRAINER.RL.PPO.hidden_size = 512
_C.TRAINER.RL.PPO.num_processes = 16
_C.TRAINER.RL.PPO.use_gae = True
_C.TRAINER.RL.PPO.use_linear_lr_decay = False
_C.TRAINER.RL.PPO.use_linear_clip_decay = False
_C.TRAINER.RL.PPO.gamma = 0.99
_C.TRAINER.RL.PPO.tau = 0.95
_C.TRAINER.RL.PPO.log_file = "train.log"
_C.TRAINER.RL.PPO.reward_window_size = 50
_C.TRAINER.RL.PPO.log_interval = 50
_C.TRAINER.RL.PPO.checkpoint_interval = 50
_C.TRAINER.RL.PPO.checkpoint_folder = "data/checkpoints"
_C.TRAINER.RL.PPO.sim_gpu_id = 0
_C.TRAINER.RL.PPO.pth_gpu_id = 0
_C.TRAINER.RL.PPO.num_updates = 10000
_C.TRAINER.RL.PPO.sensors = "RGB_SENSOR,DEPTH_SENSOR"
_C.TRAINER.RL.PPO.task_config = "configs/tasks/pointnav.yaml"
_C.TRAINER.RL.PPO.tensorboard_dir = "tb"
_C.TRAINER.RL.PPO.count_test_episodes = 2
_C.TRAINER.RL.PPO.video_option = "disk,tensorboard"
_C.TRAINER.RL.PPO.video_dir = "video_dir"
_C.TRAINER.RL.PPO.eval_ckpt_path_or_dir = "data/checkpoints"
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# ORBSLAM2 BASELINE # ORBSLAM2 BASELINE
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
_C.BASELINE.ORBSLAM2 = CN() _C.TRAINER.ORBSLAM2 = CN()
_C.BASELINE.ORBSLAM2.SLAM_VOCAB_PATH = ( _C.TRAINER.ORBSLAM2.SLAM_VOCAB_PATH = (
"habitat_baselines/slambased/data/ORBvoc.txt" "habitat_baselines/slambased/data/ORBvoc.txt"
) )
_C.BASELINE.ORBSLAM2.SLAM_SETTINGS_PATH = ( _C.TRAINER.ORBSLAM2.SLAM_SETTINGS_PATH = (
"habitat_baselines/slambased/data/mp3d3_small1k.yaml" "habitat_baselines/slambased/data/mp3d3_small1k.yaml"
) )
_C.BASELINE.ORBSLAM2.MAP_CELL_SIZE = 0.1 _C.TRAINER.ORBSLAM2.MAP_CELL_SIZE = 0.1
_C.BASELINE.ORBSLAM2.MAP_SIZE = 40 _C.TRAINER.ORBSLAM2.MAP_SIZE = 40
_C.BASELINE.ORBSLAM2.CAMERA_HEIGHT = get_config().SIMULATOR.DEPTH_SENSOR.POSITION[ _C.TRAINER.ORBSLAM2.CAMERA_HEIGHT = get_task_config().SIMULATOR.DEPTH_SENSOR.POSITION[
1 1
] ]
_C.BASELINE.ORBSLAM2.BETA = 100 _C.TRAINER.ORBSLAM2.BETA = 100
_C.BASELINE.ORBSLAM2.H_OBSTACLE_MIN = 0.3 * _C.BASELINE.ORBSLAM2.CAMERA_HEIGHT _C.TRAINER.ORBSLAM2.H_OBSTACLE_MIN = 0.3 * _C.TRAINER.ORBSLAM2.CAMERA_HEIGHT
_C.BASELINE.ORBSLAM2.H_OBSTACLE_MAX = 1.0 * _C.BASELINE.ORBSLAM2.CAMERA_HEIGHT _C.TRAINER.ORBSLAM2.H_OBSTACLE_MAX = 1.0 * _C.TRAINER.ORBSLAM2.CAMERA_HEIGHT
_C.BASELINE.ORBSLAM2.D_OBSTACLE_MIN = 0.1 _C.TRAINER.ORBSLAM2.D_OBSTACLE_MIN = 0.1
_C.BASELINE.ORBSLAM2.D_OBSTACLE_MAX = 4.0 _C.TRAINER.ORBSLAM2.D_OBSTACLE_MAX = 4.0
_C.BASELINE.ORBSLAM2.PREPROCESS_MAP = True _C.TRAINER.ORBSLAM2.PREPROCESS_MAP = True
_C.BASELINE.ORBSLAM2.MIN_PTS_IN_OBSTACLE = ( _C.TRAINER.ORBSLAM2.MIN_PTS_IN_OBSTACLE = (
get_config().SIMULATOR.DEPTH_SENSOR.WIDTH / 2.0 get_task_config().SIMULATOR.DEPTH_SENSOR.WIDTH / 2.0
) )
_C.BASELINE.ORBSLAM2.ANGLE_TH = float(np.deg2rad(15)) _C.TRAINER.ORBSLAM2.ANGLE_TH = float(np.deg2rad(15))
_C.BASELINE.ORBSLAM2.DIST_REACHED_TH = 0.15 _C.TRAINER.ORBSLAM2.DIST_REACHED_TH = 0.15
_C.BASELINE.ORBSLAM2.NEXT_WAYPOINT_TH = 0.5 _C.TRAINER.ORBSLAM2.NEXT_WAYPOINT_TH = 0.5
_C.BASELINE.ORBSLAM2.NUM_ACTIONS = 3 _C.TRAINER.ORBSLAM2.NUM_ACTIONS = 3
_C.BASELINE.ORBSLAM2.DIST_TO_STOP = 0.05 _C.TRAINER.ORBSLAM2.DIST_TO_STOP = 0.05
_C.BASELINE.ORBSLAM2.PLANNER_MAX_STEPS = 500 _C.TRAINER.ORBSLAM2.PLANNER_MAX_STEPS = 500
_C.BASELINE.ORBSLAM2.DEPTH_DENORM = ( _C.TRAINER.ORBSLAM2.DEPTH_DENORM = (
get_config().SIMULATOR.DEPTH_SENSOR.MAX_DEPTH get_task_config().SIMULATOR.DEPTH_SENSOR.MAX_DEPTH
) )
...@@ -87,6 +126,8 @@ def get_config( ...@@ -87,6 +126,8 @@ def get_config(
for config_path in config_paths: for config_path in config_paths:
config.merge_from_file(config_path) config.merge_from_file(config_path)
config.TASK_CONFIG = get_task_config(config.BASE_TASK_CONFIG_PATH)
config.CMD_TRAILING_OPTS = opts
if opts: if opts:
config.merge_from_list(opts) config.merge_from_list(opts)
......
BASE_TASK_CONFIG_PATH: "configs/tasks/pointnav.yaml"
TRAINER:
TRAINER_NAME: "ppo"
RL:
PPO:
# general options
tensorboard_dir: "tb"
num_processes: 1
log_interval: 10
pth_gpu_id: 0
sim_gpu_id: 0
checkpoint_interval: 50
checkpoint_folder: "data/checkpoints"
task_config: "configs/tasks/pointnav.yaml"
num_updates: 10000
# eval specific:
count_test_episodes: 2
video_option: "disk,tensorboard"
video_dir: "video_dir"
eval_ckpt_path_or_dir: "data/checkpoints"
# ppo params
clip_param: 0.1
ppo_epoch: 4
num_mini_batch: 1
value_loss_coef: 0.5
entropy_coef: 0.01
lr: 2.5e-4
eps: 1e-5
max_grad_norm: 0.5
num_steps: 128
hidden_size: 512
use_gae: True
gamma: 0.99
tau: 0.95
use_linear_clip_decay: True
use_linear_lr_decay: True
reward_window_size: 50
sensors: "RGB_SENSOR,DEPTH_SENSOR"
TRAINER:
TRAINER_NAME: "ppo"
RL:
PPO:
num_processes: 1
pth_gpu_id: 0
sim_gpu_id: 0
checkpoint_interval: 50
checkpoint_folder: "data/checkpoints"
task_config: "configs/tasks/pointnav.yaml"
num_updates: 10
num_mini_batch: 1
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import glob
import os
import time
from typing import Optional
import torch
import habitat
from config.default import get_config as cfg_baseline
from habitat import logger
from habitat.config.default import get_config
from habitat.utils.visualizations.utils import (
images_to_video,
observations_to_image,
)
from rl.ppo import PPO, Policy
from rl.ppo.utils import batch_obs
from tensorboard_utils import get_tensorboard_writer
from train_ppo import make_env_fn
def poll_checkpoint_folder(
checkpoint_folder: str, previous_ckpt_ind: int
) -> Optional[str]:
r""" Return (previous_ckpt_ind + 1)th checkpoint in checkpoint folder
(sorted by time of last modification).
Args:
checkpoint_folder: directory to look for checkpoints.
previous_ckpt_ind: index of checkpoint last returned.
Returns:
return checkpoint path if (previous_ckpt_ind + 1)th checkpoint is found
else return None.
"""
assert os.path.isdir(checkpoint_folder), "invalid checkpoint folder path"
models_paths = list(
filter(os.path.isfile, glob.glob(checkpoint_folder + "/*"))
)
models_paths.sort(key=os.path.getmtime)
ind = previous_ckpt_ind + 1
if ind < len(models_paths):
return models_paths[ind]
return None
def generate_video(
args, images, episode_id, checkpoint_idx, spl, tb_writer, fps=10
) -> None:
r"""Generate video according to specified information.
Args:
args: contains args.video_option and args.video_dir.
images: list of images to be converted to video.
episode_id: episode id for video naming.
checkpoint_idx: checkpoint index for video naming.
spl: SPL for this episode for video naming.
tb_writer: tensorboard writer object for uploading video
fps: fps for generated video
Returns:
None
"""
if args.video_option and len(images) > 0:
video_name = f"episode{episode_id}_ckpt{checkpoint_idx}_spl{spl:.2f}"
if "disk" in args.video_option:
images_to_video(images, args.video_dir, video_name)
if "tensorboard" in args.video_option:
tb_writer.add_video_from_np_images(
f"episode{episode_id}", checkpoint_idx, images, fps=fps
)
def eval_checkpoint(checkpoint_path, args, writer, cur_ckpt_idx=0):
env_configs = []
baseline_configs = []
device = torch.device("cuda", args.pth_gpu_id)
for _ in range(args.num_processes):
config_env = get_config(config_paths=args.task_config)
config_env.defrost()
config_env.DATASET.SPLIT = "val"
agent_sensors = args.sensors.strip().split(",")
for sensor in agent_sensors:
assert sensor in ["RGB_SENSOR", "DEPTH_SENSOR"]
config_env.SIMULATOR.AGENT_0.SENSORS = agent_sensors
if args.video_option:
config_env.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
config_env.TASK.MEASUREMENTS.append("COLLISIONS")
config_env.freeze()
env_configs.append(config_env)
config_baseline = cfg_baseline()
baseline_configs.append(config_baseline)
assert len(baseline_configs) > 0, "empty list of datasets"
envs = habitat.VectorEnv(
make_env_fn=make_env_fn,
env_fn_args=tuple(
tuple(
zip(env_configs, baseline_configs, range(args.num_processes))
)
),
)
ckpt = torch.load(checkpoint_path, map_location=device)
actor_critic = Policy(
observation_space=envs.observation_spaces[0],
action_space=envs.action_spaces[0],
hidden_size=512,
goal_sensor_uuid=env_configs[0].TASK.GOAL_SENSOR_UUID,
)
actor_critic.to(device)
ppo = PPO(
actor_critic=actor_critic,
clip_param=0.1,
ppo_epoch=4,
num_mini_batch=32,
value_loss_coef=0.5,
entropy_coef=0.01,
lr=2.5e-4,
eps=1e-5,
max_grad_norm=0.5,
)
ppo.load_state_dict(ckpt["state_dict"])
actor_critic = ppo.actor_critic
observations = envs.reset()
batch = batch_obs(observations)
for sensor in batch:
batch[sensor] = batch[sensor].to(device)
current_episode_reward = torch.zeros(envs.num_envs, 1, device=device)
test_recurrent_hidden_states = torch.zeros(
args.num_processes, args.hidden_size, device=device
)
not_done_masks = torch.zeros(args.num_processes, 1, device=device)
stats_episodes = dict() # dict of dicts that stores stats per episode
rgb_frames = None
if args.video_option:
rgb_frames = [[]] * args.num_processes
os.makedirs(args.video_dir, exist_ok=True)
while len(stats_episodes) < args.count_test_episodes and envs.num_envs > 0:
current_episodes = envs.current_episodes()
with torch.no_grad():
_, actions, _, test_recurrent_hidden_states = actor_critic.act(
batch,
test_recurrent_hidden_states,
not_done_masks,
deterministic=False,
)
outputs = envs.step([a[0].item() for a in actions])
observations, rewards, dones, infos = [list(x) for x in zip(*outputs)]
batch = batch_obs(observations)
for sensor in batch:
batch[sensor] = batch[sensor].to(device)
not_done_masks = torch.tensor(
[[0.0] if done else [1.0] for done in dones],
dtype=torch.float,
device=device,
)
rewards = torch.tensor(
rewards, dtype=torch.float, device=device
).unsqueeze(1)
current_episode_reward += rewards
next_episodes = envs.current_episodes()
envs_to_pause = []
n_envs = envs.num_envs
for i in range(n_envs):
if (
next_episodes[i].scene_id,
next_episodes[i].episode_id,
) in stats_episodes:
envs_to_pause.append(i)
# episode ended
if not_done_masks[i].item() == 0:
episode_stats = dict()
episode_stats["spl"] = infos[i]["spl"]
episode_stats["success"] = int(infos[i]["spl"] > 0)
episode_stats["reward"] = current_episode_reward[i].item()
current_episode_reward[i] = 0
# use scene_id + episode_id as unique id for storing stats
stats_episodes[
(
current_episodes[i].scene_id,
current_episodes[i].episode_id,
)
] = episode_stats
if args.video_option:
generate_video(
args,
rgb_frames[i],
current_episodes[i].episode_id,
cur_ckpt_idx,
infos[i]["spl"],
writer,
)
rgb_frames[i] = []
# episode continues
elif args.video_option:
frame = observations_to_image(observations[i], infos[i])
rgb_frames[i].append(frame)
# pausing envs with no new episode
if len(envs_to_pause) > 0:
state_index = list(range(envs.num_envs))
for idx in reversed(envs_to_pause):
state_index.pop(idx)
envs.pause_at(idx)
# indexing along the batch dimensions
test_recurrent_hidden_states = test_recurrent_hidden_states[
state_index
]
not_done_masks = not_done_masks[state_index]
current_episode_reward = current_episode_reward[state_index]
for k, v in batch.items():
batch[k] = v[state_index]
if args.video_option:
rgb_frames = [rgb_frames[i] for i in state_index]
aggregated_stats = dict()
for stat_key in next(iter(stats_episodes.values())).keys():
aggregated_stats[stat_key] = sum(
[v[stat_key] for v in stats_episodes.values()]
)
num_episodes = len(stats_episodes)
episode_reward_mean = aggregated_stats["reward"] / num_episodes
episode_spl_mean = aggregated_stats["spl"] / num_episodes
episode_success_mean = aggregated_stats["success"] / num_episodes
logger.info("Average episode reward: {:.6f}".format(episode_reward_mean))
logger.info("Average episode success: {:.6f}".format(episode_success_mean))
logger.info("Average episode SPL: {:.6f}".format(episode_spl_mean))
writer.add_scalars(
"eval_reward", {"average reward": episode_reward_mean}, cur_ckpt_idx
)
writer.add_scalars(
"eval_SPL", {"average SPL": episode_spl_mean}, cur_ckpt_idx
)
writer.add_scalars(
"eval_success", {"average success": episode_success_mean}, cur_ckpt_idx
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str)
parser.add_argument("--tracking-model-dir", type=str)
parser.add_argument("--sim-gpu-id", type=int, required=True)
parser.add_argument("--pth-gpu-id", type=int, required=True)
parser.add_argument("--num-processes", type=int, required=True)
parser.add_argument("--hidden-size", type=int, default=512)
parser.add_argument("--count-test-episodes", type=int, default=100)
parser.add_argument(
"--sensors",
type=str,
default="RGB_SENSOR,DEPTH_SENSOR",
help="comma separated string containing different"
"sensors to use, currently 'RGB_SENSOR' and"
"'DEPTH_SENSOR' are supported",
)
parser.add_argument(
"--task-config",
type=str,
default="configs/tasks/pointnav.yaml",
help="path to config yaml containing information about task",
)
parser.add_argument(
"--video-option",
type=str,
default="",
choices=["tensorboard", "disk"],
nargs="*",
help="Options for video output, leave empty for no video. "
"Videos can be saved to disk, uploaded to tensorboard, or both.",
)
parser.add_argument(
"--video-dir", type=str, help="directory for storing videos"
)
parser.add_argument(
"--tensorboard-dir",
type=str,
help="directory for storing tensorboard statistics",
)
args = parser.parse_args()
assert (args.model_path is not None) != (
args.tracking_model_dir is not None
), "Must specify a single model or a directory of models, but not both"
if "tensorboard" in args.video_option:
assert (
args.tensorboard_dir is not None
), "Must specify a tensorboard directory for video display"
if "disk" in args.video_option:
assert (
args.video_dir is not None
), "Must specify a directory for storing videos on disk"
with get_tensorboard_writer(
args.tensorboard_dir, purge_step=0, flush_secs=30
) as writer:
if args.model_path is not None:
# evaluate singe checkpoint
eval_checkpoint(args.model_path, args, writer)
else:
# evaluate multiple checkpoints in order
prev_ckpt_ind = -1
while True:
current_ckpt = None
while current_ckpt is None:
current_ckpt = poll_checkpoint_folder(
args.tracking_model_dir, prev_ckpt_ind
)
time.sleep(2) # sleep for 2 seconds before polling again
logger.warning(
"=============current_ckpt: {}=============".format(
current_ckpt
)
)
prev_ckpt_ind += 1
eval_checkpoint(
current_ckpt, args, writer, cur_ckpt_idx=prev_ckpt_ind
)
if __name__ == "__main__":
main()
...@@ -6,6 +6,5 @@ ...@@ -6,6 +6,5 @@
from habitat_baselines.rl.ppo.policy import Policy from habitat_baselines.rl.ppo.policy import Policy
from habitat_baselines.rl.ppo.ppo import PPO from habitat_baselines.rl.ppo.ppo import PPO
from habitat_baselines.rl.ppo.utils import RolloutStorage
__all__ = ["PPO", "Policy", "RolloutStorage"] __all__ = ["PPO", "Policy"]
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from habitat_baselines.rl.ppo.utils import CategoricalNet, Flatten from habitat_baselines.common.utils import CategoricalNet, Flatten
class Policy(nn.Module): class Policy(nn.Module):
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import time
from collections import deque
from typing import Dict, List
import numpy as np
import torch
from habitat import Config, logger
from habitat.utils.visualizations.utils import observations_to_image
from habitat_baselines.common.base_trainer import BaseRLTrainer
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.env_utils import construct_envs
from habitat_baselines.common.environments import NavRLEnv
from habitat_baselines.common.rollout_storage import RolloutStorage
from habitat_baselines.common.tensorboard_utils import (
TensorboardWriter,
get_tensorboard_writer,
)
from habitat_baselines.common.utils import (
batch_obs,
generate_video,
poll_checkpoint_folder,
update_linear_schedule,
)
from habitat_baselines.rl.ppo import PPO, Policy
@baseline_registry.register_trainer(name="ppo")
class PPOTrainer(BaseRLTrainer):
r"""
Trainer class for PPO algorithm
Paper: https://arxiv.org/abs/1707.06347
"""
supported_tasks = ["Nav-v0"]
def __init__(self, config=None):
super().__init__(config)
self.actor_critic = None
self.agent = None
self.envs = None
self.device = None
self.video_option = []
if config is not None:
logger.info(f"config: {config}")
def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None:
r"""
Sets up actor critic and agent for PPO
Args:
ppo_cfg: config node with relevant params
Returns:
None
"""
logger.add_filehandler(ppo_cfg.log_file)
self.actor_critic = Policy(
observation_space=self.envs.observation_spaces[0],
action_space=self.envs.action_spaces[0],
hidden_size=512,
goal_sensor_uuid=self.config.TASK_CONFIG.TASK.GOAL_SENSOR_UUID,
)
self.actor_critic.to(self.device)
self.agent = PPO(
actor_critic=self.actor_critic,
clip_param=ppo_cfg.clip_param,
ppo_epoch=ppo_cfg.ppo_epoch,
num_mini_batch=ppo_cfg.num_mini_batch,
value_loss_coef=ppo_cfg.value_loss_coef,
entropy_coef=ppo_cfg.entropy_coef,
lr=ppo_cfg.lr,
eps=ppo_cfg.eps,
max_grad_norm=ppo_cfg.max_grad_norm,
)
def save_checkpoint(self, file_name: str) -> None:
r"""
Save checkpoint with specified name
Args:
file_name: file name for checkpoint
Returns:
None
"""
checkpoint = {
"state_dict": self.agent.state_dict(),
"config": self.config,
}
torch.save(
checkpoint,
os.path.join(
self.config.TRAINER.RL.PPO.checkpoint_folder, file_name
),
)
def load_checkpoint(self, checkpoint_path: str, *args, **kwargs) -> Dict:
r"""
Load checkpoint of specified path as a dict
Args:
checkpoint_path: path of target checkpoint
*args: additional positional args
**kwargs: additional keyword args
Returns:
dict containing checkpoint info
"""
return torch.load(checkpoint_path, map_location=self.device)
def train(self) -> None:
r"""
Main method for training PPO
Returns:
None
"""
assert (
self.config is not None
), "trainer is not properly initialized, need to specify config file"
self.envs = construct_envs(self.config, NavRLEnv)
ppo_cfg = self.config.TRAINER.RL.PPO
self.device = torch.device("cuda", ppo_cfg.pth_gpu_id)
if not os.path.isdir(ppo_cfg.checkpoint_folder):
os.makedirs(ppo_cfg.checkpoint_folder)
self._setup_actor_critic_agent(ppo_cfg)
logger.info(
"agent number of parameters: {}".format(
sum(param.numel() for param in self.agent.parameters())
)
)
observations = self.envs.reset()
batch = batch_obs(observations)
rollouts = RolloutStorage(
ppo_cfg.num_steps,
self.envs.num_envs,
self.envs.observation_spaces[0],
self.envs.action_spaces[0],
ppo_cfg.hidden_size,
)
for sensor in rollouts.observations:
rollouts.observations[sensor][0].copy_(batch[sensor])
rollouts.to(self.device)
episode_rewards = torch.zeros(self.envs.num_envs, 1)
episode_counts = torch.zeros(self.envs.num_envs, 1)
current_episode_reward = torch.zeros(self.envs.num_envs, 1)
window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size)
window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size)
t_start = time.time()
env_time = 0
pth_time = 0
count_steps = 0
count_checkpoints = 0
with (
get_tensorboard_writer(
log_dir=ppo_cfg.tensorboard_dir,
purge_step=count_steps,
flush_secs=30,
)
) as writer:
for update in range(ppo_cfg.num_updates):
if ppo_cfg.use_linear_lr_decay:
update_linear_schedule(
self.agent.optimizer,
update,
ppo_cfg.num_updates,
ppo_cfg.lr,
)
if ppo_cfg.use_linear_clip_decay:
self.agent.clip_param = ppo_cfg.clip_param * (
1 - update / ppo_cfg.num_updates
)
for step in range(ppo_cfg.num_steps):
t_sample_action = time.time()
# sample actions
with torch.no_grad():
step_observation = {
k: v[step]
for k, v in rollouts.observations.items()
}
(
values,
actions,
actions_log_probs,
recurrent_hidden_states,
) = self.actor_critic.act(
step_observation,
rollouts.recurrent_hidden_states[step],
rollouts.masks[step],
)
pth_time += time.time() - t_sample_action
t_step_env = time.time()
outputs = self.envs.step([a[0].item() for a in actions])
observations, rewards, dones, infos = [
list(x) for x in zip(*outputs)
]
env_time += time.time() - t_step_env
t_update_stats = time.time()
batch = batch_obs(observations)
rewards = torch.tensor(rewards, dtype=torch.float)
rewards = rewards.unsqueeze(1)
masks = torch.tensor(
[[0.0] if done else [1.0] for done in dones],
dtype=torch.float,
)
current_episode_reward += rewards
episode_rewards += (1 - masks) * current_episode_reward
episode_counts += 1 - masks
current_episode_reward *= masks
rollouts.insert(
batch,
recurrent_hidden_states,
actions,
actions_log_probs,
values,
rewards,
masks,
)
count_steps += self.envs.num_envs
pth_time += time.time() - t_update_stats
window_episode_reward.append(episode_rewards.clone())
window_episode_counts.append(episode_counts.clone())
t_update_model = time.time()
with torch.no_grad():
last_observation = {
k: v[-1] for k, v in rollouts.observations.items()
}
next_value = self.actor_critic.get_value(
last_observation,
rollouts.recurrent_hidden_states[-1],
rollouts.masks[-1],
).detach()
rollouts.compute_returns(
next_value, ppo_cfg.use_gae, ppo_cfg.gamma, ppo_cfg.tau
)
value_loss, action_loss, dist_entropy = self.agent.update(
rollouts
)
rollouts.after_update()
pth_time += time.time() - t_update_model
losses = [value_loss, action_loss]
stats = zip(
["count", "reward"],
[window_episode_counts, window_episode_reward],
)
deltas = {
k: (
(v[-1] - v[0]).sum().item()
if len(v) > 1
else v[0].sum().item()
)
for k, v in stats
}
deltas["count"] = max(deltas["count"], 1.0)
writer.add_scalar(
"reward", deltas["reward"] / deltas["count"], count_steps
)
writer.add_scalars(
"losses",
{k: l for l, k in zip(losses, ["value", "policy"])},
count_steps,
)
# log stats
if update > 0 and update % ppo_cfg.log_interval == 0:
logger.info(
"update: {}\tfps: {:.3f}\t".format(
update, count_steps / (time.time() - t_start)
)
)
logger.info(
"update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
"frames: {}".format(
update, env_time, pth_time, count_steps
)
)
window_rewards = (
window_episode_reward[-1] - window_episode_reward[0]
).sum()
window_counts = (
window_episode_counts[-1] - window_episode_counts[0]
).sum()
if window_counts > 0:
logger.info(
"Average window size {} reward: {:3f}".format(
len(window_episode_reward),
(window_rewards / window_counts).item(),
)
)
else:
logger.info("No episodes finish in current window")
# checkpoint model
if update % ppo_cfg.checkpoint_interval == 0:
self.save_checkpoint(f"ckpt.{count_checkpoints}.pth")
count_checkpoints += 1
def eval(self) -> None:
r"""
Main method of evaluating PPO
Returns:
None
"""
ppo_cfg = self.config.TRAINER.RL.PPO
self.device = torch.device("cuda", ppo_cfg.pth_gpu_id)
self.video_option = ppo_cfg.video_option.strip().split(",")
if "tensorboard" in self.video_option:
assert (
ppo_cfg.tensorboard_dir is not None
), "Must specify a tensorboard directory for video display"
if "disk" in self.video_option:
assert (
ppo_cfg.video_dir is not None
), "Must specify a directory for storing videos on disk"
with get_tensorboard_writer(
ppo_cfg.tensorboard_dir, purge_step=0, flush_secs=30
) as writer:
if os.path.isfile(ppo_cfg.eval_ckpt_path_or_dir):
# evaluate singe checkpoint
self._eval_checkpoint(ppo_cfg.eval_ckpt_path_or_dir, writer)
else:
# evaluate multiple checkpoints in order
prev_ckpt_ind = -1
while True:
current_ckpt = None
while current_ckpt is None:
current_ckpt = poll_checkpoint_folder(
ppo_cfg.eval_ckpt_path_or_dir, prev_ckpt_ind
)
time.sleep(2) # sleep for 2 secs before polling again
logger.warning(
"=============current_ckpt: {}=============".format(
current_ckpt
)
)
prev_ckpt_ind += 1
self._eval_checkpoint(
checkpoint_path=current_ckpt,
writer=writer,
cur_ckpt_idx=prev_ckpt_ind,
)
def _eval_checkpoint(
self,
checkpoint_path: str,
writer: TensorboardWriter,
cur_ckpt_idx: int = 0,
) -> None:
r"""
Evaluates a single checkpoint
Args:
checkpoint_path: path of checkpoint
writer: tensorboard writer object for logging to tensorboard
cur_ckpt_idx: index of cur checkpoint for logging
Returns:
None
"""
ckpt_dict = self.load_checkpoint(
checkpoint_path, map_location=self.device
)
ckpt_config = ckpt_dict["config"]
config = self.config.clone()
ckpt_cmd_opts = ckpt_config.CMD_TRAILING_OPTS
eval_cmd_opts = config.CMD_TRAILING_OPTS
# config merge priority: eval_opts > ckpt_opts > eval_cfg > ckpt_cfg
# first line for old checkpoint compatibility
config.merge_from_other_cfg(ckpt_config)
config.merge_from_other_cfg(self.config)
config.merge_from_list(ckpt_cmd_opts)
config.merge_from_list(eval_cmd_opts)
ppo_cfg = config.TRAINER.RL.PPO
config.TASK_CONFIG.defrost()
config.TASK_CONFIG.DATASET.SPLIT = "val"
agent_sensors = ppo_cfg.sensors.strip().split(",")
config.TASK_CONFIG.SIMULATOR.AGENT_0.SENSORS = agent_sensors
if self.video_option:
config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
config.freeze()
logger.info(f"env config: {config}")
self.envs = construct_envs(config, NavRLEnv)
self._setup_actor_critic_agent(ppo_cfg)
self.agent.load_state_dict(ckpt_dict["state_dict"])
self.actor_critic = self.agent.actor_critic
observations = self.envs.reset()
batch = batch_obs(observations)
for sensor in batch:
batch[sensor] = batch[sensor].to(self.device)
current_episode_reward = torch.zeros(
self.envs.num_envs, 1, device=self.device
)
test_recurrent_hidden_states = torch.zeros(
ppo_cfg.num_processes, ppo_cfg.hidden_size, device=self.device
)
not_done_masks = torch.zeros(
ppo_cfg.num_processes, 1, device=self.device
)
stats_episodes = dict() # dict of dicts that stores stats per episode
rgb_frames = [
[]
] * ppo_cfg.num_processes # type: List[List[np.ndarray]]
if self.video_option:
os.makedirs(ppo_cfg.video_dir, exist_ok=True)
while (
len(stats_episodes) < ppo_cfg.count_test_episodes
and self.envs.num_envs > 0
):
current_episodes = self.envs.current_episodes()
with torch.no_grad():
_, actions, _, test_recurrent_hidden_states = self.actor_critic.act(
batch,
test_recurrent_hidden_states,
not_done_masks,
deterministic=False,
)
outputs = self.envs.step([a[0].item() for a in actions])
observations, rewards, dones, infos = [
list(x) for x in zip(*outputs)
]
batch = batch_obs(observations)
for sensor in batch:
batch[sensor] = batch[sensor].to(self.device)
not_done_masks = torch.tensor(
[[0.0] if done else [1.0] for done in dones],
dtype=torch.float,
device=self.device,
)
rewards = torch.tensor(
rewards, dtype=torch.float, device=self.device
).unsqueeze(1)
current_episode_reward += rewards
next_episodes = self.envs.current_episodes()
envs_to_pause = []
n_envs = self.envs.num_envs
for i in range(n_envs):
if (
next_episodes[i].scene_id,
next_episodes[i].episode_id,
) in stats_episodes:
envs_to_pause.append(i)
# episode ended
if not_done_masks[i].item() == 0:
episode_stats = dict()
episode_stats["spl"] = infos[i]["spl"]
episode_stats["success"] = int(infos[i]["spl"] > 0)
episode_stats["reward"] = current_episode_reward[i].item()
current_episode_reward[i] = 0
# use scene_id + episode_id as unique id for storing stats
stats_episodes[
(
current_episodes[i].scene_id,
current_episodes[i].episode_id,
)
] = episode_stats
if self.video_option:
generate_video(
ppo_cfg,
rgb_frames[i],
current_episodes[i].episode_id,
cur_ckpt_idx,
infos[i]["spl"],
writer,
)
rgb_frames[i] = []
# episode continues
elif self.video_option:
frame = observations_to_image(observations[i], infos[i])
rgb_frames[i].append(frame)
# pausing self.envs with no new episode
if len(envs_to_pause) > 0:
state_index = list(range(self.envs.num_envs))
for idx in reversed(envs_to_pause):
state_index.pop(idx)
self.envs.pause_at(idx)
# indexing along the batch dimensions
test_recurrent_hidden_states = test_recurrent_hidden_states[
state_index
]
not_done_masks = not_done_masks[state_index]
current_episode_reward = current_episode_reward[state_index]
for k, v in batch.items():
batch[k] = v[state_index]
if self.video_option:
rgb_frames = [rgb_frames[i] for i in state_index]
aggregated_stats = dict()
for stat_key in next(iter(stats_episodes.values())).keys():
aggregated_stats[stat_key] = sum(
[v[stat_key] for v in stats_episodes.values()]
)
num_episodes = len(stats_episodes)
episode_reward_mean = aggregated_stats["reward"] / num_episodes
episode_spl_mean = aggregated_stats["spl"] / num_episodes
episode_success_mean = aggregated_stats["success"] / num_episodes
logger.info(
"Average episode reward: {:.6f}".format(episode_reward_mean)
)
logger.info(
"Average episode success: {:.6f}".format(episode_success_mean)
)
logger.info("Average episode SPL: {:.6f}".format(episode_spl_mean))
writer.add_scalars(
"eval_reward",
{"average reward": episode_reward_mean},
cur_ckpt_idx,
)
writer.add_scalars(
"eval_SPL", {"average SPL": episode_spl_mean}, cur_ckpt_idx
)
writer.add_scalars(
"eval_success",
{"average success": episode_success_mean},
cur_ckpt_idx,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment