Skip to content
Snippets Groups Projects
Commit d3d6a5eb authored by JasonJiazhiZhang's avatar JasonJiazhiZhang Committed by Oleksandr
Browse files

Update baseline test import and fix docstring (#170)

Update baseline test import and fix docstring 
parent 31318f81
No related branches found
No related tags found
No related merge requests found
...@@ -194,8 +194,8 @@ jobs: ...@@ -194,8 +194,8 @@ jobs:
command: | command: |
export PATH=$HOME/miniconda/bin:$PATH export PATH=$HOME/miniconda/bin:$PATH
. activate habitat; cd habitat-api . activate habitat; cd habitat-api
python setup.py test
python setup.py develop --all python setup.py develop --all
python setup.py test
python -u habitat_baselines/run.py --exp-config habitat_baselines/config/pointnav/ppo_train_test.yaml --run-type train python -u habitat_baselines/run.py --exp-config habitat_baselines/config/pointnav/ppo_train_test.yaml --run-type train
......
...@@ -8,8 +8,7 @@ from typing import ClassVar, Dict, List ...@@ -8,8 +8,7 @@ from typing import ClassVar, Dict, List
class BaseTrainer: class BaseTrainer:
""" r"""Generic trainer class that serves as a base template for more
Most generic trainer class that serves as a base template for more
specific trainer classes like RL trainer, SLAM or imitation learner. specific trainer classes like RL trainer, SLAM or imitation learner.
Includes only the most basic functionality. Includes only the most basic functionality.
""" """
...@@ -30,8 +29,7 @@ class BaseTrainer: ...@@ -30,8 +29,7 @@ class BaseTrainer:
class BaseRLTrainer(BaseTrainer): class BaseRLTrainer(BaseTrainer):
""" r"""Base trainer class for RL trainers. Future RL-specific
Base trainer class for RL based trainers. Future RL-specific
methods should be hosted here. methods should be hosted here.
""" """
......
...@@ -47,6 +47,7 @@ class BaselineRegistry(Registry): ...@@ -47,6 +47,7 @@ class BaselineRegistry(Registry):
def register_env(cls, to_register=None, *, name: Optional[str] = None): def register_env(cls, to_register=None, *, name: Optional[str] = None):
r"""Register a environment to registry with key 'name' r"""Register a environment to registry with key 'name'
currently only support subclass of RLEnv. currently only support subclass of RLEnv.
Args: Args:
name: Key with which the env will be registered. name: Key with which the env will be registered.
If None will use the name of the class. If None will use the name of the class.
......
...@@ -14,9 +14,9 @@ from habitat import Config, Env, VectorEnv, make_dataset ...@@ -14,9 +14,9 @@ from habitat import Config, Env, VectorEnv, make_dataset
def make_env_fn( def make_env_fn(
task_config: Config, rl_env_config: Config, env_class: Type, rank: int task_config: Config, rl_env_config: Config, env_class: Type, rank: int
) -> Env: ) -> Env:
r""" r"""Creates an env of type env_class with specified config and rank.
Creates an env of type env_class with specified config and rank.
This is to be passed in as an argument when creating VectorEnv. This is to be passed in as an argument when creating VectorEnv.
Args: Args:
task_config: task config file for creating env. task_config: task config file for creating env.
rl_env_config: RL env config for creating env. rl_env_config: RL env config for creating env.
...@@ -37,10 +37,10 @@ def make_env_fn( ...@@ -37,10 +37,10 @@ def make_env_fn(
def construct_envs(config: Config, env_class: Type) -> VectorEnv: def construct_envs(config: Config, env_class: Type) -> VectorEnv:
r""" r"""Create VectorEnv object with specified config and env class type.
Create VectorEnv object with specified config and env class type.
To allow better performance, dataset are split into small ones for To allow better performance, dataset are split into small ones for
each individual env, grouped by scenes. each individual env, grouped by scenes.
Args: Args:
config: configs that contain num_processes as well as information config: configs that contain num_processes as well as information
necessary to create individual environments. necessary to create individual environments.
......
...@@ -10,8 +10,8 @@ import torch ...@@ -10,8 +10,8 @@ import torch
class RolloutStorage: class RolloutStorage:
r""" r"""Class for storing rollout information for RL trainers.
Class for storing rollout information for RL trainers
""" """
def __init__( def __init__(
...@@ -210,8 +210,8 @@ class RolloutStorage: ...@@ -210,8 +210,8 @@ class RolloutStorage:
@staticmethod @staticmethod
def _flatten_helper(t: int, n: int, tensor: torch.Tensor) -> torch.Tensor: def _flatten_helper(t: int, n: int, tensor: torch.Tensor) -> torch.Tensor:
r""" r"""Given a tensor of size (t, n, ..), flatten it to size (t*n, ...).
Given a tensor of size (t, n, ..), flatten it to size (t*n, ...).
Args: Args:
t: first dimension of tensor. t: first dimension of tensor.
n: second dimension of tensor. n: second dimension of tensor.
......
...@@ -8,6 +8,7 @@ from typing import Union ...@@ -8,6 +8,7 @@ from typing import Union
import numpy as np import numpy as np
import torch import torch
from torch.utils.tensorboard import SummaryWriter
# TODO Add checks to replace DummyWriter # TODO Add checks to replace DummyWriter
...@@ -28,12 +29,6 @@ class DummyWriter: ...@@ -28,12 +29,6 @@ class DummyWriter:
return lambda *args, **kwargs: None return lambda *args, **kwargs: None
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = DummyWriter
class TensorboardWriter(SummaryWriter): class TensorboardWriter(SummaryWriter):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
......
...@@ -27,8 +27,8 @@ from habitat_baselines.common.tensorboard_utils import ( ...@@ -27,8 +27,8 @@ from habitat_baselines.common.tensorboard_utils import (
def get_trainer(trainer_name: str, trainer_cfg: Config) -> BaseTrainer: def get_trainer(trainer_name: str, trainer_cfg: Config) -> BaseTrainer:
r""" r"""Create specific trainer instance according to name.
Create specific trainer instance according to name.
Args: Args:
trainer_name: name of registered trainer . trainer_name: name of registered trainer .
trainer_cfg: config file for trainer. trainer_cfg: config file for trainer.
...@@ -79,7 +79,8 @@ class CategoricalNet(nn.Module): ...@@ -79,7 +79,8 @@ class CategoricalNet(nn.Module):
# TODO make this a LRScheduler class # TODO make this a LRScheduler class
def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
r"""Decreases the learning rate linearly r"""Decreases the learning rate linearly.
""" """
lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
...@@ -87,9 +88,9 @@ def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): ...@@ -87,9 +88,9 @@ def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
def batch_obs(observations: List[Dict]) -> Dict: def batch_obs(observations: List[Dict]) -> Dict:
r""" r"""Transpose a batch of observation dicts to a dict of batched
Transpose a batch of observation dicts to a dict of batched
observations. observations.
Args: Args:
observations: list of dicts of observations. observations: list of dicts of observations.
...@@ -143,8 +144,8 @@ def generate_video( ...@@ -143,8 +144,8 @@ def generate_video(
tb_writer: Union[DummyWriter, TensorboardWriter], tb_writer: Union[DummyWriter, TensorboardWriter],
fps: int = 10, fps: int = 10,
) -> None: ) -> None:
r""" r"""Generate video according to specified information.
Generate video according to specified information.
Args: Args:
config: config object that contains video_option and video_dir. config: config object that contains video_option and video_dir.
images: list of images to be converted to video. images: list of images to be converted to video.
......
...@@ -34,9 +34,8 @@ from habitat_baselines.rl.ppo import PPO, Policy ...@@ -34,9 +34,8 @@ from habitat_baselines.rl.ppo import PPO, Policy
@baseline_registry.register_trainer(name="ppo") @baseline_registry.register_trainer(name="ppo")
class PPOTrainer(BaseRLTrainer): class PPOTrainer(BaseRLTrainer):
r""" r"""Trainer class for PPO algorithm
Trainer class for PPO algorithm Paper: https://arxiv.org/abs/1707.06347.
Paper: https://arxiv.org/abs/1707.06347
""" """
supported_tasks = ["Nav-v0"] supported_tasks = ["Nav-v0"]
...@@ -51,8 +50,8 @@ class PPOTrainer(BaseRLTrainer): ...@@ -51,8 +50,8 @@ class PPOTrainer(BaseRLTrainer):
logger.info(f"config: {config}") logger.info(f"config: {config}")
def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None: def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None:
r""" r"""Sets up actor critic and agent for PPO.
Sets up actor critic and agent for PPO
Args: Args:
ppo_cfg: config node with relevant params ppo_cfg: config node with relevant params
...@@ -82,8 +81,8 @@ class PPOTrainer(BaseRLTrainer): ...@@ -82,8 +81,8 @@ class PPOTrainer(BaseRLTrainer):
) )
def save_checkpoint(self, file_name: str) -> None: def save_checkpoint(self, file_name: str) -> None:
r""" r"""Save checkpoint with specified name.
Save checkpoint with specified name
Args: Args:
file_name: file name for checkpoint file_name: file name for checkpoint
...@@ -102,8 +101,8 @@ class PPOTrainer(BaseRLTrainer): ...@@ -102,8 +101,8 @@ class PPOTrainer(BaseRLTrainer):
) )
def load_checkpoint(self, checkpoint_path: str, *args, **kwargs) -> Dict: def load_checkpoint(self, checkpoint_path: str, *args, **kwargs) -> Dict:
r""" r"""Load checkpoint of specified path as a dict.
Load checkpoint of specified path as a dict
Args: Args:
checkpoint_path: path of target checkpoint checkpoint_path: path of target checkpoint
*args: additional positional args *args: additional positional args
...@@ -115,8 +114,8 @@ class PPOTrainer(BaseRLTrainer): ...@@ -115,8 +114,8 @@ class PPOTrainer(BaseRLTrainer):
return torch.load(checkpoint_path, map_location=self.device) return torch.load(checkpoint_path, map_location=self.device)
def train(self) -> None: def train(self) -> None:
r""" r"""Main method for training PPO.
Main method for training PPO
Returns: Returns:
None None
""" """
...@@ -330,8 +329,8 @@ class PPOTrainer(BaseRLTrainer): ...@@ -330,8 +329,8 @@ class PPOTrainer(BaseRLTrainer):
count_checkpoints += 1 count_checkpoints += 1
def eval(self) -> None: def eval(self) -> None:
r""" r"""Main method of evaluating PPO.
Main method of evaluating PPO
Returns: Returns:
None None
""" """
...@@ -382,8 +381,8 @@ class PPOTrainer(BaseRLTrainer): ...@@ -382,8 +381,8 @@ class PPOTrainer(BaseRLTrainer):
writer: TensorboardWriter, writer: TensorboardWriter,
cur_ckpt_idx: int = 0, cur_ckpt_idx: int = 0,
) -> None: ) -> None:
r""" r"""Evaluates a single checkpoint.
Evaluates a single checkpoint
Args: Args:
checkpoint_path: path of checkpoint checkpoint_path: path of checkpoint
writer: tensorboard writer object for logging to tensorboard writer: tensorboard writer object for logging to tensorboard
......
...@@ -9,22 +9,21 @@ import os ...@@ -9,22 +9,21 @@ import os
import pytest import pytest
import habitat import habitat
from habitat_baselines.agents import simple_agents
try: try:
import torch # noqa # pylint: disable=unused-import from habitat_baselines.agents import ppo_agents
from habitat_baselines.agents import simple_agents
has_torch = True baseline_installed = True
except ImportError: except ImportError:
has_torch = False baseline_installed = False
if has_torch:
from habitat_baselines.agents import ppo_agents
CFG_TEST = "configs/test/habitat_all_sensors_test.yaml" CFG_TEST = "configs/test/habitat_all_sensors_test.yaml"
@pytest.mark.skipif(not has_torch, reason="Test needs torch") @pytest.mark.skipif(
not baseline_installed, reason="baseline sub-module not installed"
)
def test_ppo_agents(): def test_ppo_agents():
agent_config = ppo_agents.get_default_config() agent_config = ppo_agents.get_default_config()
agent_config.MODEL_PATH = "" agent_config.MODEL_PATH = ""
...@@ -50,6 +49,9 @@ def test_ppo_agents(): ...@@ -50,6 +49,9 @@ def test_ppo_agents():
habitat.logger.info(benchmark.evaluate(agent, num_episodes=10)) habitat.logger.info(benchmark.evaluate(agent, num_episodes=10))
@pytest.mark.skipif(
not baseline_installed, reason="baseline sub-module not installed"
)
def test_simple_agents(): def test_simple_agents():
config_env = habitat.get_config(config_paths=CFG_TEST) config_env = habitat.get_config(config_paths=CFG_TEST)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment