From d3d6a5eb2fa8d1a0878e647e4f965834898af6cc Mon Sep 17 00:00:00 2001
From: JasonJiazhiZhang <21229070+JasonJiazhiZhang@users.noreply.github.com>
Date: Thu, 25 Jul 2019 00:51:51 -0700
Subject: [PATCH] Update baseline test import and fix docstring (#170)

Update baseline test import and fix docstring
---
 .circleci/config.yml                          |  2 +-
 habitat_baselines/common/base_trainer.py      |  6 ++--
 habitat_baselines/common/baseline_registry.py |  1 +
 habitat_baselines/common/env_utils.py         |  8 ++---
 habitat_baselines/common/rollout_storage.py   |  8 ++---
 habitat_baselines/common/tensorboard_utils.py |  7 +----
 habitat_baselines/common/utils.py             | 15 +++++-----
 habitat_baselines/rl/ppo/ppo_trainer.py       | 29 +++++++++----------
 test/test_baseline_agents.py                  | 18 +++++++-----
 9 files changed, 45 insertions(+), 49 deletions(-)

diff --git a/.circleci/config.yml b/.circleci/config.yml
index 488f76040..6bd40c6fd 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -194,8 +194,8 @@ jobs:
           command: |
               export PATH=$HOME/miniconda/bin:$PATH
               . activate habitat; cd habitat-api
-              python setup.py test
               python setup.py develop --all
+              python setup.py test
               python -u habitat_baselines/run.py --exp-config habitat_baselines/config/pointnav/ppo_train_test.yaml --run-type train
 
 
diff --git a/habitat_baselines/common/base_trainer.py b/habitat_baselines/common/base_trainer.py
index cc8ffed86..9bbf129f6 100644
--- a/habitat_baselines/common/base_trainer.py
+++ b/habitat_baselines/common/base_trainer.py
@@ -8,8 +8,7 @@ from typing import ClassVar, Dict, List
 
 
 class BaseTrainer:
-    """
-    Most generic trainer class that serves as a base template for more
+    r"""Generic trainer class that serves as a base template for more
     specific trainer classes like RL trainer, SLAM or imitation learner.
     Includes only the most basic functionality.
     """
@@ -30,8 +29,7 @@ class BaseTrainer:
 
 
 class BaseRLTrainer(BaseTrainer):
-    """
-    Base trainer class for RL based trainers. Future RL-specific
+    r"""Base trainer class for RL trainers. Future RL-specific
     methods should be hosted here.
     """
 
diff --git a/habitat_baselines/common/baseline_registry.py b/habitat_baselines/common/baseline_registry.py
index fe54971f9..d49c787ad 100644
--- a/habitat_baselines/common/baseline_registry.py
+++ b/habitat_baselines/common/baseline_registry.py
@@ -47,6 +47,7 @@ class BaselineRegistry(Registry):
     def register_env(cls, to_register=None, *, name: Optional[str] = None):
         r"""Register a environment to registry with key 'name'
             currently only support subclass of RLEnv.
+
         Args:
             name: Key with which the env will be registered.
                 If None will use the name of the class.
diff --git a/habitat_baselines/common/env_utils.py b/habitat_baselines/common/env_utils.py
index 89d3d4c2a..b656a1abb 100644
--- a/habitat_baselines/common/env_utils.py
+++ b/habitat_baselines/common/env_utils.py
@@ -14,9 +14,9 @@ from habitat import Config, Env, VectorEnv, make_dataset
 def make_env_fn(
     task_config: Config, rl_env_config: Config, env_class: Type, rank: int
 ) -> Env:
-    r"""
-    Creates an env of type env_class with specified config and rank.
+    r"""Creates an env of type env_class with specified config and rank.
     This is to be passed in as an argument when creating VectorEnv.
+
     Args:
         task_config: task config file for creating env.
         rl_env_config: RL env config for creating env.
@@ -37,10 +37,10 @@ def make_env_fn(
 
 
 def construct_envs(config: Config, env_class: Type) -> VectorEnv:
-    r"""
-    Create VectorEnv object with specified config and env class type.
+    r"""Create VectorEnv object with specified config and env class type.
     To allow better performance, dataset are split into small ones for
     each individual env, grouped by scenes.
+
     Args:
         config: configs that contain num_processes as well as information
         necessary to create individual environments.
diff --git a/habitat_baselines/common/rollout_storage.py b/habitat_baselines/common/rollout_storage.py
index 908e50b63..3f9b5dd26 100644
--- a/habitat_baselines/common/rollout_storage.py
+++ b/habitat_baselines/common/rollout_storage.py
@@ -10,8 +10,8 @@ import torch
 
 
 class RolloutStorage:
-    r"""
-    Class for storing rollout information for RL trainers
+    r"""Class for storing rollout information for RL trainers.
+
     """
 
     def __init__(
@@ -210,8 +210,8 @@ class RolloutStorage:
 
     @staticmethod
     def _flatten_helper(t: int, n: int, tensor: torch.Tensor) -> torch.Tensor:
-        r"""
-        Given a tensor of size (t, n, ..), flatten it to size (t*n, ...).
+        r"""Given a tensor of size (t, n, ..), flatten it to size (t*n, ...).
+
         Args:
             t: first dimension of tensor.
             n: second dimension of tensor.
diff --git a/habitat_baselines/common/tensorboard_utils.py b/habitat_baselines/common/tensorboard_utils.py
index ec5799de4..b1b9fead3 100644
--- a/habitat_baselines/common/tensorboard_utils.py
+++ b/habitat_baselines/common/tensorboard_utils.py
@@ -8,6 +8,7 @@ from typing import Union
 
 import numpy as np
 import torch
+from torch.utils.tensorboard import SummaryWriter
 
 
 # TODO Add checks to replace DummyWriter
@@ -28,12 +29,6 @@ class DummyWriter:
         return lambda *args, **kwargs: None
 
 
-try:
-    from torch.utils.tensorboard import SummaryWriter
-except ImportError:
-    SummaryWriter = DummyWriter
-
-
 class TensorboardWriter(SummaryWriter):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
diff --git a/habitat_baselines/common/utils.py b/habitat_baselines/common/utils.py
index 54fb6a38f..eb78aa770 100644
--- a/habitat_baselines/common/utils.py
+++ b/habitat_baselines/common/utils.py
@@ -27,8 +27,8 @@ from habitat_baselines.common.tensorboard_utils import (
 
 
 def get_trainer(trainer_name: str, trainer_cfg: Config) -> BaseTrainer:
-    r"""
-    Create specific trainer instance according to name.
+    r"""Create specific trainer instance according to name.
+
     Args:
         trainer_name: name of registered trainer .
         trainer_cfg: config file for trainer.
@@ -79,7 +79,8 @@ class CategoricalNet(nn.Module):
 
 # TODO make this a  LRScheduler class
 def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
-    r"""Decreases the learning rate linearly
+    r"""Decreases the learning rate linearly.
+
     """
     lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
     for param_group in optimizer.param_groups:
@@ -87,9 +88,9 @@ def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
 
 
 def batch_obs(observations: List[Dict]) -> Dict:
-    r"""
-    Transpose a batch of observation dicts to a dict of batched
+    r"""Transpose a batch of observation dicts to a dict of batched
     observations.
+
     Args:
         observations:  list of dicts of observations.
 
@@ -143,8 +144,8 @@ def generate_video(
     tb_writer: Union[DummyWriter, TensorboardWriter],
     fps: int = 10,
 ) -> None:
-    r"""
-    Generate video according to specified information.
+    r"""Generate video according to specified information.
+    
     Args:
         config: config object that contains video_option and video_dir.
         images: list of images to be converted to video.
diff --git a/habitat_baselines/rl/ppo/ppo_trainer.py b/habitat_baselines/rl/ppo/ppo_trainer.py
index 31d4b4782..602b22e82 100644
--- a/habitat_baselines/rl/ppo/ppo_trainer.py
+++ b/habitat_baselines/rl/ppo/ppo_trainer.py
@@ -34,9 +34,8 @@ from habitat_baselines.rl.ppo import PPO, Policy
 
 @baseline_registry.register_trainer(name="ppo")
 class PPOTrainer(BaseRLTrainer):
-    r"""
-    Trainer class for PPO algorithm
-    Paper: https://arxiv.org/abs/1707.06347
+    r"""Trainer class for PPO algorithm
+    Paper: https://arxiv.org/abs/1707.06347.
     """
     supported_tasks = ["Nav-v0"]
 
@@ -51,8 +50,8 @@ class PPOTrainer(BaseRLTrainer):
             logger.info(f"config: {config}")
 
     def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None:
-        r"""
-        Sets up actor critic and agent for PPO
+        r"""Sets up actor critic and agent for PPO.
+
         Args:
             ppo_cfg: config node with relevant params
 
@@ -82,8 +81,8 @@ class PPOTrainer(BaseRLTrainer):
         )
 
     def save_checkpoint(self, file_name: str) -> None:
-        r"""
-        Save checkpoint with specified name
+        r"""Save checkpoint with specified name.
+
         Args:
             file_name: file name for checkpoint
 
@@ -102,8 +101,8 @@ class PPOTrainer(BaseRLTrainer):
         )
 
     def load_checkpoint(self, checkpoint_path: str, *args, **kwargs) -> Dict:
-        r"""
-        Load checkpoint of specified path as a dict
+        r"""Load checkpoint of specified path as a dict.
+
         Args:
             checkpoint_path: path of target checkpoint
             *args: additional positional args
@@ -115,8 +114,8 @@ class PPOTrainer(BaseRLTrainer):
         return torch.load(checkpoint_path, map_location=self.device)
 
     def train(self) -> None:
-        r"""
-        Main method for training PPO
+        r"""Main method for training PPO.
+
         Returns:
             None
         """
@@ -330,8 +329,8 @@ class PPOTrainer(BaseRLTrainer):
                     count_checkpoints += 1
 
     def eval(self) -> None:
-        r"""
-        Main method of evaluating PPO
+        r"""Main method of evaluating PPO.
+
         Returns:
             None
         """
@@ -382,8 +381,8 @@ class PPOTrainer(BaseRLTrainer):
         writer: TensorboardWriter,
         cur_ckpt_idx: int = 0,
     ) -> None:
