From d3d6a5eb2fa8d1a0878e647e4f965834898af6cc Mon Sep 17 00:00:00 2001 From: JasonJiazhiZhang <21229070+JasonJiazhiZhang@users.noreply.github.com> Date: Thu, 25 Jul 2019 00:51:51 -0700 Subject: [PATCH] Update baseline test import and fix docstring (#170) Update baseline test import and fix docstring --- .circleci/config.yml | 2 +- habitat_baselines/common/base_trainer.py | 6 ++-- habitat_baselines/common/baseline_registry.py | 1 + habitat_baselines/common/env_utils.py | 8 ++--- habitat_baselines/common/rollout_storage.py | 8 ++--- habitat_baselines/common/tensorboard_utils.py | 7 +---- habitat_baselines/common/utils.py | 15 +++++----- habitat_baselines/rl/ppo/ppo_trainer.py | 29 +++++++++---------- test/test_baseline_agents.py | 18 +++++++----- 9 files changed, 45 insertions(+), 49 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 488f76040..6bd40c6fd 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -194,8 +194,8 @@ jobs: command: | export PATH=$HOME/miniconda/bin:$PATH . activate habitat; cd habitat-api - python setup.py test python setup.py develop --all + python setup.py test python -u habitat_baselines/run.py --exp-config habitat_baselines/config/pointnav/ppo_train_test.yaml --run-type train diff --git a/habitat_baselines/common/base_trainer.py b/habitat_baselines/common/base_trainer.py index cc8ffed86..9bbf129f6 100644 --- a/habitat_baselines/common/base_trainer.py +++ b/habitat_baselines/common/base_trainer.py @@ -8,8 +8,7 @@ from typing import ClassVar, Dict, List class BaseTrainer: - """ - Most generic trainer class that serves as a base template for more + r"""Generic trainer class that serves as a base template for more specific trainer classes like RL trainer, SLAM or imitation learner. Includes only the most basic functionality. """ @@ -30,8 +29,7 @@ class BaseTrainer: class BaseRLTrainer(BaseTrainer): - """ - Base trainer class for RL based trainers. Future RL-specific + r"""Base trainer class for RL trainers. Future RL-specific methods should be hosted here. """ diff --git a/habitat_baselines/common/baseline_registry.py b/habitat_baselines/common/baseline_registry.py index fe54971f9..d49c787ad 100644 --- a/habitat_baselines/common/baseline_registry.py +++ b/habitat_baselines/common/baseline_registry.py @@ -47,6 +47,7 @@ class BaselineRegistry(Registry): def register_env(cls, to_register=None, *, name: Optional[str] = None): r"""Register a environment to registry with key 'name' currently only support subclass of RLEnv. + Args: name: Key with which the env will be registered. If None will use the name of the class. diff --git a/habitat_baselines/common/env_utils.py b/habitat_baselines/common/env_utils.py index 89d3d4c2a..b656a1abb 100644 --- a/habitat_baselines/common/env_utils.py +++ b/habitat_baselines/common/env_utils.py @@ -14,9 +14,9 @@ from habitat import Config, Env, VectorEnv, make_dataset def make_env_fn( task_config: Config, rl_env_config: Config, env_class: Type, rank: int ) -> Env: - r""" - Creates an env of type env_class with specified config and rank. + r"""Creates an env of type env_class with specified config and rank. This is to be passed in as an argument when creating VectorEnv. + Args: task_config: task config file for creating env. rl_env_config: RL env config for creating env. @@ -37,10 +37,10 @@ def make_env_fn( def construct_envs(config: Config, env_class: Type) -> VectorEnv: - r""" - Create VectorEnv object with specified config and env class type. + r"""Create VectorEnv object with specified config and env class type. To allow better performance, dataset are split into small ones for each individual env, grouped by scenes. + Args: config: configs that contain num_processes as well as information necessary to create individual environments. diff --git a/habitat_baselines/common/rollout_storage.py b/habitat_baselines/common/rollout_storage.py index 908e50b63..3f9b5dd26 100644 --- a/habitat_baselines/common/rollout_storage.py +++ b/habitat_baselines/common/rollout_storage.py @@ -10,8 +10,8 @@ import torch class RolloutStorage: - r""" - Class for storing rollout information for RL trainers + r"""Class for storing rollout information for RL trainers. + """ def __init__( @@ -210,8 +210,8 @@ class RolloutStorage: @staticmethod def _flatten_helper(t: int, n: int, tensor: torch.Tensor) -> torch.Tensor: - r""" - Given a tensor of size (t, n, ..), flatten it to size (t*n, ...). + r"""Given a tensor of size (t, n, ..), flatten it to size (t*n, ...). + Args: t: first dimension of tensor. n: second dimension of tensor. diff --git a/habitat_baselines/common/tensorboard_utils.py b/habitat_baselines/common/tensorboard_utils.py index ec5799de4..b1b9fead3 100644 --- a/habitat_baselines/common/tensorboard_utils.py +++ b/habitat_baselines/common/tensorboard_utils.py @@ -8,6 +8,7 @@ from typing import Union import numpy as np import torch +from torch.utils.tensorboard import SummaryWriter # TODO Add checks to replace DummyWriter @@ -28,12 +29,6 @@ class DummyWriter: return lambda *args, **kwargs: None -try: - from torch.utils.tensorboard import SummaryWriter -except ImportError: - SummaryWriter = DummyWriter - - class TensorboardWriter(SummaryWriter): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/habitat_baselines/common/utils.py b/habitat_baselines/common/utils.py index 54fb6a38f..eb78aa770 100644 --- a/habitat_baselines/common/utils.py +++ b/habitat_baselines/common/utils.py @@ -27,8 +27,8 @@ from habitat_baselines.common.tensorboard_utils import ( def get_trainer(trainer_name: str, trainer_cfg: Config) -> BaseTrainer: - r""" - Create specific trainer instance according to name. + r"""Create specific trainer instance according to name. + Args: trainer_name: name of registered trainer . trainer_cfg: config file for trainer. @@ -79,7 +79,8 @@ class CategoricalNet(nn.Module): # TODO make this a LRScheduler class def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): - r"""Decreases the learning rate linearly + r"""Decreases the learning rate linearly. + """ lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) for param_group in optimizer.param_groups: @@ -87,9 +88,9 @@ def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): def batch_obs(observations: List[Dict]) -> Dict: - r""" - Transpose a batch of observation dicts to a dict of batched + r"""Transpose a batch of observation dicts to a dict of batched observations. + Args: observations: list of dicts of observations. @@ -143,8 +144,8 @@ def generate_video( tb_writer: Union[DummyWriter, TensorboardWriter], fps: int = 10, ) -> None: - r""" - Generate video according to specified information. + r"""Generate video according to specified information. + Args: config: config object that contains video_option and video_dir. images: list of images to be converted to video. diff --git a/habitat_baselines/rl/ppo/ppo_trainer.py b/habitat_baselines/rl/ppo/ppo_trainer.py index 31d4b4782..602b22e82 100644 --- a/habitat_baselines/rl/ppo/ppo_trainer.py +++ b/habitat_baselines/rl/ppo/ppo_trainer.py @@ -34,9 +34,8 @@ from habitat_baselines.rl.ppo import PPO, Policy @baseline_registry.register_trainer(name="ppo") class PPOTrainer(BaseRLTrainer): - r""" - Trainer class for PPO algorithm - Paper: https://arxiv.org/abs/1707.06347 + r"""Trainer class for PPO algorithm + Paper: https://arxiv.org/abs/1707.06347. """ supported_tasks = ["Nav-v0"] @@ -51,8 +50,8 @@ class PPOTrainer(BaseRLTrainer): logger.info(f"config: {config}") def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None: - r""" - Sets up actor critic and agent for PPO + r"""Sets up actor critic and agent for PPO. + Args: ppo_cfg: config node with relevant params @@ -82,8 +81,8 @@ class PPOTrainer(BaseRLTrainer): ) def save_checkpoint(self, file_name: str) -> None: - r""" - Save checkpoint with specified name + r"""Save checkpoint with specified name. + Args: file_name: file name for checkpoint @@ -102,8 +101,8 @@ class PPOTrainer(BaseRLTrainer): ) def load_checkpoint(self, checkpoint_path: str, *args, **kwargs) -> Dict: - r""" - Load checkpoint of specified path as a dict + r"""Load checkpoint of specified path as a dict. + Args: checkpoint_path: path of target checkpoint *args: additional positional args @@ -115,8 +114,8 @@ class PPOTrainer(BaseRLTrainer): return torch.load(checkpoint_path, map_location=self.device) def train(self) -> None: - r""" - Main method for training PPO + r"""Main method for training PPO. + Returns: None """ @@ -330,8 +329,8 @@ class PPOTrainer(BaseRLTrainer): count_checkpoints += 1 def eval(self) -> None: - r""" - Main method of evaluating PPO + r"""Main method of evaluating PPO. + Returns: None """ @@ -382,8 +381,8 @@ class PPOTrainer(BaseRLTrainer): writer: TensorboardWriter, cur_ckpt_idx: int = 0, ) -> None: - r""" - Evaluates a single checkpoint + r"""Evaluates a single checkpoint. + Args: checkpoint_path: path of checkpoint writer: tensorboard writer object for logging to tensorboard diff --git a/test/test_baseline_agents.py b/test/test_baseline_agents.py index 3175b7c2a..a4f1d663d 100644 --- a/test/test_baseline_agents.py +++ b/test/test_baseline_agents.py @@ -9,22 +9,21 @@ import os import pytest import habitat -from habitat_baselines.agents import simple_agents try: - import torch # noqa # pylint: disable=unused-import + from habitat_baselines.agents import ppo_agents + from habitat_baselines.agents import simple_agents - has_torch = True + baseline_installed = True except ImportError: - has_torch = False - -if has_torch: - from habitat_baselines.agents import ppo_agents + baseline_installed = False CFG_TEST = "configs/test/habitat_all_sensors_test.yaml" -@pytest.mark.skipif(not has_torch, reason="Test needs torch") +@pytest.mark.skipif( + not baseline_installed, reason="baseline sub-module not installed" +) def test_ppo_agents(): agent_config = ppo_agents.get_default_config() agent_config.MODEL_PATH = "" @@ -50,6 +49,9 @@ def test_ppo_agents(): habitat.logger.info(benchmark.evaluate(agent, num_episodes=10)) +@pytest.mark.skipif( + not baseline_installed, reason="baseline sub-module not installed" +) def test_simple_agents(): config_env = habitat.get_config(config_paths=CFG_TEST) -- GitLab