Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
test_baseline_agents.py 2.32 KiB
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os

import pytest

import habitat

try:
    from habitat_baselines.agents import ppo_agents
    from habitat_baselines.agents import simple_agents

    baseline_installed = True
except ImportError:
    baseline_installed = False

CFG_TEST = "configs/test/habitat_all_sensors_test.yaml"


@pytest.mark.skipif(
    not baseline_installed, reason="baseline sub-module not installed"
)
def test_ppo_agents():
    agent_config = ppo_agents.get_default_config()
    agent_config.MODEL_PATH = ""
    config_env = habitat.get_config(config_paths=CFG_TEST)
    if not os.path.exists(config_env.SIMULATOR.SCENE):
        pytest.skip("Please download Habitat test data to data folder.")

    benchmark = habitat.Benchmark(config_paths=CFG_TEST)

    for input_type in ["blind", "rgb", "depth", "rgbd"]:
        config_env.defrost()
        config_env.SIMULATOR.AGENT_0.SENSORS = []
        if input_type in ["rgb", "rgbd"]:
            config_env.SIMULATOR.AGENT_0.SENSORS += ["RGB_SENSOR"]
        if input_type in ["depth", "rgbd"]:
            config_env.SIMULATOR.AGENT_0.SENSORS += ["DEPTH_SENSOR"]
        config_env.freeze()
        del benchmark._env
        benchmark._env = habitat.Env(config=config_env)
        agent_config.INPUT_TYPE = input_type

        agent = ppo_agents.PPOAgent(agent_config)
        habitat.logger.info(benchmark.evaluate(agent, num_episodes=10))


@pytest.mark.skipif(
    not baseline_installed, reason="baseline sub-module not installed"
)
def test_simple_agents():
    config_env = habitat.get_config(config_paths=CFG_TEST)

    if not os.path.exists(config_env.SIMULATOR.SCENE):
        pytest.skip("Please download Habitat test data to data folder.")

    benchmark = habitat.Benchmark(config_paths=CFG_TEST)

    for agent_class in [
        simple_agents.ForwardOnlyAgent,
        simple_agents.GoalFollower,
        simple_agents.RandomAgent,
        simple_agents.RandomForwardAgent,
    ]:
        agent = agent_class(
            config_env.TASK.SUCCESS_DISTANCE, config_env.TASK.GOAL_SENSOR_UUID
        )
        habitat.logger.info(agent_class.__name__)
        habitat.logger.info(benchmark.evaluate(agent, num_episodes=100))