-        r"""
-        Evaluates a single checkpoint
+        r"""Evaluates a single checkpoint.
+
         Args:
             checkpoint_path: path of checkpoint
             writer: tensorboard writer object for logging to tensorboard
diff --git a/test/test_baseline_agents.py b/test/test_baseline_agents.py
index 3175b7c2a..a4f1d663d 100644
--- a/test/test_baseline_agents.py
+++ b/test/test_baseline_agents.py
@@ -9,22 +9,21 @@ import os
 import pytest
 
 import habitat
-from habitat_baselines.agents import simple_agents
 
 try:
-    import torch  # noqa # pylint: disable=unused-import
+    from habitat_baselines.agents import ppo_agents
+    from habitat_baselines.agents import simple_agents
 
-    has_torch = True
+    baseline_installed = True
 except ImportError:
-    has_torch = False
-
-if has_torch:
-    from habitat_baselines.agents import ppo_agents
+    baseline_installed = False
 
 CFG_TEST = "configs/test/habitat_all_sensors_test.yaml"
 
 
-@pytest.mark.skipif(not has_torch, reason="Test needs torch")
+@pytest.mark.skipif(
+    not baseline_installed, reason="baseline sub-module not installed"
+)
 def test_ppo_agents():
     agent_config = ppo_agents.get_default_config()
     agent_config.MODEL_PATH = ""
@@ -50,6 +49,9 @@ def test_ppo_agents():
         habitat.logger.info(benchmark.evaluate(agent, num_episodes=10))
 
 
+@pytest.mark.skipif(
+    not baseline_installed, reason="baseline sub-module not installed"
+)
 def test_simple_agents():
     config_env = habitat.get_config(config_paths=CFG_TEST)
 
-- 
GitLab