From 31318f81db05100099cfd308438d5930c3fb6cd2 Mon Sep 17 00:00:00 2001
From: JasonJiazhiZhang <21229070+JasonJiazhiZhang@users.noreply.github.com>
Date: Mon, 22 Jul 2019 16:00:12 -0700
Subject: [PATCH] Refactor baselines: add generic trainer and common utils
 (#153)

* refactor and  add generic trainer class

* fix to pass tests

* change BaseModel to BaseTrainer

* fix tensorboard import causing CI failure

* modify circle-ci test script accordingly

* doc, typing and other changes

* rename BASELINE to TRAINER

* merge from upstream master

* Update Habitat-API to allow for no rendering sensors (#139)

Update Habitat-API to allow for no rendering sensors

* Added installation requirements step for sim installation in CI setup. (#159)

Added installation requirements step for sim installation in CI setup

* move RolloutStorage to utils

* add environments.py

* make ckpt dir if not exist

* small fixes according to comments

* changes according to comments

* update readme

* fix old config compatibility

* fix missed isort lint
---
 .circleci/config.yml                          |   2 +-
 configs/baselines/ppo.yaml                    |   2 +-
 habitat_baselines/README.md                   |  58 +-
 habitat_baselines/__init__.py                 |   5 +
 habitat_baselines/agents/ppo_agents.py        |   2 +-
 habitat_baselines/agents/slam_agents.py       |  18 +-
 habitat_baselines/common/base_trainer.py      |  52 ++
 habitat_baselines/common/baseline_registry.py |  64 ++
 habitat_baselines/common/env_utils.py         | 104 ++++
 habitat_baselines/common/environments.py      |  83 +++
 .../utils.py => common/rollout_storage.py}    | 259 +-------
 .../{ => common}/tensorboard_utils.py         |  52 +-
 habitat_baselines/common/utils.py             | 166 +++++
 habitat_baselines/config/default.py           |  99 ++-
 habitat_baselines/config/pointnav/ppo.yaml    |  39 ++
 .../config/pointnav/ppo_train_test.yaml       |  12 +
 habitat_baselines/evaluate_ppo.py             | 356 -----------
 habitat_baselines/rl/ppo/__init__.py          |   3 +-
 habitat_baselines/rl/ppo/policy.py            |   2 +-
 habitat_baselines/rl/ppo/ppo_trainer.py       | 575 ++++++++++++++++++
 habitat_baselines/run.py                      |  50 ++
 habitat_baselines/train_ppo.py                | 397 ------------
 22 files changed, 1302 insertions(+), 1098 deletions(-)
 create mode 100644 habitat_baselines/common/base_trainer.py
 create mode 100644 habitat_baselines/common/baseline_registry.py
 create mode 100644 habitat_baselines/common/env_utils.py
 create mode 100644 habitat_baselines/common/environments.py
 rename habitat_baselines/{rl/ppo/utils.py => common/rollout_storage.py} (50%)
 rename habitat_baselines/{ => common}/tensorboard_utils.py (80%)
 create mode 100644 habitat_baselines/common/utils.py
 create mode 100644 habitat_baselines/config/pointnav/ppo.yaml
 create mode 100644 habitat_baselines/config/pointnav/ppo_train_test.yaml
 delete mode 100644 habitat_baselines/evaluate_ppo.py
 create mode 100644 habitat_baselines/rl/ppo/ppo_trainer.py
 create mode 100644 habitat_baselines/run.py
 delete mode 100644 habitat_baselines/train_ppo.py

diff --git a/.circleci/config.yml b/.circleci/config.yml
index 93d21ebd7..488f76040 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -196,7 +196,7 @@ jobs:
               . activate habitat; cd habitat-api
               python setup.py test
               python setup.py develop --all
-              python -u habitat_baselines/train_ppo.py --log-file "train.log" --checkpoint-folder "data/checkpoints" --sim-gpu-id 0 --pth-gpu-id 0 --num-processes 1 --num-mini-batch 1 --num-updates 10
+              python -u habitat_baselines/run.py --exp-config habitat_baselines/config/pointnav/ppo_train_test.yaml --run-type train
 
 
 workflows:
diff --git a/configs/baselines/ppo.yaml b/configs/baselines/ppo.yaml
index 2b6390fe3..839df239f 100644
--- a/configs/baselines/ppo.yaml
+++ b/configs/baselines/ppo.yaml
@@ -1,4 +1,4 @@
-BASELINE:
+TRAINER:
   RL:
     SUCCESS_REWARD: 10.0
     SLACK_REWARD: -0.01
diff --git a/habitat_baselines/README.md b/habitat_baselines/README.md
index 9a29cf4d9..014d1d605 100644
--- a/habitat_baselines/README.md
+++ b/habitat_baselines/README.md
@@ -25,46 +25,18 @@ For training on sample data please follow steps in the repository README. You sh
 
 **train**:
 ```bash
-python -u habitat_baselines/train_ppo.py \
-    --use-gae \
-    --sim-gpu-id 0 \
-    --pth-gpu-id 0 \
-    --lr 2.5e-4 \
-    --clip-param 0.1 \
-    --value-loss-coef 0.5 \
-    --num-processes 4 \
-    --num-steps 128 \
-    --num-mini-batch 4 \
-    --num-updates 100000 \
-    --use-linear-lr-decay \
-    --use-linear-clip-decay \
-    --entropy-coef 0.01 \
-    --log-file "train.log" \
-    --log-interval 5 \
-    --checkpoint-folder "data/checkpoints" \
-    --checkpoint-interval 50 \
-    --task-config "configs/tasks/pointnav.yaml" \
-
-
+python -u habitat_baselines/run.py --exp-config habitat_baselines/config/pointnav/ppo.yaml --run-type train
 ```
 
 **test**:
 ```bash
-python -u habitat_baselines/evaluate_ppo.py \
-    --model-path "/path/to/checkpoint" \
-    --sim-gpu-id 0 \
-    --pth-gpu-id 0 \
-    --num-processes 4 \
-    --count-test-episodes 100 \
-    --task-config "configs/tasks/pointnav.yaml" \
-
-
+python -u habitat_baselines/run.py --exp-config habitat_baselines/config/pointnav/ppo.yaml --run-type eval
 ```
 
 We also provide trained RGB, RGBD, Blind PPO models. 
 To use them download pre-trained pytorch models from [link](https://dl.fbaipublicfiles.com/habitat/data/baselines/v1/habitat_baselines_v1.zip) and unzip and specify model path [here](agents/ppo_agents.py#L132).
 
-Set argument `--task-config` to `tasks/pointnav_mp3d.yaml` for training on [MatterPort3D point goal navigation dataset](/README.md#task-datasets).
+Change field `task_config` in `habitat_baselines/config/pointnav/ppo.yaml` to `tasks/pointnav_mp3d.yaml` for training on [MatterPort3D point goal navigation dataset](/README.md#task-datasets).
 
 ### Classic
 
@@ -74,26 +46,16 @@ Set argument `--task-config` to `tasks/pointnav_mp3d.yaml` for training on [Matt
 "Benchmarking Classic and Learned Navigation in Complex 3D Environments".
 ### Additional Utilities
 
-**single-episode training**: 
-Algorithms can be trained with a single-episode option. This option can be used as a sanity check since good algorithms should overfit one episode relatively fast. To enable this option, add `DATASET.NUM_EPISODE_SAMPLE 1` *at the end* of the training command, or include the single-episode yaml file in `--task-config` like this:
-```
-   --task-config "configs/tasks/pointnav.yaml,configs/datasets/single_episode.yaml"
-```
+**Episode iterator options**:
+Coming very soon 
 
-**tensorboard and video generation support**
+**Tensorboard and video generation support**
 
-Enable tensorboard logging by adding `--tensorboard-dir logdir` when running `train_ppo.py` or `evaluate_ppo.py`
+Enable tensorboard by changing `tensorboard_dir` field in `habitat_baselines/config/pointnav/ppo.yaml`. 
 
-Enable video generation for `evaluate_ppo.py` using `--video-option`: specifying `tensorboard`(for displaying on tensorboard) or `disk` (for saving videos on disk), for example:
-```
-python -u habitat_baselines/evaluate_ppo.py   
-...
---count-test-episodes 2 \
---video-option tensorboard \
---tensorboard-dir tb_eval \
---model-path data/checkpoints/ckpt.xx.pth
-```
-The above command should generate navigation episode recordings and display them on tensorboard like this:
+Enable video generation for `eval` mode by changing `video_option`: `tensorboard,disk` (for displaying on tensorboard and for saving videos on disk, respectively)
+
+Generated navigation episode recordings should look like this on tensorboard:
 <p align="center">
   <img src="../res/img/tensorboard_video_demo.gif"  height="500">
 </p>
diff --git a/habitat_baselines/__init__.py b/habitat_baselines/__init__.py
index 240697e32..4fd2c3c3c 100644
--- a/habitat_baselines/__init__.py
+++ b/habitat_baselines/__init__.py
@@ -3,3 +3,8 @@
 # Copyright (c) Facebook, Inc. and its affiliates.
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
+
+from habitat_baselines.common.base_trainer import BaseRLTrainer, BaseTrainer
+from habitat_baselines.rl.ppo.ppo_trainer import PPOTrainer, RolloutStorage
+
+__all__ = ["BaseTrainer", "BaseRLTrainer", "PPOTrainer", "RolloutStorage"]
diff --git a/habitat_baselines/agents/ppo_agents.py b/habitat_baselines/agents/ppo_agents.py
index 49e4dd6f8..919298ffb 100644
--- a/habitat_baselines/agents/ppo_agents.py
+++ b/habitat_baselines/agents/ppo_agents.py
@@ -16,8 +16,8 @@ import habitat
 from habitat.config import Config
 from habitat.config.default import get_config
 from habitat.core.agent import Agent
+from habitat_baselines.common.utils import batch_obs
 from habitat_baselines.rl.ppo import Policy
-from habitat_baselines.rl.ppo.utils import batch_obs
 
 
 def get_default_config():
diff --git a/habitat_baselines/agents/slam_agents.py b/habitat_baselines/agents/slam_agents.py
index 9d7b291a5..59e2b73d3 100644
--- a/habitat_baselines/agents/slam_agents.py
+++ b/habitat_baselines/agents/slam_agents.py
@@ -70,16 +70,16 @@ def make_good_config_for_orbslam2(config):
     config.SIMULATOR.RGB_SENSOR.HEIGHT = 256
     config.SIMULATOR.DEPTH_SENSOR.WIDTH = 256
     config.SIMULATOR.DEPTH_SENSOR.HEIGHT = 256
-    config.BASELINE.ORBSLAM2.CAMERA_HEIGHT = config.SIMULATOR.DEPTH_SENSOR.POSITION[
+    config.TRAINER.ORBSLAM2.CAMERA_HEIGHT = config.SIMULATOR.DEPTH_SENSOR.POSITION[
         1
     ]
-    config.BASELINE.ORBSLAM2.H_OBSTACLE_MIN = (
-        0.3 * config.BASELINE.ORBSLAM2.CAMERA_HEIGHT
+    config.TRAINER.ORBSLAM2.H_OBSTACLE_MIN = (
+        0.3 * config.TRAINER.ORBSLAM2.CAMERA_HEIGHT
     )
-    config.BASELINE.ORBSLAM2.H_OBSTACLE_MAX = (
-        1.0 * config.BASELINE.ORBSLAM2.CAMERA_HEIGHT
+    config.TRAINER.ORBSLAM2.H_OBSTACLE_MAX = (
+        1.0 * config.TRAINER.ORBSLAM2.CAMERA_HEIGHT
     )
-    config.BASELINE.ORBSLAM2.MIN_PTS_IN_OBSTACLE = (
+    config.TRAINER.ORBSLAM2.MIN_PTS_IN_OBSTACLE = (
         config.SIMULATOR.DEPTH_SENSOR.WIDTH / 2.0
     )
     return
@@ -607,11 +607,11 @@ def main():
     make_good_config_for_orbslam2(config)
 
     if args.agent_type == "blind":
-        agent = BlindAgent(config.BASELINE.ORBSLAM2)
+        agent = BlindAgent(config.TRAINER.ORBSLAM2)
     elif args.agent_type == "orbslam2-rgbd":
-        agent = ORBSLAM2Agent(config.BASELINE.ORBSLAM2)
+        agent = ORBSLAM2Agent(config.TRAINER.ORBSLAM2)
     elif args.agent_type == "orbslam2-rgb-monod":
-        agent = ORBSLAM2MonodepthAgent(config.BASELINE.ORBSLAM2)
+        agent = ORBSLAM2MonodepthAgent(config.TRAINER.ORBSLAM2)
     else:
         raise ValueError(args.agent_type, "is unknown type of agent")
     benchmark = habitat.Benchmark(args.task_config)
diff --git a/habitat_baselines/common/base_trainer.py b/habitat_baselines/common/base_trainer.py
new file mode 100644
index 000000000..cc8ffed86
--- /dev/null
+++ b/habitat_baselines/common/base_trainer.py
@@ -0,0 +1,52 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import ClassVar, Dict, List
+
+
+class BaseTrainer:
+    """
+    Most generic trainer class that serves as a base template for more
+    specific trainer classes like RL trainer, SLAM or imitation learner.
+    Includes only the most basic functionality.
+    """
+
+    supported_tasks: ClassVar[List[str]]
+
+    def train(self) -> None:
+        raise NotImplementedError
+
+    def eval(self) -> None:
+        raise NotImplementedError
+
+    def save_checkpoint(self, file_name) -> None:
+        raise NotImplementedError
+
+    def load_checkpoint(self, checkpoint_path, *args, **kwargs) -> Dict:
+        raise NotImplementedError
+
+
+class BaseRLTrainer(BaseTrainer):
+    """
+    Base trainer class for RL based trainers. Future RL-specific
+    methods should be hosted here.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+
+    def train(self) -> None:
+        raise NotImplementedError
+
+    def eval(self) -> None:
+        raise NotImplementedError
+
+    def save_checkpoint(self, file_name) -> None:
+        raise NotImplementedError
+
+    def load_checkpoint(self, checkpoint_path, *args, **kwargs) -> Dict:
+        raise NotImplementedError
diff --git a/habitat_baselines/common/baseline_registry.py b/habitat_baselines/common/baseline_registry.py
new file mode 100644
index 000000000..fe54971f9
--- /dev/null
+++ b/habitat_baselines/common/baseline_registry.py
@@ -0,0 +1,64 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+r"""BaselineRegistry is extended from habitat.Registry to provide
+registration for trainer and environments, while keeping Registry
+in habitat core intact.
+
+Import the baseline registry object using
+
+``from habitat_baselines.common.baseline_registry import baseline_registry``
+
+Various decorators for registry different kind of classes with unique keys
+
+- Register a environment: ``@registry.register_env``
+- Register a trainer: ``@registry.register_trainer``
+"""
+
+from typing import Optional
+
+from habitat.core.registry import Registry
+
+
+class BaselineRegistry(Registry):
+    @classmethod
+    def register_trainer(cls, to_register=None, *, name: Optional[str] = None):
+        r"""Register a RL training algorithm to registry with key 'name'.
+
+        Args:
+            name: Key with which the trainer will be registered.
+                If None will use the name of the class.
+
+        """
+        from habitat_baselines.common.base_trainer import BaseTrainer
+
+        return cls._register_impl(
+            "trainer", to_register, name, assert_type=BaseTrainer
+        )
+
+    @classmethod
+    def get_trainer(cls, name):
+        return cls._get_impl("trainer", name)
+
+    @classmethod
+    def register_env(cls, to_register=None, *, name: Optional[str] = None):
+        r"""Register a environment to registry with key 'name'
+            currently only support subclass of RLEnv.
+        Args:
+            name: Key with which the env will be registered.
+                If None will use the name of the class.
+
+        """
+        from habitat import RLEnv
+
+        return cls._register_impl("env", to_register, name, assert_type=RLEnv)
+
+    @classmethod
+    def get_env(cls, name):
+        return cls._get_impl("env", name)
+
+
+baseline_registry = BaselineRegistry()
diff --git a/habitat_baselines/common/env_utils.py b/habitat_baselines/common/env_utils.py
new file mode 100644
index 000000000..89d3d4c2a
--- /dev/null
+++ b/habitat_baselines/common/env_utils.py
@@ -0,0 +1,104 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+from typing import Type
+
+import habitat
+from habitat import Config, Env, VectorEnv, make_dataset
+
+
+def make_env_fn(
+    task_config: Config, rl_env_config: Config, env_class: Type, rank: int
+) -> Env:
+    r"""
+    Creates an env of type env_class with specified config and rank.
+    This is to be passed in as an argument when creating VectorEnv.
+    Args:
+        task_config: task config file for creating env.
+        rl_env_config: RL env config for creating env.
+        env_class: class type of the env to be created.
+        rank: rank of env to be created (for seeding).
+
+    Returns:
+        env object created according to specification.
+    """
+    dataset = make_dataset(
+        task_config.DATASET.TYPE, config=task_config.DATASET
+    )
+    env = env_class(
+        config_env=task_config, config_baseline=rl_env_config, dataset=dataset
+    )
+    env.seed(rank)
+    return env
+
+
+def construct_envs(config: Config, env_class: Type) -> VectorEnv:
+    r"""
+    Create VectorEnv object with specified config and env class type.
+    To allow better performance, dataset are split into small ones for
+    each individual env, grouped by scenes.
+    Args:
+        config: configs that contain num_processes as well as information
+        necessary to create individual environments.
+        env_class: class type of the envs to be created.
+
+    Returns:
+        VectorEnv object created according to specification.
+    """
+    trainer_config = config.TRAINER.RL.PPO
+    rl_env_config = config.TRAINER.RL
+    task_config = config.TASK_CONFIG  # excluding trainer-specific configs
+    env_configs, rl_env_configs = [], []
+    env_classes = [env_class for _ in range(trainer_config.num_processes)]
+    dataset = make_dataset(task_config.DATASET.TYPE)
+    scenes = dataset.get_scenes_to_load(task_config.DATASET)
+
+    if len(scenes) > 0:
+        random.shuffle(scenes)
+
+        assert len(scenes) >= trainer_config.num_processes, (
+            "reduce the number of processes as there "
+            "aren't enough number of scenes"
+        )
+
+    scene_splits = [[] for _ in range(trainer_config.num_processes)]
+    for idx, scene in enumerate(scenes):
+        scene_splits[idx % len(scene_splits)].append(scene)
+
+    assert sum(map(len, scene_splits)) == len(scenes)
+
+    for i in range(trainer_config.num_processes):
+
+        env_config = task_config.clone()
+        env_config.defrost()
+        if len(scenes) > 0:
+            env_config.DATASET.CONTENT_SCENES = scene_splits[i]
+
+        env_config.SIMULATOR.HABITAT_SIM_V0.GPU_DEVICE_ID = (
+            trainer_config.sim_gpu_id
+        )
+
+        agent_sensors = trainer_config.sensors.strip().split(",")
+        env_config.SIMULATOR.AGENT_0.SENSORS = agent_sensors
+        env_config.freeze()
+        env_configs.append(env_config)
+        rl_env_configs.append(rl_env_config)
+
+    envs = habitat.VectorEnv(
+        make_env_fn=make_env_fn,
+        env_fn_args=tuple(
+            tuple(
+                zip(
+                    env_configs,
+                    rl_env_configs,
+                    env_classes,
+                    range(trainer_config.num_processes),
+                )
+            )
+        ),
+    )
+    return envs
diff --git a/habitat_baselines/common/environments.py b/habitat_baselines/common/environments.py
new file mode 100644
index 000000000..e542ae695
--- /dev/null
+++ b/habitat_baselines/common/environments.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+r"""
+This file hosts task-specific or trainer-specific environments for trainers.
+All environments here should be a (direct or indirect ) subclass of Env class
+in habitat. Customized environments should be registered using
+``@baseline_registry.register_env(name="myEnv")` for reusability
+"""
+
+import habitat
+from habitat import SimulatorActions
+from habitat_baselines.common.baseline_registry import baseline_registry
+
+
+@baseline_registry.register_env(name="NavRLEnv")
+class NavRLEnv(habitat.RLEnv):
+    def __init__(self, config_env, config_baseline, dataset):
+        self._config_env = config_env.TASK
+        self._config_baseline = config_baseline
+        self._previous_target_distance = None
+        self._previous_action = None
+        self._episode_distance_covered = None
+        super().__init__(config_env, dataset)
+
+    def reset(self):
+        self._previous_action = None
+
+        observations = super().reset()
+
+        self._previous_target_distance = self.habitat_env.current_episode.info[
+            "geodesic_distance"
+        ]
+        return observations
+
+    def step(self, action):
+        self._previous_action = action
+        return super().step(action)
+
+    def get_reward_range(self):
+        return (
+            self._config_baseline.SLACK_REWARD - 1.0,
+            self._config_baseline.SUCCESS_REWARD + 1.0,
+        )
+
+    def get_reward(self, observations):
+        reward = self._config_baseline.SLACK_REWARD
+
+        current_target_distance = self._distance_target()
+        reward += self._previous_target_distance - current_target_distance
+        self._previous_target_distance = current_target_distance
+
+        if self._episode_success():
+            reward += self._config_baseline.SUCCESS_REWARD
+
+        return reward
+
+    def _distance_target(self):
+        current_position = self._env.sim.get_agent_state().position.tolist()
+        target_position = self._env.current_episode.goals[0].position
+        distance = self._env.sim.geodesic_distance(
+            current_position, target_position
+        )
+        return distance
+
+    def _episode_success(self):
+        if (
+            self._previous_action == SimulatorActions.STOP
+            and self._distance_target() < self._config_env.SUCCESS_DISTANCE
+        ):
+            return True
+        return False
+
+    def get_done(self, observations):
+        done = False
+        if self._env.episode_over or self._episode_success():
+            done = True
+        return done
+
+    def get_info(self, observations):
+        return self.habitat_env.get_metrics()
diff --git a/habitat_baselines/rl/ppo/utils.py b/habitat_baselines/common/rollout_storage.py
similarity index 50%
rename from habitat_baselines/rl/ppo/utils.py
rename to habitat_baselines/common/rollout_storage.py
index 8cd0e8709..908e50b63 100644
--- a/habitat_baselines/rl/ppo/utils.py
+++ b/habitat_baselines/common/rollout_storage.py
@@ -4,63 +4,16 @@
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
 
-import argparse
 from collections import defaultdict
 
-import numpy as np
 import torch
-import torch.nn as nn
 
 
-class Flatten(nn.Module):
-    def forward(self, x):
-        return x.view(x.size(0), -1)
-
-
-class CustomFixedCategorical(torch.distributions.Categorical):
-    def sample(self, sample_shape=torch.Size()):
-        return super().sample(sample_shape).unsqueeze(-1)
-
-    def log_probs(self, actions):
-        return (
-            super()
-            .log_prob(actions.squeeze(-1))
-            .view(actions.size(0), -1)
-            .sum(-1)
-            .unsqueeze(-1)
-        )
-
-    def mode(self):
-        return self.probs.argmax(dim=-1, keepdim=True)
-
-
-class CategoricalNet(nn.Module):
-    def __init__(self, num_inputs, num_outputs):
-        super().__init__()
-
-        self.linear = nn.Linear(num_inputs, num_outputs)
-
-        nn.init.orthogonal_(self.linear.weight, gain=0.01)
-        nn.init.constant_(self.linear.bias, 0)
-
-    def forward(self, x):
-        x = self.linear(x)
-        return CustomFixedCategorical(logits=x)
-
-
-def _flatten_helper(t, n, tensor):
-    return tensor.view(t * n, *tensor.size()[2:])
-
-
-def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
-    r"""Decreases the learning rate linearly
+class RolloutStorage:
+    r"""
+    Class for storing rollout information for RL trainers
     """
-    lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
-    for param_group in optimizer.param_groups:
-        param_group["lr"] = lr
-
 
-class RolloutStorage:
     def __init__(
         self,
         num_steps,
@@ -75,7 +28,7 @@ class RolloutStorage:
             self.observations[sensor] = torch.zeros(
                 num_steps + 1,
                 num_envs,
-                *observation_space.spaces[sensor].shape
+                *observation_space.spaces[sensor].shape,
             )
 
         self.recurrent_hidden_states = torch.zeros(
@@ -168,9 +121,9 @@ class RolloutStorage:
     def recurrent_generator(self, advantages, num_mini_batch):
         num_processes = self.rewards.size(1)
         assert num_processes >= num_mini_batch, (
-            "PPO requires the number of processes ({}) "
+            "Trainer requires the number of processes ({}) "
             "to be greater than or equal to the number of "
-            "PPO mini batches ({}).".format(num_processes, num_mini_batch)
+            "trainer mini batches ({}).".format(num_processes, num_mini_batch)
         )
         num_envs_per_batch = num_processes // num_mini_batch
         perm = torch.randperm(num_processes)
@@ -231,18 +184,18 @@ class RolloutStorage:
 
             # Flatten the (T, N, ...) tensors to (T * N, ...)
             for sensor in observations_batch:
-                observations_batch[sensor] = _flatten_helper(
+                observations_batch[sensor] = self._flatten_helper(
                     T, N, observations_batch[sensor]
                 )
 
-            actions_batch = _flatten_helper(T, N, actions_batch)
-            value_preds_batch = _flatten_helper(T, N, value_preds_batch)
-            return_batch = _flatten_helper(T, N, return_batch)
-            masks_batch = _flatten_helper(T, N, masks_batch)
-            old_action_log_probs_batch = _flatten_helper(
+            actions_batch = self._flatten_helper(T, N, actions_batch)
+            value_preds_batch = self._flatten_helper(T, N, value_preds_batch)
+            return_batch = self._flatten_helper(T, N, return_batch)
+            masks_batch = self._flatten_helper(T, N, masks_batch)
+            old_action_log_probs_batch = self._flatten_helper(
                 T, N, old_action_log_probs_batch
             )
-            adv_targ = _flatten_helper(T, N, adv_targ)
+            adv_targ = self._flatten_helper(T, N, adv_targ)
 
             yield (
                 observations_batch,
@@ -255,176 +208,16 @@ class RolloutStorage:
                 adv_targ,
             )
 
-
-def batch_obs(observations):
-    batch = defaultdict(list)
-
-    for obs in observations:
-        for sensor in obs:
-            batch[sensor].append(obs[sensor])
-
-    for sensor in batch:
-        batch[sensor] = torch.tensor(
-            np.array(batch[sensor]), dtype=torch.float
-        )
-    return batch
-
-
-def ppo_args():
-    parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--clip-param",
-        type=float,
-        default=0.2,
-        help="ppo clip parameter (default: 0.2)",
-    )
-    parser.add_argument(
-        "--ppo-epoch",
-        type=int,
-        default=4,
-        help="number of ppo epochs (default: 4)",
-    )
-    parser.add_argument(
-        "--num-mini-batch",
-        type=int,
-        default=32,
-        help="number of batches for ppo (default: 32)",
-    )
-    parser.add_argument(
-        "--value-loss-coef",
-        type=float,
-        default=0.5,
-        help="value loss coefficient (default: 0.5)",
-    )
-    parser.add_argument(
-        "--entropy-coef",
-        type=float,
-        default=0.01,
-        help="entropy term coefficient (default: 0.01)",
-    )
-    parser.add_argument(
-        "--lr", type=float, default=7e-4, help="learning rate (default: 7e-4)"
-    )
-    parser.add_argument(
-        "--eps",
-        type=float,
-        default=1e-5,
-        help="RMSprop optimizer epsilon (default: 1e-5)",
-    )
-    parser.add_argument(
-        "--max-grad-norm",
-        type=float,
-        default=0.5,
-        help="max norm of gradients (default: 0.5)",
-    )
-    parser.add_argument(
-        "--num-steps",
-        type=int,
-        default=5,
-        help="number of forward steps in A2C (default: 5)",
-    )
-    parser.add_argument("--hidden-size", type=int, default=512)
-    parser.add_argument(
-        "--num-processes",
-        type=int,
-        default=16,
-        help="number of training processes " "to use (default: 16)",
-    )
-    parser.add_argument(
-        "--use-gae",
-        action="store_true",
-        default=False,
-        help="use generalized advantage estimation",
-    )
-    parser.add_argument(
-        "--use-linear-lr-decay",
-        action="store_true",
-        default=False,
-        help="use a linear schedule on the learning rate",
-    )
-    parser.add_argument(
-        "--use-linear-clip-decay",
-        action="store_true",
-        default=False,
-        help="use a linear schedule on the " "ppo clipping parameter",
-    )
-    parser.add_argument(
-        "--gamma",
-        type=float,
-        default=0.99,
-        help="discount factor for rewards (default: 0.99)",
-    )
-    parser.add_argument(
-        "--tau", type=float, default=0.95, help="gae parameter (default: 0.95)"
-    )
-    parser.add_argument(
-        "--log-file", type=str, required=True, help="path for log file"
-    )
-    parser.add_argument(
-        "--reward-window-size",
-        type=int,
-        default=50,
-        help="logging window for rewards",
-    )
-    parser.add_argument(
-        "--log-interval",
-        type=int,
-        default=1,
-        help="number of updates after which metrics are logged",
-    )
-    parser.add_argument(
-        "--checkpoint-interval",
-        type=int,
-        default=50,
-        help="number of updates after which models are checkpointed",
-    )
-    parser.add_argument(
-        "--checkpoint-folder",
-        type=str,
-        required=True,
-        help="folder for storing checkpoints",
-    )
-    parser.add_argument(
-        "--sim-gpu-id",
-        type=int,
-        required=True,
-        help="gpu id on which scenes are loaded",
-    )
-    parser.add_argument(
-        "--pth-gpu-id",
-        type=int,
-        required=True,
-        help="gpu id on which pytorch runs",
-    )
-    parser.add_argument(
-        "--num-updates",
-        type=int,
-        default=10000,
-        help="number of PPO updates to run",
-    )
-    parser.add_argument(
-        "--sensors",
-        type=str,
-        default="RGB_SENSOR,DEPTH_SENSOR",
-        help="comma separated string containing different sensors to use,"
-        "currently 'RGB_SENSOR' and 'DEPTH_SENSOR' are supported",
-    )
-    parser.add_argument(
-        "--task-config",
-        type=str,
-        default="configs/tasks/pointnav.yaml",
-        help="path to config yaml containing information about task",
-    )
-    parser.add_argument("--seed", type=int, default=100)
-    parser.add_argument(
-        "opts",
-        default=None,
-        nargs=argparse.REMAINDER,
-        help="Modify config options from command line",
-    )
-    parser.add_argument(
-        "--tensorboard-dir",
-        type=str,
-        help="path to tensorboard logging directory",
-    )
-    return parser
+    @staticmethod
+    def _flatten_helper(t: int, n: int, tensor: torch.Tensor) -> torch.Tensor:
+        r"""
+        Given a tensor of size (t, n, ..), flatten it to size (t*n, ...).
+        Args:
+            t: first dimension of tensor.
+            n: second dimension of tensor.
+            tensor: target tensor to be flattened.
+
+        Returns:
+            flattened tensor of size (t*n, ...)
+        """
+        return tensor.view(t * n, *tensor.size()[2:])
diff --git a/habitat_baselines/tensorboard_utils.py b/habitat_baselines/common/tensorboard_utils.py
similarity index 80%
rename from habitat_baselines/tensorboard_utils.py
rename to habitat_baselines/common/tensorboard_utils.py
index 7449dd64b..ec5799de4 100644
--- a/habitat_baselines/tensorboard_utils.py
+++ b/habitat_baselines/common/tensorboard_utils.py
@@ -1,8 +1,37 @@
-from typing import Optional, Union
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Union
 
 import numpy as np
 import torch
-from torch.utils.tensorboard import SummaryWriter
+
+
+# TODO Add checks to replace DummyWriter
+class DummyWriter:
+    def __init__(self, *args, **kwargs):
+        pass
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        pass
+
+    def close(self):
+        pass
+
+    def __getattr__(self, item):
+        return lambda *args, **kwargs: None
+
+
+try:
+    from torch.utils.tensorboard import SummaryWriter
+except ImportError:
+    SummaryWriter = DummyWriter
 
 
 class TensorboardWriter(SummaryWriter):
@@ -33,23 +62,6 @@ class TensorboardWriter(SummaryWriter):
         self.add_video(video_name, video_tensor, fps=fps, global_step=step_idx)
 
 
-class DummyWriter:
-    def __init__(self, *args, **kwargs):
-        pass
-
-    def __enter__(self):
-        return self
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        pass
-
-    def close(self):
-        pass
-
-    def __getattr__(self, item):
-        return lambda *args, **kwargs: None
-
-
 def get_tensorboard_writer(
     log_dir: str, *args, **kwargs
 ) -> Union[DummyWriter, TensorboardWriter]:
@@ -62,7 +74,7 @@ def get_tensorboard_writer(
         **kwargs: additional keyword args.
 
     Returns:
-        Either the created tensorboard writer or a dummy writer.
+        either the created tensorboard writer or a dummy writer.
     """
     if log_dir:
         return TensorboardWriter(log_dir, *args, **kwargs)
diff --git a/habitat_baselines/common/utils.py b/habitat_baselines/common/utils.py
new file mode 100644
index 000000000..54fb6a38f
--- /dev/null
+++ b/habitat_baselines/common/utils.py
@@ -0,0 +1,166 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import glob
+import os
+from collections import defaultdict
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from habitat import Config
+from habitat.utils.visualizations.utils import images_to_video
+from habitat_baselines import BaseTrainer
+from habitat_baselines.common.baseline_registry import baseline_registry
+from habitat_baselines.common.tensorboard_utils import (
+    DummyWriter,
+    TensorboardWriter,
+)
+
+# TODO distribute utilities in this file to separate files
+
+
+def get_trainer(trainer_name: str, trainer_cfg: Config) -> BaseTrainer:
+    r"""
+    Create specific trainer instance according to name.
+    Args:
+        trainer_name: name of registered trainer .
+        trainer_cfg: config file for trainer.
+
+    Returns:
+        an instance of the specified trainer.
+    """
+    trainer = baseline_registry.get_trainer(trainer_name)
+    assert trainer is not None, f"{trainer_name} is not supported"
+    return trainer(trainer_cfg)
+
+
+class Flatten(nn.Module):
+    def forward(self, x):
+        return x.view(x.size(0), -1)
+
+
+class CustomFixedCategorical(torch.distributions.Categorical):
+    def sample(self, sample_shape=torch.Size()):
+        return super().sample(sample_shape).unsqueeze(-1)
+
+    def log_probs(self, actions):
+        return (
+            super()
+            .log_prob(actions.squeeze(-1))
+            .view(actions.size(0), -1)
+            .sum(-1)
+            .unsqueeze(-1)
+        )
+
+    def mode(self):
+        return self.probs.argmax(dim=-1, keepdim=True)
+
+
+class CategoricalNet(nn.Module):
+    def __init__(self, num_inputs, num_outputs):
+        super().__init__()
+
+        self.linear = nn.Linear(num_inputs, num_outputs)
+
+        nn.init.orthogonal_(self.linear.weight, gain=0.01)
+        nn.init.constant_(self.linear.bias, 0)
+
+    def forward(self, x):
+        x = self.linear(x)
+        return CustomFixedCategorical(logits=x)
+
+
+# TODO make this a  LRScheduler class
+def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
+    r"""Decreases the learning rate linearly
+    """
+    lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
+    for param_group in optimizer.param_groups:
+        param_group["lr"] = lr
+
+
+def batch_obs(observations: List[Dict]) -> Dict:
+    r"""
+    Transpose a batch of observation dicts to a dict of batched
+    observations.
+    Args:
+        observations:  list of dicts of observations.
+
+    Returns:
+        transposed dict of lists of observations.
+    """
+    batch = defaultdict(list)
+
+    for obs in observations:
+        for sensor in obs:
+            batch[sensor].append(obs[sensor])
+
+    for sensor in batch:
+        batch[sensor] = torch.tensor(
+            np.array(batch[sensor]), dtype=torch.float
+        )
+    return batch
+
+
+def poll_checkpoint_folder(
+    checkpoint_folder: str, previous_ckpt_ind: int
+) -> Optional[str]:
+    r""" Return (previous_ckpt_ind + 1)th checkpoint in checkpoint folder
+    (sorted by time of last modification).
+
+    Args:
+        checkpoint_folder: directory to look for checkpoints.
+        previous_ckpt_ind: index of checkpoint last returned.
+
+    Returns:
+        return checkpoint path if (previous_ckpt_ind + 1)th checkpoint is found
+        else return None.
+    """
+    assert os.path.isdir(checkpoint_folder), "invalid checkpoint folder path"
+    models_paths = list(
+        filter(os.path.isfile, glob.glob(checkpoint_folder + "/*"))
+    )
+    models_paths.sort(key=os.path.getmtime)
+    ind = previous_ckpt_ind + 1
+    if ind < len(models_paths):
+        return models_paths[ind]
+    return None
+
+
+def generate_video(
+    config: Config,
+    images: List[np.ndarray],
+    episode_id: int,
+    checkpoint_idx: int,
+    spl: float,
+    tb_writer: Union[DummyWriter, TensorboardWriter],
+    fps: int = 10,
+) -> None:
+    r"""
+    Generate video according to specified information.
+    Args:
+        config: config object that contains video_option and video_dir.
+        images: list of images to be converted to video.
+        episode_id: episode id for video naming.
+        checkpoint_idx: checkpoint index for video naming.
+        spl: SPL for this episode for video naming.
+        tb_writer: tensorboard writer object for uploading video
+        fps: fps for generated video
+    Returns:
+        None
+    """
+    if config.video_option and len(images) > 0:
+        video_name = f"episode{episode_id}_ckpt{checkpoint_idx}_spl{spl:.2f}"
+        if "disk" in config.video_option:
+            images_to_video(images, config.video_dir, video_name)
+        if "tensorboard" in config.video_option:
+            tb_writer.add_video_from_np_images(
+                f"episode{episode_id}", checkpoint_idx, images, fps=fps
+            )
diff --git a/habitat_baselines/config/default.py b/habitat_baselines/config/default.py
index c7506e9e1..1de2f7c15 100644
--- a/habitat_baselines/config/default.py
+++ b/habitat_baselines/config/default.py
@@ -8,7 +8,7 @@ from typing import List, Optional, Union
 
 import numpy as np
 
-from habitat import get_config
+from habitat import get_config as get_task_config
 from habitat.config import Config as CN
 
 DEFAULT_CONFIG_DIR = "configs/"
@@ -17,49 +17,88 @@ CONFIG_FILE_SEPARATOR = ","
 # Config definition
 # -----------------------------------------------------------------------------
 _C = CN()
-_C.SEED = 100
+_C.BASE_TASK_CONFIG_PATH = "configs/tasks/pointnav.yaml"
+_C.TASK_CONFIG = CN()  # task_config will be stored as a config node
+_C.CMD_TRAILING_OPTS = ""  # store command line options"
 # -----------------------------------------------------------------------------
-# BASELINE
+# TRAINER ALGORITHMS
 # -----------------------------------------------------------------------------
-_C.BASELINE = CN()
+_C.TRAINER = CN()
+_C.TRAINER.TRAINER_NAME = "ppo"
 # -----------------------------------------------------------------------------
 # REINFORCEMENT LEARNING (RL)
 # -----------------------------------------------------------------------------
-_C.BASELINE.RL = CN()
-_C.BASELINE.RL.SUCCESS_REWARD = 10.0
-_C.BASELINE.RL.SLACK_REWARD = -0.01
+_C.TRAINER.RL = CN()
+_C.TRAINER.RL.SUCCESS_REWARD = 10.0
+_C.TRAINER.RL.SLACK_REWARD = -0.01
+# -----------------------------------------------------------------------------
+# PROXIMAL POLICY OPTIMIZATION (PPO)
+# -----------------------------------------------------------------------------
+# TODO move general options out of PPO
+_C.TRAINER.RL.PPO = CN()
+_C.TRAINER.RL.PPO.clip_param = 0.2
+_C.TRAINER.RL.PPO.ppo_epoch = 4
+_C.TRAINER.RL.PPO.num_mini_batch = 16
+_C.TRAINER.RL.PPO.value_loss_coef = 0.5
+_C.TRAINER.RL.PPO.entropy_coef = 0.01
+_C.TRAINER.RL.PPO.lr = 7e-4
+_C.TRAINER.RL.PPO.eps = 1e-5
+_C.TRAINER.RL.PPO.max_grad_norm = 0.5
+_C.TRAINER.RL.PPO.num_steps = 5
+_C.TRAINER.RL.PPO.hidden_size = 512
+_C.TRAINER.RL.PPO.num_processes = 16
+_C.TRAINER.RL.PPO.use_gae = True
+_C.TRAINER.RL.PPO.use_linear_lr_decay = False
+_C.TRAINER.RL.PPO.use_linear_clip_decay = False
+_C.TRAINER.RL.PPO.gamma = 0.99
+_C.TRAINER.RL.PPO.tau = 0.95
+_C.TRAINER.RL.PPO.log_file = "train.log"
+_C.TRAINER.RL.PPO.reward_window_size = 50
+_C.TRAINER.RL.PPO.log_interval = 50
+_C.TRAINER.RL.PPO.checkpoint_interval = 50
+_C.TRAINER.RL.PPO.checkpoint_folder = "data/checkpoints"
+_C.TRAINER.RL.PPO.sim_gpu_id = 0
+_C.TRAINER.RL.PPO.pth_gpu_id = 0
+_C.TRAINER.RL.PPO.num_updates = 10000
+_C.TRAINER.RL.PPO.sensors = "RGB_SENSOR,DEPTH_SENSOR"
+_C.TRAINER.RL.PPO.task_config = "configs/tasks/pointnav.yaml"
+_C.TRAINER.RL.PPO.tensorboard_dir = "tb"
+_C.TRAINER.RL.PPO.count_test_episodes = 2
+_C.TRAINER.RL.PPO.video_option = "disk,tensorboard"
+_C.TRAINER.RL.PPO.video_dir = "video_dir"
+_C.TRAINER.RL.PPO.eval_ckpt_path_or_dir = "data/checkpoints"
 # -----------------------------------------------------------------------------
 # ORBSLAM2 BASELINE
 # -----------------------------------------------------------------------------
-_C.BASELINE.ORBSLAM2 = CN()
-_C.BASELINE.ORBSLAM2.SLAM_VOCAB_PATH = (
+_C.TRAINER.ORBSLAM2 = CN()
+_C.TRAINER.ORBSLAM2.SLAM_VOCAB_PATH = (
     "habitat_baselines/slambased/data/ORBvoc.txt"
 )
-_C.BASELINE.ORBSLAM2.SLAM_SETTINGS_PATH = (
+_C.TRAINER.ORBSLAM2.SLAM_SETTINGS_PATH = (
     "habitat_baselines/slambased/data/mp3d3_small1k.yaml"
 )
-_C.BASELINE.ORBSLAM2.MAP_CELL_SIZE = 0.1
-_C.BASELINE.ORBSLAM2.MAP_SIZE = 40
-_C.BASELINE.ORBSLAM2.CAMERA_HEIGHT = get_config().SIMULATOR.DEPTH_SENSOR.POSITION[
+_C.TRAINER.ORBSLAM2.MAP_CELL_SIZE = 0.1
+_C.TRAINER.ORBSLAM2.MAP_SIZE = 40
+_C.TRAINER.ORBSLAM2.CAMERA_HEIGHT = get_task_config().SIMULATOR.DEPTH_SENSOR.POSITION[
     1
 ]
-_C.BASELINE.ORBSLAM2.BETA = 100
-_C.BASELINE.ORBSLAM2.H_OBSTACLE_MIN = 0.3 * _C.BASELINE.ORBSLAM2.CAMERA_HEIGHT
-_C.BASELINE.ORBSLAM2.H_OBSTACLE_MAX = 1.0 * _C.BASELINE.ORBSLAM2.CAMERA_HEIGHT
-_C.BASELINE.ORBSLAM2.D_OBSTACLE_MIN = 0.1
-_C.BASELINE.ORBSLAM2.D_OBSTACLE_MAX = 4.0
-_C.BASELINE.ORBSLAM2.PREPROCESS_MAP = True
-_C.BASELINE.ORBSLAM2.MIN_PTS_IN_OBSTACLE = (
-    get_config().SIMULATOR.DEPTH_SENSOR.WIDTH / 2.0
+_C.TRAINER.ORBSLAM2.BETA = 100
+_C.TRAINER.ORBSLAM2.H_OBSTACLE_MIN = 0.3 * _C.TRAINER.ORBSLAM2.CAMERA_HEIGHT
+_C.TRAINER.ORBSLAM2.H_OBSTACLE_MAX = 1.0 * _C.TRAINER.ORBSLAM2.CAMERA_HEIGHT
+_C.TRAINER.ORBSLAM2.D_OBSTACLE_MIN = 0.1
+_C.TRAINER.ORBSLAM2.D_OBSTACLE_MAX = 4.0
+_C.TRAINER.ORBSLAM2.PREPROCESS_MAP = True
+_C.TRAINER.ORBSLAM2.MIN_PTS_IN_OBSTACLE = (
+    get_task_config().SIMULATOR.DEPTH_SENSOR.WIDTH / 2.0
 )
-_C.BASELINE.ORBSLAM2.ANGLE_TH = float(np.deg2rad(15))
-_C.BASELINE.ORBSLAM2.DIST_REACHED_TH = 0.15
-_C.BASELINE.ORBSLAM2.NEXT_WAYPOINT_TH = 0.5
-_C.BASELINE.ORBSLAM2.NUM_ACTIONS = 3
-_C.BASELINE.ORBSLAM2.DIST_TO_STOP = 0.05
-_C.BASELINE.ORBSLAM2.PLANNER_MAX_STEPS = 500
-_C.BASELINE.ORBSLAM2.DEPTH_DENORM = (
-    get_config().SIMULATOR.DEPTH_SENSOR.MAX_DEPTH
+_C.TRAINER.ORBSLAM2.ANGLE_TH = float(np.deg2rad(15))
+_C.TRAINER.ORBSLAM2.DIST_REACHED_TH = 0.15
+_C.TRAINER.ORBSLAM2.NEXT_WAYPOINT_TH = 0.5
+_C.TRAINER.ORBSLAM2.NUM_ACTIONS = 3
+_C.TRAINER.ORBSLAM2.DIST_TO_STOP = 0.05
+_C.TRAINER.ORBSLAM2.PLANNER_MAX_STEPS = 500
+_C.TRAINER.ORBSLAM2.DEPTH_DENORM = (
+    get_task_config().SIMULATOR.DEPTH_SENSOR.MAX_DEPTH
 )
 
 
@@ -87,6 +126,8 @@ def get_config(
         for config_path in config_paths:
             config.merge_from_file(config_path)
 
+    config.TASK_CONFIG = get_task_config(config.BASE_TASK_CONFIG_PATH)
+    config.CMD_TRAILING_OPTS = opts
     if opts:
         config.merge_from_list(opts)
 
diff --git a/habitat_baselines/config/pointnav/ppo.yaml b/habitat_baselines/config/pointnav/ppo.yaml
new file mode 100644
index 000000000..305a2eed1
--- /dev/null
+++ b/habitat_baselines/config/pointnav/ppo.yaml
@@ -0,0 +1,39 @@
+BASE_TASK_CONFIG_PATH: "configs/tasks/pointnav.yaml"
+TRAINER:
+  TRAINER_NAME: "ppo"
+  RL:
+    PPO:
+      # general options
+      tensorboard_dir: "tb"
+      num_processes: 1
+      log_interval: 10
+      pth_gpu_id: 0
+      sim_gpu_id: 0
+      checkpoint_interval: 50
+      checkpoint_folder: "data/checkpoints"
+      task_config: "configs/tasks/pointnav.yaml"
+      num_updates: 10000
+      # eval specific:
+      count_test_episodes: 2
+      video_option: "disk,tensorboard"
+      video_dir: "video_dir"
+      eval_ckpt_path_or_dir: "data/checkpoints"
+
+      # ppo params
+      clip_param: 0.1
+      ppo_epoch: 4
+      num_mini_batch: 1
+      value_loss_coef: 0.5
+      entropy_coef: 0.01
+      lr: 2.5e-4
+      eps: 1e-5
+      max_grad_norm: 0.5
+      num_steps: 128
+      hidden_size: 512
+      use_gae: True
+      gamma: 0.99
+      tau: 0.95
+      use_linear_clip_decay: True
+      use_linear_lr_decay: True
+      reward_window_size: 50
+      sensors: "RGB_SENSOR,DEPTH_SENSOR"
diff --git a/habitat_baselines/config/pointnav/ppo_train_test.yaml b/habitat_baselines/config/pointnav/ppo_train_test.yaml
new file mode 100644
index 000000000..cb303ba72
--- /dev/null
+++ b/habitat_baselines/config/pointnav/ppo_train_test.yaml
@@ -0,0 +1,12 @@
+TRAINER:
+  TRAINER_NAME: "ppo"
+  RL:
+    PPO:
+      num_processes: 1
+      pth_gpu_id: 0
+      sim_gpu_id: 0
+      checkpoint_interval: 50
+      checkpoint_folder: "data/checkpoints"
+      task_config: "configs/tasks/pointnav.yaml"
+      num_updates: 10
+      num_mini_batch: 1
diff --git a/habitat_baselines/evaluate_ppo.py b/habitat_baselines/evaluate_ppo.py
deleted file mode 100644
index f467b1956..000000000
--- a/habitat_baselines/evaluate_ppo.py
+++ /dev/null
@@ -1,356 +0,0 @@
-#!/usr/bin/env python3
-
-# Copyright (c) Facebook, Inc. and its affiliates.
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-import argparse
-import glob
-import os
-import time
-from typing import Optional
-
-import torch
-
-import habitat
-from config.default import get_config as cfg_baseline
-from habitat import logger
-from habitat.config.default import get_config
-from habitat.utils.visualizations.utils import (
-    images_to_video,
-    observations_to_image,
-)
-from rl.ppo import PPO, Policy
-from rl.ppo.utils import batch_obs
-from tensorboard_utils import get_tensorboard_writer
-from train_ppo import make_env_fn
-
-
-def poll_checkpoint_folder(
-    checkpoint_folder: str, previous_ckpt_ind: int
-) -> Optional[str]:
-    r""" Return (previous_ckpt_ind + 1)th checkpoint in checkpoint folder
-    (sorted by time of last modification).
-
-    Args:
-        checkpoint_folder: directory to look for checkpoints.
-        previous_ckpt_ind: index of checkpoint last returned.
-
-    Returns:
-        return checkpoint path if (previous_ckpt_ind + 1)th checkpoint is found
-        else return None.
-    """
-    assert os.path.isdir(checkpoint_folder), "invalid checkpoint folder path"
-    models_paths = list(
-        filter(os.path.isfile, glob.glob(checkpoint_folder + "/*"))
-    )
-    models_paths.sort(key=os.path.getmtime)
-    ind = previous_ckpt_ind + 1
-    if ind < len(models_paths):
-        return models_paths[ind]
-    return None
-
-
-def generate_video(
-    args, images, episode_id, checkpoint_idx, spl, tb_writer, fps=10
-) -> None:
-    r"""Generate video according to specified information.
-
-    Args:
-        args: contains args.video_option and args.video_dir.
-        images: list of images to be converted to video.
-        episode_id: episode id for video naming.
-        checkpoint_idx: checkpoint index for video naming.
-        spl: SPL for this episode for video naming.
-        tb_writer: tensorboard writer object for uploading video
-        fps: fps for generated video
-
-    Returns:
-        None
-    """
-    if args.video_option and len(images) > 0:
-        video_name = f"episode{episode_id}_ckpt{checkpoint_idx}_spl{spl:.2f}"
-        if "disk" in args.video_option:
-            images_to_video(images, args.video_dir, video_name)
-        if "tensorboard" in args.video_option:
-            tb_writer.add_video_from_np_images(
-                f"episode{episode_id}", checkpoint_idx, images, fps=fps
-            )
-
-
-def eval_checkpoint(checkpoint_path, args, writer, cur_ckpt_idx=0):
-    env_configs = []
-    baseline_configs = []
-    device = torch.device("cuda", args.pth_gpu_id)
-
-    for _ in range(args.num_processes):
-        config_env = get_config(config_paths=args.task_config)
-        config_env.defrost()
-        config_env.DATASET.SPLIT = "val"
-
-        agent_sensors = args.sensors.strip().split(",")
-        for sensor in agent_sensors:
-            assert sensor in ["RGB_SENSOR", "DEPTH_SENSOR"]
-        config_env.SIMULATOR.AGENT_0.SENSORS = agent_sensors
-        if args.video_option:
-            config_env.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
-            config_env.TASK.MEASUREMENTS.append("COLLISIONS")
-        config_env.freeze()
-        env_configs.append(config_env)
-
-        config_baseline = cfg_baseline()
-        baseline_configs.append(config_baseline)
-
-    assert len(baseline_configs) > 0, "empty list of datasets"
-
-    envs = habitat.VectorEnv(
-        make_env_fn=make_env_fn,
-        env_fn_args=tuple(
-            tuple(
-                zip(env_configs, baseline_configs, range(args.num_processes))
-            )
-        ),
-    )
-
-    ckpt = torch.load(checkpoint_path, map_location=device)
-
-    actor_critic = Policy(
-        observation_space=envs.observation_spaces[0],
-        action_space=envs.action_spaces[0],
-        hidden_size=512,
-        goal_sensor_uuid=env_configs[0].TASK.GOAL_SENSOR_UUID,
-    )
-    actor_critic.to(device)
-
-    ppo = PPO(
-        actor_critic=actor_critic,
-        clip_param=0.1,
-        ppo_epoch=4,
-        num_mini_batch=32,
-        value_loss_coef=0.5,
-        entropy_coef=0.01,
-        lr=2.5e-4,
-        eps=1e-5,
-        max_grad_norm=0.5,
-    )
-
-    ppo.load_state_dict(ckpt["state_dict"])
-
-    actor_critic = ppo.actor_critic
-
-    observations = envs.reset()
-    batch = batch_obs(observations)
-    for sensor in batch:
-        batch[sensor] = batch[sensor].to(device)
-
-    current_episode_reward = torch.zeros(envs.num_envs, 1, device=device)
-
-    test_recurrent_hidden_states = torch.zeros(
-        args.num_processes, args.hidden_size, device=device
-    )
-    not_done_masks = torch.zeros(args.num_processes, 1, device=device)
-    stats_episodes = dict()  # dict of dicts that stores stats per episode
-
-    rgb_frames = None
-    if args.video_option:
-        rgb_frames = [[]] * args.num_processes
-        os.makedirs(args.video_dir, exist_ok=True)
-
-    while len(stats_episodes) < args.count_test_episodes and envs.num_envs > 0:
-        current_episodes = envs.current_episodes()
-
-        with torch.no_grad():
-            _, actions, _, test_recurrent_hidden_states = actor_critic.act(
-                batch,
-                test_recurrent_hidden_states,
-                not_done_masks,
-                deterministic=False,
-            )
-
-        outputs = envs.step([a[0].item() for a in actions])
-
-        observations, rewards, dones, infos = [list(x) for x in zip(*outputs)]
-        batch = batch_obs(observations)
-        for sensor in batch:
-            batch[sensor] = batch[sensor].to(device)
-
-        not_done_masks = torch.tensor(
-            [[0.0] if done else [1.0] for done in dones],
-            dtype=torch.float,
-            device=device,
-        )
-
-        rewards = torch.tensor(
-            rewards, dtype=torch.float, device=device
-        ).unsqueeze(1)
-        current_episode_reward += rewards
-        next_episodes = envs.current_episodes()
-        envs_to_pause = []
-        n_envs = envs.num_envs
-        for i in range(n_envs):
-            if (
-                next_episodes[i].scene_id,
-                next_episodes[i].episode_id,
-            ) in stats_episodes:
-                envs_to_pause.append(i)
-
-            # episode ended
-            if not_done_masks[i].item() == 0:
-                episode_stats = dict()
-                episode_stats["spl"] = infos[i]["spl"]
-                episode_stats["success"] = int(infos[i]["spl"] > 0)
-                episode_stats["reward"] = current_episode_reward[i].item()
-                current_episode_reward[i] = 0
-                # use scene_id + episode_id as unique id for storing stats
-                stats_episodes[
-                    (
-                        current_episodes[i].scene_id,
-                        current_episodes[i].episode_id,
-                    )
-                ] = episode_stats
-                if args.video_option:
-                    generate_video(
-                        args,
-                        rgb_frames[i],
-                        current_episodes[i].episode_id,
-                        cur_ckpt_idx,
-                        infos[i]["spl"],
-                        writer,
-                    )
-                    rgb_frames[i] = []
-
-            # episode continues
-            elif args.video_option:
-                frame = observations_to_image(observations[i], infos[i])
-                rgb_frames[i].append(frame)
-
-        # pausing envs with no new episode
-        if len(envs_to_pause) > 0:
-            state_index = list(range(envs.num_envs))
-            for idx in reversed(envs_to_pause):
-                state_index.pop(idx)
-                envs.pause_at(idx)
-
-            # indexing along the batch dimensions
-            test_recurrent_hidden_states = test_recurrent_hidden_states[
-                state_index
-            ]
-            not_done_masks = not_done_masks[state_index]
-            current_episode_reward = current_episode_reward[state_index]
-
-            for k, v in batch.items():
-                batch[k] = v[state_index]
-
-            if args.video_option:
-                rgb_frames = [rgb_frames[i] for i in state_index]
-
-    aggregated_stats = dict()
-    for stat_key in next(iter(stats_episodes.values())).keys():
-        aggregated_stats[stat_key] = sum(
-            [v[stat_key] for v in stats_episodes.values()]
-        )
-    num_episodes = len(stats_episodes)
-
-    episode_reward_mean = aggregated_stats["reward"] / num_episodes
-    episode_spl_mean = aggregated_stats["spl"] / num_episodes
-    episode_success_mean = aggregated_stats["success"] / num_episodes
-
-    logger.info("Average episode reward: {:.6f}".format(episode_reward_mean))
-    logger.info("Average episode success: {:.6f}".format(episode_success_mean))
-    logger.info("Average episode SPL: {:.6f}".format(episode_spl_mean))
-
-    writer.add_scalars(
-        "eval_reward", {"average reward": episode_reward_mean}, cur_ckpt_idx
-    )
-    writer.add_scalars(
-        "eval_SPL", {"average SPL": episode_spl_mean}, cur_ckpt_idx
-    )
-    writer.add_scalars(
-        "eval_success", {"average success": episode_success_mean}, cur_ckpt_idx
-    )
-
-
-def main():
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--model-path", type=str)
-    parser.add_argument("--tracking-model-dir", type=str)
-    parser.add_argument("--sim-gpu-id", type=int, required=True)
-    parser.add_argument("--pth-gpu-id", type=int, required=True)
-    parser.add_argument("--num-processes", type=int, required=True)
-    parser.add_argument("--hidden-size", type=int, default=512)
-    parser.add_argument("--count-test-episodes", type=int, default=100)
-    parser.add_argument(
-        "--sensors",
-        type=str,
-        default="RGB_SENSOR,DEPTH_SENSOR",
-        help="comma separated string containing different"
-        "sensors to use, currently 'RGB_SENSOR' and"
-        "'DEPTH_SENSOR' are supported",
-    )
-    parser.add_argument(
-        "--task-config",
-        type=str,
-        default="configs/tasks/pointnav.yaml",
-        help="path to config yaml containing information about task",
-    )
-    parser.add_argument(
-        "--video-option",
-        type=str,
-        default="",
-        choices=["tensorboard", "disk"],
-        nargs="*",
-        help="Options for video output, leave empty for no video. "
-        "Videos can be saved to disk, uploaded to tensorboard, or both.",
-    )
-    parser.add_argument(
-        "--video-dir", type=str, help="directory for storing videos"
-    )
-    parser.add_argument(
-        "--tensorboard-dir",
-        type=str,
-        help="directory for storing tensorboard statistics",
-    )
-
-    args = parser.parse_args()
-
-    assert (args.model_path is not None) != (
-        args.tracking_model_dir is not None
-    ), "Must specify a single model or a directory of models, but not both"
-    if "tensorboard" in args.video_option:
-        assert (
-            args.tensorboard_dir is not None
-        ), "Must specify a tensorboard directory for video display"
-    if "disk" in args.video_option:
-        assert (
-            args.video_dir is not None
-        ), "Must specify a directory for storing videos on disk"
-
-    with get_tensorboard_writer(
-        args.tensorboard_dir, purge_step=0, flush_secs=30
-    ) as writer:
-        if args.model_path is not None:
-            # evaluate singe checkpoint
-            eval_checkpoint(args.model_path, args, writer)
-        else:
-            # evaluate multiple checkpoints in order
-            prev_ckpt_ind = -1
-            while True:
-                current_ckpt = None
-                while current_ckpt is None:
-                    current_ckpt = poll_checkpoint_folder(
-                        args.tracking_model_dir, prev_ckpt_ind
-                    )
-                    time.sleep(2)  # sleep for 2 seconds before polling again
-                logger.warning(
-                    "=============current_ckpt: {}=============".format(
-                        current_ckpt
-                    )
-                )
-                prev_ckpt_ind += 1
-                eval_checkpoint(
-                    current_ckpt, args, writer, cur_ckpt_idx=prev_ckpt_ind
-                )
-
-
-if __name__ == "__main__":
-    main()
diff --git a/habitat_baselines/rl/ppo/__init__.py b/habitat_baselines/rl/ppo/__init__.py
index 128cc360c..9c00af215 100644
--- a/habitat_baselines/rl/ppo/__init__.py
+++ b/habitat_baselines/rl/ppo/__init__.py
@@ -6,6 +6,5 @@
 
 from habitat_baselines.rl.ppo.policy import Policy
 from habitat_baselines.rl.ppo.ppo import PPO
-from habitat_baselines.rl.ppo.utils import RolloutStorage
 
-__all__ = ["PPO", "Policy", "RolloutStorage"]
+__all__ = ["PPO", "Policy"]
diff --git a/habitat_baselines/rl/ppo/policy.py b/habitat_baselines/rl/ppo/policy.py
index eed1d8b39..0be21e5af 100644
--- a/habitat_baselines/rl/ppo/policy.py
+++ b/habitat_baselines/rl/ppo/policy.py
@@ -8,7 +8,7 @@ import numpy as np
 import torch
 import torch.nn as nn
 
-from habitat_baselines.rl.ppo.utils import CategoricalNet, Flatten
+from habitat_baselines.common.utils import CategoricalNet, Flatten
 
 
 class Policy(nn.Module):
diff --git a/habitat_baselines/rl/ppo/ppo_trainer.py b/habitat_baselines/rl/ppo/ppo_trainer.py
new file mode 100644
index 000000000..31d4b4782
--- /dev/null
+++ b/habitat_baselines/rl/ppo/ppo_trainer.py
@@ -0,0 +1,575 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import time
+from collections import deque
+from typing import Dict, List
+
+import numpy as np
+import torch
+
+from habitat import Config, logger
+from habitat.utils.visualizations.utils import observations_to_image
+from habitat_baselines.common.base_trainer import BaseRLTrainer
+from habitat_baselines.common.baseline_registry import baseline_registry
+from habitat_baselines.common.env_utils import construct_envs
+from habitat_baselines.common.environments import NavRLEnv
+from habitat_baselines.common.rollout_storage import RolloutStorage
+from habitat_baselines.common.tensorboard_utils import (
+    TensorboardWriter,
+    get_tensorboard_writer,
+)
+from habitat_baselines.common.utils import (
+    batch_obs,
+    generate_video,
+    poll_checkpoint_folder,
+    update_linear_schedule,
+)
+from habitat_baselines.rl.ppo import PPO, Policy
+
+
+@baseline_registry.register_trainer(name="ppo")
+class PPOTrainer(BaseRLTrainer):
+    r"""
+    Trainer class for PPO algorithm
+    Paper: https://arxiv.org/abs/1707.06347
+    """
+    supported_tasks = ["Nav-v0"]
+
+    def __init__(self, config=None):
+        super().__init__(config)
+        self.actor_critic = None
+        self.agent = None
+        self.envs = None
+        self.device = None
+        self.video_option = []
+        if config is not None:
+            logger.info(f"config: {config}")
+
+    def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None:
+        r"""
+        Sets up actor critic and agent for PPO
+        Args:
+            ppo_cfg: config node with relevant params
+
+        Returns:
+            None
+        """
+        logger.add_filehandler(ppo_cfg.log_file)
+
+        self.actor_critic = Policy(
+            observation_space=self.envs.observation_spaces[0],
+            action_space=self.envs.action_spaces[0],
+            hidden_size=512,
+            goal_sensor_uuid=self.config.TASK_CONFIG.TASK.GOAL_SENSOR_UUID,
+        )
+        self.actor_critic.to(self.device)
+
+        self.agent = PPO(
+            actor_critic=self.actor_critic,
+            clip_param=ppo_cfg.clip_param,
+            ppo_epoch=ppo_cfg.ppo_epoch,
+            num_mini_batch=ppo_cfg.num_mini_batch,
+            value_loss_coef=ppo_cfg.value_loss_coef,
+            entropy_coef=ppo_cfg.entropy_coef,
+            lr=ppo_cfg.lr,
+            eps=ppo_cfg.eps,
+            max_grad_norm=ppo_cfg.max_grad_norm,
+        )
+
+    def save_checkpoint(self, file_name: str) -> None:
+        r"""
+        Save checkpoint with specified name
+        Args:
+            file_name: file name for checkpoint
+
+        Returns:
+            None
+        """
+        checkpoint = {
+            "state_dict": self.agent.state_dict(),
+            "config": self.config,
+        }
+        torch.save(
+            checkpoint,
+            os.path.join(
+                self.config.TRAINER.RL.PPO.checkpoint_folder, file_name
+            ),
+        )
+
+    def load_checkpoint(self, checkpoint_path: str, *args, **kwargs) -> Dict:
+        r"""
+        Load checkpoint of specified path as a dict
+        Args:
+            checkpoint_path: path of target checkpoint
+            *args: additional positional args
+            **kwargs: additional keyword args
+
+        Returns:
+            dict containing checkpoint info
+        """
+        return torch.load(checkpoint_path, map_location=self.device)
+
+    def train(self) -> None:
+        r"""
+        Main method for training PPO
+        Returns:
+            None
+        """
+        assert (
+            self.config is not None
+        ), "trainer is not properly initialized, need to specify config file"
+
+        self.envs = construct_envs(self.config, NavRLEnv)
+
+        ppo_cfg = self.config.TRAINER.RL.PPO
+        self.device = torch.device("cuda", ppo_cfg.pth_gpu_id)
+        if not os.path.isdir(ppo_cfg.checkpoint_folder):
+            os.makedirs(ppo_cfg.checkpoint_folder)
+        self._setup_actor_critic_agent(ppo_cfg)
+        logger.info(
+            "agent number of parameters: {}".format(
+                sum(param.numel() for param in self.agent.parameters())
+            )
+        )
+
+        observations = self.envs.reset()
+        batch = batch_obs(observations)
+
+        rollouts = RolloutStorage(
+            ppo_cfg.num_steps,
+            self.envs.num_envs,
+            self.envs.observation_spaces[0],
+            self.envs.action_spaces[0],
+            ppo_cfg.hidden_size,
+        )
+        for sensor in rollouts.observations:
+            rollouts.observations[sensor][0].copy_(batch[sensor])
+        rollouts.to(self.device)
+
+        episode_rewards = torch.zeros(self.envs.num_envs, 1)
+        episode_counts = torch.zeros(self.envs.num_envs, 1)
+        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
+        window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size)
+        window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size)
+
+        t_start = time.time()
+        env_time = 0
+        pth_time = 0
+        count_steps = 0
+        count_checkpoints = 0
+
+        with (
+            get_tensorboard_writer(
+                log_dir=ppo_cfg.tensorboard_dir,
+                purge_step=count_steps,
+                flush_secs=30,
+            )
+        ) as writer:
+            for update in range(ppo_cfg.num_updates):
+                if ppo_cfg.use_linear_lr_decay:
+                    update_linear_schedule(
+                        self.agent.optimizer,
+                        update,
+                        ppo_cfg.num_updates,
+                        ppo_cfg.lr,
+                    )
+
+                if ppo_cfg.use_linear_clip_decay:
+                    self.agent.clip_param = ppo_cfg.clip_param * (
+                        1 - update / ppo_cfg.num_updates
+                    )
+
+                for step in range(ppo_cfg.num_steps):
+                    t_sample_action = time.time()
+                    # sample actions
+                    with torch.no_grad():
+                        step_observation = {
+                            k: v[step]
+                            for k, v in rollouts.observations.items()
+                        }
+
+                        (
+                            values,
+                            actions,
+                            actions_log_probs,
+                            recurrent_hidden_states,
+                        ) = self.actor_critic.act(
+                            step_observation,
+                            rollouts.recurrent_hidden_states[step],
+                            rollouts.masks[step],
+                        )
+                    pth_time += time.time() - t_sample_action
+
+                    t_step_env = time.time()
+
+                    outputs = self.envs.step([a[0].item() for a in actions])
+                    observations, rewards, dones, infos = [
+                        list(x) for x in zip(*outputs)
+                    ]
+
+                    env_time += time.time() - t_step_env
+
+                    t_update_stats = time.time()
+                    batch = batch_obs(observations)
+                    rewards = torch.tensor(rewards, dtype=torch.float)
+                    rewards = rewards.unsqueeze(1)
+
+                    masks = torch.tensor(
+                        [[0.0] if done else [1.0] for done in dones],
+                        dtype=torch.float,
+                    )
+
+                    current_episode_reward += rewards
+                    episode_rewards += (1 - masks) * current_episode_reward
+                    episode_counts += 1 - masks
+                    current_episode_reward *= masks
+
+                    rollouts.insert(
+                        batch,
+                        recurrent_hidden_states,
+                        actions,
+                        actions_log_probs,
+                        values,
+                        rewards,
+                        masks,
+                    )
+
+                    count_steps += self.envs.num_envs
+                    pth_time += time.time() - t_update_stats
+
+                window_episode_reward.append(episode_rewards.clone())
+                window_episode_counts.append(episode_counts.clone())
+
+                t_update_model = time.time()
+                with torch.no_grad():
+                    last_observation = {
+                        k: v[-1] for k, v in rollouts.observations.items()
+                    }
+                    next_value = self.actor_critic.get_value(
+                        last_observation,
+                        rollouts.recurrent_hidden_states[-1],
+                        rollouts.masks[-1],
+                    ).detach()
+
+                rollouts.compute_returns(
+                    next_value, ppo_cfg.use_gae, ppo_cfg.gamma, ppo_cfg.tau
+                )
+
+                value_loss, action_loss, dist_entropy = self.agent.update(
+                    rollouts
+                )
+
+                rollouts.after_update()
+                pth_time += time.time() - t_update_model
+
+                losses = [value_loss, action_loss]
+                stats = zip(
+                    ["count", "reward"],
+                    [window_episode_counts, window_episode_reward],
+                )
+                deltas = {
+                    k: (
+                        (v[-1] - v[0]).sum().item()
+                        if len(v) > 1
+                        else v[0].sum().item()
+                    )
+                    for k, v in stats
+                }
+                deltas["count"] = max(deltas["count"], 1.0)
+
+                writer.add_scalar(
+                    "reward", deltas["reward"] / deltas["count"], count_steps
+                )
+
+                writer.add_scalars(
+                    "losses",
+                    {k: l for l, k in zip(losses, ["value", "policy"])},
+                    count_steps,
+                )
+
+                # log stats
+                if update > 0 and update % ppo_cfg.log_interval == 0:
+                    logger.info(
+                        "update: {}\tfps: {:.3f}\t".format(
+                            update, count_steps / (time.time() - t_start)
+                        )
+                    )
+
+                    logger.info(
+                        "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
+                        "frames: {}".format(
+                            update, env_time, pth_time, count_steps
+                        )
+                    )
+
+                    window_rewards = (
+                        window_episode_reward[-1] - window_episode_reward[0]
+                    ).sum()
+                    window_counts = (
+                        window_episode_counts[-1] - window_episode_counts[0]
+                    ).sum()
+
+                    if window_counts > 0:
+                        logger.info(
+                            "Average window size {} reward: {:3f}".format(
+                                len(window_episode_reward),
+                                (window_rewards / window_counts).item(),
+                            )
+                        )
+                    else:
+                        logger.info("No episodes finish in current window")
+
+                # checkpoint model
+                if update % ppo_cfg.checkpoint_interval == 0:
+                    self.save_checkpoint(f"ckpt.{count_checkpoints}.pth")
+                    count_checkpoints += 1
+
+    def eval(self) -> None:
+        r"""
+        Main method of evaluating PPO
+        Returns:
+            None
+        """
+        ppo_cfg = self.config.TRAINER.RL.PPO
+        self.device = torch.device("cuda", ppo_cfg.pth_gpu_id)
+        self.video_option = ppo_cfg.video_option.strip().split(",")
+
+        if "tensorboard" in self.video_option:
+            assert (
+                ppo_cfg.tensorboard_dir is not None
+            ), "Must specify a tensorboard directory for video display"
+        if "disk" in self.video_option:
+            assert (
+                ppo_cfg.video_dir is not None
+            ), "Must specify a directory for storing videos on disk"
+
+        with get_tensorboard_writer(
+            ppo_cfg.tensorboard_dir, purge_step=0, flush_secs=30
+        ) as writer:
+            if os.path.isfile(ppo_cfg.eval_ckpt_path_or_dir):
+                # evaluate singe checkpoint
+                self._eval_checkpoint(ppo_cfg.eval_ckpt_path_or_dir, writer)
+            else:
+                # evaluate multiple checkpoints in order
+                prev_ckpt_ind = -1
+                while True:
+                    current_ckpt = None
+                    while current_ckpt is None:
+                        current_ckpt = poll_checkpoint_folder(
+                            ppo_cfg.eval_ckpt_path_or_dir, prev_ckpt_ind
+                        )
+                        time.sleep(2)  # sleep for 2 secs before polling again
+                    logger.warning(
+                        "=============current_ckpt: {}=============".format(
+                            current_ckpt
+                        )
+                    )
+                    prev_ckpt_ind += 1
+                    self._eval_checkpoint(
+                        checkpoint_path=current_ckpt,
+                        writer=writer,
+                        cur_ckpt_idx=prev_ckpt_ind,
+                    )
+
+    def _eval_checkpoint(
+        self,
+        checkpoint_path: str,
+        writer: TensorboardWriter,
+        cur_ckpt_idx: int = 0,
+    ) -> None:
+        r"""
+        Evaluates a single checkpoint
+        Args:
+            checkpoint_path: path of checkpoint
+            writer: tensorboard writer object for logging to tensorboard
+            cur_ckpt_idx: index of cur checkpoint for logging
+
+        Returns:
+            None
+        """
+        ckpt_dict = self.load_checkpoint(
+            checkpoint_path, map_location=self.device
+        )
+
+        ckpt_config = ckpt_dict["config"]
+        config = self.config.clone()
+        ckpt_cmd_opts = ckpt_config.CMD_TRAILING_OPTS
+        eval_cmd_opts = config.CMD_TRAILING_OPTS
+
+        # config merge priority: eval_opts > ckpt_opts > eval_cfg > ckpt_cfg
+        # first line for old checkpoint compatibility
+        config.merge_from_other_cfg(ckpt_config)
+        config.merge_from_other_cfg(self.config)
+        config.merge_from_list(ckpt_cmd_opts)
+        config.merge_from_list(eval_cmd_opts)
+
+        ppo_cfg = config.TRAINER.RL.PPO
+        config.TASK_CONFIG.defrost()
+        config.TASK_CONFIG.DATASET.SPLIT = "val"
+        agent_sensors = ppo_cfg.sensors.strip().split(",")
+        config.TASK_CONFIG.SIMULATOR.AGENT_0.SENSORS = agent_sensors
+        if self.video_option:
+            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
+            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
+        config.freeze()
+
+        logger.info(f"env config: {config}")
+        self.envs = construct_envs(config, NavRLEnv)
+        self._setup_actor_critic_agent(ppo_cfg)
+
+        self.agent.load_state_dict(ckpt_dict["state_dict"])
+        self.actor_critic = self.agent.actor_critic
+
+        observations = self.envs.reset()
+        batch = batch_obs(observations)
+        for sensor in batch:
+            batch[sensor] = batch[sensor].to(self.device)
+
+        current_episode_reward = torch.zeros(
+            self.envs.num_envs, 1, device=self.device
+        )
+
+        test_recurrent_hidden_states = torch.zeros(
+            ppo_cfg.num_processes, ppo_cfg.hidden_size, device=self.device
+        )
+        not_done_masks = torch.zeros(
+            ppo_cfg.num_processes, 1, device=self.device
+        )
+        stats_episodes = dict()  # dict of dicts that stores stats per episode
+
+        rgb_frames = [
+            []
+        ] * ppo_cfg.num_processes  # type: List[List[np.ndarray]]
+        if self.video_option:
+            os.makedirs(ppo_cfg.video_dir, exist_ok=True)
+
+        while (
+            len(stats_episodes) < ppo_cfg.count_test_episodes
+            and self.envs.num_envs > 0
+        ):
+            current_episodes = self.envs.current_episodes()
+
+            with torch.no_grad():
+                _, actions, _, test_recurrent_hidden_states = self.actor_critic.act(
+                    batch,
+                    test_recurrent_hidden_states,
+                    not_done_masks,
+                    deterministic=False,
+                )
+
+            outputs = self.envs.step([a[0].item() for a in actions])
+
+            observations, rewards, dones, infos = [
+                list(x) for x in zip(*outputs)
+            ]
+            batch = batch_obs(observations)
+            for sensor in batch:
+                batch[sensor] = batch[sensor].to(self.device)
+
+            not_done_masks = torch.tensor(
+                [[0.0] if done else [1.0] for done in dones],
+                dtype=torch.float,
+                device=self.device,
+            )
+
+            rewards = torch.tensor(
+                rewards, dtype=torch.float, device=self.device
+            ).unsqueeze(1)
+            current_episode_reward += rewards
+            next_episodes = self.envs.current_episodes()
+            envs_to_pause = []
+            n_envs = self.envs.num_envs
+            for i in range(n_envs):
+                if (
+                    next_episodes[i].scene_id,
+                    next_episodes[i].episode_id,
+                ) in stats_episodes:
+                    envs_to_pause.append(i)
+
+                # episode ended
+                if not_done_masks[i].item() == 0:
+                    episode_stats = dict()
+                    episode_stats["spl"] = infos[i]["spl"]
+                    episode_stats["success"] = int(infos[i]["spl"] > 0)
+                    episode_stats["reward"] = current_episode_reward[i].item()
+                    current_episode_reward[i] = 0
+                    # use scene_id + episode_id as unique id for storing stats
+                    stats_episodes[
+                        (
+                            current_episodes[i].scene_id,
+                            current_episodes[i].episode_id,
+                        )
+                    ] = episode_stats
+                    if self.video_option:
+                        generate_video(
+                            ppo_cfg,
+                            rgb_frames[i],
+                            current_episodes[i].episode_id,
+                            cur_ckpt_idx,
+                            infos[i]["spl"],
+                            writer,
+                        )
+                        rgb_frames[i] = []
+
+                # episode continues
+                elif self.video_option:
+                    frame = observations_to_image(observations[i], infos[i])
+                    rgb_frames[i].append(frame)
+
+            # pausing self.envs with no new episode
+            if len(envs_to_pause) > 0:
+                state_index = list(range(self.envs.num_envs))
+                for idx in reversed(envs_to_pause):
+                    state_index.pop(idx)
+                    self.envs.pause_at(idx)
+
+                # indexing along the batch dimensions
+                test_recurrent_hidden_states = test_recurrent_hidden_states[
+                    state_index
+                ]
+                not_done_masks = not_done_masks[state_index]
+                current_episode_reward = current_episode_reward[state_index]
+
+                for k, v in batch.items():
+                    batch[k] = v[state_index]
+
+                if self.video_option:
+                    rgb_frames = [rgb_frames[i] for i in state_index]
+
+        aggregated_stats = dict()
+        for stat_key in next(iter(stats_episodes.values())).keys():
+            aggregated_stats[stat_key] = sum(
+                [v[stat_key] for v in stats_episodes.values()]
+            )
+        num_episodes = len(stats_episodes)
+
+        episode_reward_mean = aggregated_stats["reward"] / num_episodes
+        episode_spl_mean = aggregated_stats["spl"] / num_episodes
+        episode_success_mean = aggregated_stats["success"] / num_episodes
+
+        logger.info(
+            "Average episode reward: {:.6f}".format(episode_reward_mean)
+        )
+        logger.info(
+            "Average episode success: {:.6f}".format(episode_success_mean)
+        )
+        logger.info("Average episode SPL: {:.6f}".format(episode_spl_mean))
+
+        writer.add_scalars(
+            "eval_reward",
+            {"average reward": episode_reward_mean},
+            cur_ckpt_idx,
+        )
+        writer.add_scalars(
+            "eval_SPL", {"average SPL": episode_spl_mean}, cur_ckpt_idx
+        )
+        writer.add_scalars(
+            "eval_success",
+            {"average success": episode_success_mean},
+            cur_ckpt_idx,
+        )
diff --git a/habitat_baselines/run.py b/habitat_baselines/run.py
new file mode 100644
index 000000000..8d9ce029b
--- /dev/null
+++ b/habitat_baselines/run.py
@@ -0,0 +1,50 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import random
+
+import numpy as np
+
+from habitat_baselines.common.utils import get_trainer
+from habitat_baselines.config.default import get_config
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--run-type",
+        choices=["train", "eval"],
+        required=True,
+        help="run type of the experiment (train or eval)",
+    )
+    parser.add_argument(
+        "--exp-config",
+        type=str,
+        required=True,
+        help="path to config yaml containing info about experiment",
+    )
+    parser.add_argument(
+        "opts",
+        default=None,
+        nargs=argparse.REMAINDER,
+        help="Modify config options from command line",
+    )
+    args = parser.parse_args()
+    config = get_config(args.exp_config, args.opts)
+
+    random.seed(config.TASK_CONFIG.SEED)
+    np.random.seed(config.TASK_CONFIG.SEED)
+
+    trainer = get_trainer(config.TRAINER.TRAINER_NAME, config)
+    if args.run_type == "train":
+        trainer.train()
+    elif args.run_type == "eval":
+        trainer.eval()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/habitat_baselines/train_ppo.py b/habitat_baselines/train_ppo.py
deleted file mode 100644
index 733d78b66..000000000
--- a/habitat_baselines/train_ppo.py
+++ /dev/null
@@ -1,397 +0,0 @@
-#!/usr/bin/env python3
-
-# Copyright (c) Facebook, Inc. and its affiliates.
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-import os
-import random
-from collections import deque
-from time import time
-
-import numpy as np
-import torch
-
-import habitat
-from config.default import get_config as cfg_baseline
-from habitat import SimulatorActions, logger
-from habitat.config.default import get_config as cfg_env
-from habitat.datasets.registration import make_dataset
-from rl.ppo import PPO, Policy, RolloutStorage
-from rl.ppo.utils import batch_obs, ppo_args, update_linear_schedule
-from tensorboard_utils import get_tensorboard_writer
-
-
-class NavRLEnv(habitat.RLEnv):
-    def __init__(self, config_env, config_baseline, dataset):
-        self._config_env = config_env.TASK
-        self._config_baseline = config_baseline
-        self._previous_target_distance = None
-        self._previous_action = None
-        self._episode_distance_covered = None
-        super().__init__(config_env, dataset)
-
-    def reset(self):
-        self._previous_action = None
-
-        observations = super().reset()
-
-        self._previous_target_distance = self.habitat_env.current_episode.info[
-            "geodesic_distance"
-        ]
-        return observations
-
-    def step(self, action):
-        self._previous_action = action
-        return super().step(action)
-
-    def get_reward_range(self):
-        return (
-            self._config_baseline.BASELINE.RL.SLACK_REWARD - 1.0,
-            self._config_baseline.BASELINE.RL.SUCCESS_REWARD + 1.0,
-        )
-
-    def get_reward(self, observations):
-        reward = self._config_baseline.BASELINE.RL.SLACK_REWARD
-
-        current_target_distance = self._distance_target()
-        reward += self._previous_target_distance - current_target_distance
-        self._previous_target_distance = current_target_distance
-
-        if self._episode_success():
-            reward += self._config_baseline.BASELINE.RL.SUCCESS_REWARD
-
-        return reward
-
-    def _distance_target(self):
-        current_position = self._env.sim.get_agent_state().position.tolist()
-        target_position = self._env.current_episode.goals[0].position
-        distance = self._env.sim.geodesic_distance(
-            current_position, target_position
-        )
-        return distance
-
-    def _episode_success(self):
-        if (
-            self._previous_action == SimulatorActions.STOP
-            and self._distance_target() < self._config_env.SUCCESS_DISTANCE
-        ):
-            return True
-        return False
-
-    def get_done(self, observations):
-        done = False
-        if self._env.episode_over or self._episode_success():
-            done = True
-        return done
-
-    def get_info(self, observations):
-        return self.habitat_env.get_metrics()
-
-
-def make_env_fn(config_env, config_baseline, rank):
-    dataset = make_dataset(config_env.DATASET.TYPE, config=config_env.DATASET)
-    config_env.defrost()
-    config_env.SIMULATOR.SCENE = dataset.episodes[0].scene_id
-    config_env.freeze()
-    env = NavRLEnv(
-        config_env=config_env, config_baseline=config_baseline, dataset=dataset
-    )
-    env.seed(rank)
-    return env
-
-
-def construct_envs(args):
-    env_configs = []
-    baseline_configs = []
-
-    basic_config = cfg_env(config_paths=args.task_config, opts=args.opts)
-    dataset = make_dataset(basic_config.DATASET.TYPE)
-    scenes = dataset.get_scenes_to_load(basic_config.DATASET)
-
-    if len(scenes) > 0:
-        random.shuffle(scenes)
-
-        assert len(scenes) >= args.num_processes, (
-            "reduce the number of processes as there "
-            "aren't enough number of scenes"
-        )
-        scene_split_size = int(np.floor(len(scenes) / args.num_processes))
-
-    scene_splits = [[] for _ in range(args.num_processes)]
-    for j, s in enumerate(scenes):
-        scene_splits[j % len(scene_splits)].append(s)
-
-    assert sum(map(len, scene_splits)) == len(scenes)
-
-    for i in range(args.num_processes):
-        config_env = cfg_env(config_paths=args.task_config, opts=args.opts)
-        config_env.defrost()
-
-        if len(scenes) > 0:
-            config_env.DATASET.CONTENT_SCENES = scene_splits[i]
-
-        config_env.SIMULATOR.HABITAT_SIM_V0.GPU_DEVICE_ID = args.sim_gpu_id
-
-        agent_sensors = args.sensors.strip().split(",")
-        for sensor in agent_sensors:
-            assert sensor in ["RGB_SENSOR", "DEPTH_SENSOR"]
-        config_env.SIMULATOR.AGENT_0.SENSORS = agent_sensors
-        config_env.freeze()
-        env_configs.append(config_env)
-
-        config_baseline = cfg_baseline()
-        baseline_configs.append(config_baseline)
-
-        logger.info("config_env: {}".format(config_env))
-
-    envs = habitat.VectorEnv(
-        make_env_fn=make_env_fn,
-        env_fn_args=tuple(
-            tuple(
-                zip(env_configs, baseline_configs, range(args.num_processes))
-            )
-        ),
-    )
-
-    return envs
-
-
-def run_training():
-    parser = ppo_args()
-    args = parser.parse_args()
-
-    random.seed(args.seed)
-
-    device = torch.device("cuda", args.pth_gpu_id)
-
-    logger.add_filehandler(args.log_file)
-
-    if not os.path.isdir(args.checkpoint_folder):
-        os.makedirs(args.checkpoint_folder)
-
-    for p in sorted(list(vars(args))):
-        logger.info("{}: {}".format(p, getattr(args, p)))
-
-    envs = construct_envs(args)
-    task_cfg = cfg_env(config_paths=args.task_config)
-    actor_critic = Policy(
-        observation_space=envs.observation_spaces[0],
-        action_space=envs.action_spaces[0],
-        hidden_size=args.hidden_size,
-        goal_sensor_uuid=task_cfg.TASK.GOAL_SENSOR_UUID,
-    )
-    actor_critic.to(device)
-
-    agent = PPO(
-        actor_critic,
-        args.clip_param,
-        args.ppo_epoch,
-        args.num_mini_batch,
-        args.value_loss_coef,
-        args.entropy_coef,
-        lr=args.lr,
-        eps=args.eps,
-        max_grad_norm=args.max_grad_norm,
-    )
-
-    logger.info(
-        "agent number of parameters: {}".format(
-            sum(param.numel() for param in agent.parameters())
-        )
-    )
-
-    observations = envs.reset()
-
-    batch = batch_obs(observations)
-
-    rollouts = RolloutStorage(
-        args.num_steps,
-        envs.num_envs,
-        envs.observation_spaces[0],
-        envs.action_spaces[0],
-        args.hidden_size,
-    )
-    for sensor in rollouts.observations:
-        rollouts.observations[sensor][0].copy_(batch[sensor])
-    rollouts.to(device)
-
-    episode_rewards = torch.zeros(envs.num_envs, 1)
-    episode_counts = torch.zeros(envs.num_envs, 1)
-    current_episode_reward = torch.zeros(envs.num_envs, 1)
-    window_episode_reward = deque(maxlen=args.reward_window_size)
-    window_episode_counts = deque(maxlen=args.reward_window_size)
-
-    t_start = time()
-    env_time = 0
-    pth_time = 0
-    count_steps = 0
-    count_checkpoints = 0
-
-    with (
-        get_tensorboard_writer(
-            log_dir=args.tensorboard_dir, purge_step=count_steps, flush_secs=30
-        )
-    ) as writer:
-        for update in range(args.num_updates):
-            if args.use_linear_lr_decay:
-                update_linear_schedule(
-                    agent.optimizer, update, args.num_updates, args.lr
-                )
-
-            agent.clip_param = args.clip_param * (
-                1 - update / args.num_updates
-            )
-
-            for step in range(args.num_steps):
-                t_sample_action = time()
-                # sample actions
-                with torch.no_grad():
-                    step_observation = {
-                        k: v[step] for k, v in rollouts.observations.items()
-                    }
-
-                    (
-                        values,
-                        actions,
-                        actions_log_probs,
-                        recurrent_hidden_states,
-                    ) = actor_critic.act(
-                        step_observation,
-                        rollouts.recurrent_hidden_states[step],
-                        rollouts.masks[step],
-                    )
-                pth_time += time() - t_sample_action
-
-                t_step_env = time()
-
-                outputs = envs.step([a[0].item() for a in actions])
-                observations, rewards, dones, infos = [
-                    list(x) for x in zip(*outputs)
-                ]
-
-                env_time += time() - t_step_env
-
-                t_update_stats = time()
-                batch = batch_obs(observations)
-                rewards = torch.tensor(rewards, dtype=torch.float)
-                rewards = rewards.unsqueeze(1)
-
-                masks = torch.tensor(
-                    [[0.0] if done else [1.0] for done in dones],
-                    dtype=torch.float,
-                )
-
-                current_episode_reward += rewards
-                episode_rewards += (1 - masks) * current_episode_reward
-                episode_counts += 1 - masks
-                current_episode_reward *= masks
-
-                rollouts.insert(
-                    batch,
-                    recurrent_hidden_states,
-                    actions,
-                    actions_log_probs,
-                    values,
-                    rewards,
-                    masks,
-                )
-
-                count_steps += envs.num_envs
-                pth_time += time() - t_update_stats
-
-            window_episode_reward.append(episode_rewards.clone())
-            window_episode_counts.append(episode_counts.clone())
-
-            t_update_model = time()
-            with torch.no_grad():
-                last_observation = {
-                    k: v[-1] for k, v in rollouts.observations.items()
-                }
-                next_value = actor_critic.get_value(
-                    last_observation,
-                    rollouts.recurrent_hidden_states[-1],
-                    rollouts.masks[-1],
-                ).detach()
-
-            rollouts.compute_returns(
-                next_value, args.use_gae, args.gamma, args.tau
-            )
-
-            value_loss, action_loss, dist_entropy = agent.update(rollouts)
-
-            rollouts.after_update()
-            pth_time += time() - t_update_model
-
-            losses = [value_loss, action_loss]
-            stats = zip(
-                ["count", "reward"],
-                [window_episode_counts, window_episode_reward],
-            )
-            deltas = {
-                k: (
-                    (v[-1] - v[0]).sum().item()
-                    if len(v) > 1
-                    else v[0].sum().item()
-                )
-                for k, v in stats
-            }
-            deltas["count"] = max(deltas["count"], 1.0)
-
-            writer.add_scalar(
-                "reward", deltas["reward"] / deltas["count"], count_steps
-            )
-
-            writer.add_scalars(
-                "losses",
-                {k: l for l, k in zip(losses, ["value", "policy"])},
-                count_steps,
-            )
-
-            # log stats
-            if update > 0 and update % args.log_interval == 0:
-                logger.info(
-                    "update: {}\tfps: {:.3f}\t".format(
-                        update, count_steps / (time() - t_start)
-                    )
-                )
-
-                logger.info(
-                    "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
-                    "frames: {}".format(
-                        update, env_time, pth_time, count_steps
-                    )
-                )
-
-                window_rewards = (
-                    window_episode_reward[-1] - window_episode_reward[0]
-                ).sum()
-                window_counts = (
-                    window_episode_counts[-1] - window_episode_counts[0]
-                ).sum()
-
-                if window_counts > 0:
-                    logger.info(
-                        "Average window size {} reward: {:3f}".format(
-                            len(window_episode_reward),
-                            (window_rewards / window_counts).item(),
-                        )
-                    )
-                else:
-                    logger.info("No episodes finish in current window")
-
-            # checkpoint model
-            if update % args.checkpoint_interval == 0:
-                checkpoint = {"state_dict": agent.state_dict(), "args": args}
-                torch.save(
-                    checkpoint,
-                    os.path.join(
-                        args.checkpoint_folder,
-                        "ckpt.{}.pth".format(count_checkpoints),
-                    ),
-                )
-                count_checkpoints += 1
-
-
-if __name__ == "__main__":
-    run_training()
-- 
GitLab