diff --git a/README.md b/README.md index 8921308d19ae19d0d2d1aff9730dbb0925b5a721..143012e771c5ad01ffb094774a286e93b32515f0 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ import habitat # Load embodied AI task (PointNav) and a pre-specified virtual robot env = habitat.Env( - config=habitat.get_config(config_file="tasks/pointnav.yaml") + config=habitat.get_config("configs/tasks/pointnav.yaml") ) observations = env.reset() diff --git a/baselines/README.md b/baselines/README.md index 42d1ccad68a9b14cd89ccfd6320e3de5b4513db2..9829d12c0df1c9e09189688462c7f36cb8d48bb8 100644 --- a/baselines/README.md +++ b/baselines/README.md @@ -35,7 +35,7 @@ python -u baselines/train_ppo.py \ --log-interval 5 \ --checkpoint-folder "data/checkpoints" \ --checkpoint-interval 50 \ - --task-config "tasks/pointnav.yaml" \ + --task-config "configs/tasks/pointnav.yaml" \ ``` @@ -48,7 +48,7 @@ python -u baselines/evaluate_ppo.py \ --pth-gpu-id 0 \ --num-processes 4 \ --count-test-episodes 100 \ - --task-config "tasks/pointnav.yaml" \ + --task-config "configs/tasks/pointnav.yaml" \ ``` diff --git a/baselines/agents/ppo_agents.py b/baselines/agents/ppo_agents.py index 0656eed499e21c5139807d1992afd61292ad8f9a..355dc9de7fffca70187f8941a4bcd4f8e2c0e1d8 100644 --- a/baselines/agents/ppo_agents.py +++ b/baselines/agents/ppo_agents.py @@ -131,7 +131,7 @@ def main(): ) parser.add_argument("--model-path", default="", type=str) parser.add_argument( - "--task-config", type=str, default="tasks/pointnav.yaml" + "--task-config", type=str, default="configs/tasks/pointnav.yaml" ) args = parser.parse_args() diff --git a/baselines/agents/simple_agents.py b/baselines/agents/simple_agents.py index 8010a8b4b06241379a5065958b6c50718d182634..e91b605b1d899aaffea680963eb3e1656b742466 100644 --- a/baselines/agents/simple_agents.py +++ b/baselines/agents/simple_agents.py @@ -122,7 +122,7 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--success-distance", type=float, default=0.2) parser.add_argument( - "--task-config", type=str, default="tasks/pointnav.yaml" + "--task-config", type=str, default="configs/tasks/pointnav.yaml" ) parser.add_argument("--agent-class", type=str, default="GoalFollower") args = parser.parse_args() diff --git a/baselines/agents/slam_agents.py b/baselines/agents/slam_agents.py index 0f6648847843e5b4ebacc13bc7b8769ec895e4f9..2ec49302745d2f41b0ee77adaaf4f9dffa85553c 100644 --- a/baselines/agents/slam_agents.py +++ b/baselines/agents/slam_agents.py @@ -26,7 +26,7 @@ from habitat.sims.habitat_simulator import SimulatorActions from baselines.slambased.mappers import DirectDepthMapper from baselines.slambased.path_planners import DifferentiableStarPlanner -from baselines.config.default import cfg +from baselines.config.default import get_config as cfg_baseline from habitat.config.default import get_config from baselines.slambased.monodepth import MonoDepthEstimator @@ -600,7 +600,7 @@ def main(): args = parser.parse_args() config = get_config() - agent_config = cfg() + agent_config = cfg_baseline() config.defrost() config.BASELINE = agent_config.BASELINE make_good_config_for_orbslam2(config) diff --git a/baselines/config/default.py b/baselines/config/default.py index 809691008174fbea2d9bf9545e2e08a06cd5d9f8..57b034158ce12d8d477e2e7277d713e8b89d4055 100644 --- a/baselines/config/default.py +++ b/baselines/config/default.py @@ -4,14 +4,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import os import numpy as np -from typing import Optional +from typing import List, Optional, Union from habitat import get_config from habitat.config import Config as CN DEFAULT_CONFIG_DIR = "configs/" - +CONFIG_FILE_SEPARATOR = "," # ----------------------------------------------------------------------------- # Config definition # ----------------------------------------------------------------------------- @@ -60,10 +59,33 @@ _C.BASELINE.ORBSLAM2.DEPTH_DENORM = ( ) -def cfg( - config_file: Optional[str] = None, config_dir: str = DEFAULT_CONFIG_DIR +def get_config( + config_paths: Optional[Union[List[str], str]] = None, + opts: Optional[list] = None, ) -> CN: + """ + Create a unified config with default values overwritten by values from + `config_paths` and overwritten by options from `opts`. + Args: + config_paths: List of config paths or string that contains comma + separated list of config paths. + opts: Config options (keys, values) in a list (e.g., passed from + command line into the config. For example, `opts = ['FOO.BAR', + 0.5]`. Argument can be used for parameter sweeping or quick tests. + """ config = _C.clone() - if config_file: - config.merge_from_file(os.path.join(config_dir, config_file)) + if config_paths: + if isinstance(config_paths, str): + if CONFIG_FILE_SEPARATOR in config_paths: + config_paths = config_paths.split(CONFIG_FILE_SEPARATOR) + else: + config_paths = [config_paths] + + for config_path in config_paths: + config.merge_from_file(config_path) + + if opts: + config.merge_from_list(opts) + + config.freeze() return config diff --git a/baselines/evaluate_ppo.py b/baselines/evaluate_ppo.py index b7277b8b3df368a0d0db38c98e40943e407135ae..43f340797fab36e9f5d0ccebbf398bf221ecef2f 100644 --- a/baselines/evaluate_ppo.py +++ b/baselines/evaluate_ppo.py @@ -10,7 +10,7 @@ import torch import habitat from habitat.config.default import get_config -from config.default import cfg as cfg_baseline +from config.default import get_config as cfg_baseline from train_ppo import make_env_fn from rl.ppo import PPO, Policy @@ -36,7 +36,7 @@ def main(): parser.add_argument( "--task-config", type=str, - default="tasks/pointnav.yaml", + default="configs/tasks/pointnav.yaml", help="path to config yaml containing information about task", ) args = parser.parse_args() @@ -47,7 +47,7 @@ def main(): baseline_configs = [] for _ in range(args.num_processes): - config_env = get_config(config_file=args.task_config) + config_env = get_config(config_paths=args.task_config) config_env.defrost() config_env.DATASET.SPLIT = "val" diff --git a/baselines/rl/ppo/utils.py b/baselines/rl/ppo/utils.py index db132eb318d8c81d3392aad482e3f8e7a7827d0a..ad29e04482acd2a8171c6461f53b9e94a40e4a8e 100644 --- a/baselines/rl/ppo/utils.py +++ b/baselines/rl/ppo/utils.py @@ -411,7 +411,7 @@ def ppo_args(): parser.add_argument( "--task-config", type=str, - default="tasks/pointnav.yaml", + default="configs/tasks/pointnav.yaml", help="path to config yaml containing information about task", ) parser.add_argument("--seed", type=int, default=100) diff --git a/baselines/train_ppo.py b/baselines/train_ppo.py index 2f396eed054c55ea71c8d67b27c5ae8f148e6721..b71238263630cbee285c20b15f9f36e3afe4092a 100644 --- a/baselines/train_ppo.py +++ b/baselines/train_ppo.py @@ -15,7 +15,7 @@ import habitat from habitat import logger from habitat.sims.habitat_simulator import SimulatorActions from habitat.config.default import get_config as cfg_env -from config.default import cfg as cfg_baseline +from config.default import get_config as cfg_baseline from habitat.datasets.pointnav.pointnav_dataset import PointNavDatasetV1 from rl.ppo import PPO, Policy, RolloutStorage from rl.ppo.utils import update_linear_schedule, ppo_args, batch_obs @@ -109,7 +109,7 @@ def construct_envs(args): env_configs = [] baseline_configs = [] - basic_config = cfg_env(config_file=args.task_config) + basic_config = cfg_env(config_paths=args.task_config) scenes = PointNavDatasetV1.get_scenes_to_load(basic_config.DATASET) @@ -123,7 +123,7 @@ def construct_envs(args): scene_split_size = int(np.floor(len(scenes) / args.num_processes)) for i in range(args.num_processes): - config_env = cfg_env(config_file=args.task_config) + config_env = cfg_env(config_paths=args.task_config) config_env.defrost() if len(scenes) > 0: diff --git a/examples/benchmark.py b/examples/benchmark.py index 470abddf37724b89a385b2c26dde9918cbf93b1b..9a680ec3bcc4bca9c2ce731656c1d0cc305ccd9b 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -22,7 +22,7 @@ class ForwardOnlyAgent(habitat.Agent): def main(): parser = argparse.ArgumentParser() parser.add_argument( - "--task-config", type=str, default="tasks/pointnav.yaml" + "--task-config", type=str, default="configs/tasks/pointnav.yaml" ) args = parser.parse_args() diff --git a/examples/example.py b/examples/example.py index cc1fb3687a0e65da725a414445a356da1a315263..6d627d07099c0c5b61cb271b6467c79bce7ddf45 100644 --- a/examples/example.py +++ b/examples/example.py @@ -9,7 +9,7 @@ import habitat def example(): env = habitat.Env( - config=habitat.get_config(config_file="tasks/pointnav.yaml") + config=habitat.get_config("configs/tasks/pointnav.yaml") ) print("Environment creation successful") diff --git a/examples/shortest_path_follower_example.py b/examples/shortest_path_follower_example.py index b3f8ca7b9d2636514e4d1b4f29942a501fab9422..2cd170f36a19b036d3624375c5b85ec034eb867c 100644 --- a/examples/shortest_path_follower_example.py +++ b/examples/shortest_path_follower_example.py @@ -57,7 +57,7 @@ def draw_top_down_map(info, heading, output_size): def shortest_path_example(mode): - config = habitat.get_config(config_file="tasks/pointnav.yaml") + config = habitat.get_config(config_paths="configs/tasks/pointnav.yaml") config.TASK.MEASUREMENTS.append("TOP_DOWN_MAP") config.TASK.SENSORS.append("HEADING_SENSOR") env = SimpleRLEnv(config=config) diff --git a/examples/visualization_examples.py b/examples/visualization_examples.py index e5be1bd296b7095c75f6ab4db8df19545d5f4d2f..2a02fe9309ac8b9ec7f3f28d452e18ba79b22141 100644 --- a/examples/visualization_examples.py +++ b/examples/visualization_examples.py @@ -80,7 +80,7 @@ def example_pointnav_draw_target_birdseye_view_agent_on_border(): def example_get_topdown_map(): - config = habitat.get_config(config_file="tasks/pointnav.yaml") + config = habitat.get_config(config_paths="configs/tasks/pointnav.yaml") dataset = habitat.make_dataset( id_dataset=config.DATASET.TYPE, config=config.DATASET ) diff --git a/habitat/config/default.py b/habitat/config/default.py index a4e191734720f50ccc2ca550baa19038985f51e9..3a165ceca2487976daa5b72e16c6bf6b0d6482aa 100644 --- a/habitat/config/default.py +++ b/habitat/config/default.py @@ -4,12 +4,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import os -from typing import Optional +from typing import List, Optional, Union from habitat.config import Config as CN # type: ignore DEFAULT_CONFIG_DIR = "configs/" +CONFIG_FILE_SEPARATOR = "," # ----------------------------------------------------------------------------- # Config definition @@ -135,6 +135,7 @@ _C.SIMULATOR.HABITAT_SIM_V0.GPU_DEVICE_ID = 0 _C.DATASET = CN() _C.DATASET.TYPE = "PointNav-v1" _C.DATASET.SPLIT = "train" +_C.DATASET.SCENES_DIR = "data/scene_datasets" # ----------------------------------------------------------------------------- # MP3DEQAV1 DATASET # ----------------------------------------------------------------------------- @@ -156,10 +157,32 @@ _C.DATASET.POINTNAVV1.CONTENT_SCENES = ["*"] def get_config( - config_file: Optional[str] = None, config_dir: str = DEFAULT_CONFIG_DIR + config_paths: Optional[Union[List[str], str]] = None, + opts: Optional[list] = None, ) -> CN: + """ + Create a unified config with default values overwritten by values from + `config_paths` and overwritten by options from `opts`. + Args: + config_paths: List of config paths or string that contains comma + separated list of config paths. + opts: Config options (keys, values) in a list (e.g., passed from + command line into the config. For example, `opts = ['FOO.BAR', + 0.5]`. Argument can be used for parameter sweeping or quick tests. + """ config = _C.clone() - if config_file: - config.merge_from_file(os.path.join(config_dir, config_file)) + if config_paths: + if isinstance(config_paths, str): + if CONFIG_FILE_SEPARATOR in config_paths: + config_paths = config_paths.split(CONFIG_FILE_SEPARATOR) + else: + config_paths = [config_paths] + + for config_path in config_paths: + config.merge_from_file(config_path) + + if opts: + config.merge_from_list(opts) + config.freeze() return config diff --git a/habitat/core/benchmark.py b/habitat/core/benchmark.py index cb49a6b8142df6e0405b24405c060da6f8984f22..a3dd090ede0f78a651bbea87e9aa673a027d5bd1 100644 --- a/habitat/core/benchmark.py +++ b/habitat/core/benchmark.py @@ -17,16 +17,12 @@ class Benchmark: Args: - config_file: file to be used for creating the environment. - config_dir: directory where config_file is located. + config_paths: file to be used for creating the environment. + config_dir: directory where config_paths is located. """ - def __init__( - self, - config_file: Optional[str] = None, - config_dir: str = DEFAULT_CONFIG_DIR, - ) -> None: - config_env = get_config(config_file=config_file, config_dir=config_dir) + def __init__(self, config_paths: Optional[str] = None) -> None: + config_env = get_config(config_paths) self._env = Env(config=config_env) def evaluate( diff --git a/habitat/core/challenge.py b/habitat/core/challenge.py index 5094333e5514cea7347e4116e400bfaeb8e4b43e..7dcbd7878c2bc48ab877d0860b3aa955936d00f9 100644 --- a/habitat/core/challenge.py +++ b/habitat/core/challenge.py @@ -12,8 +12,8 @@ from habitat.core.logging import logger class Challenge(Benchmark): def __init__(self): - config_file = os.environ["CHALLENGE_CONFIG_FILE"] - super().__init__(config_file) + config_paths = os.environ["CHALLENGE_CONFIG_FILE"] + super().__init__(config_paths) def submit(self, agent): metrics = super().evaluate(agent) diff --git a/habitat/core/dataset.py b/habitat/core/dataset.py index a7841dc07dddb3af2b23566e412853278ee79b9f..535f0b55b9928686b4ff58145741af6c86a665fe 100644 --- a/habitat/core/dataset.py +++ b/habitat/core/dataset.py @@ -6,7 +6,6 @@ import copy import json -import random from typing import Dict, List, Type, TypeVar, Generic, Optional, Callable import numpy as np @@ -103,7 +102,17 @@ class Dataset(Generic[T]): result = DatasetJSONEncoder().encode(self) return result - def from_json(self, json_str: str) -> None: + def from_json( + self, json_str: str, scenes_dir: Optional[str] = None + ) -> None: + """ + Parses passed JSON string and creates dataset based on that. + Function is used as deserialization method for Dataset. + Args: + json_str: JSON dump of Dataset instance. + scenes_dir: Path to directory with scenes assets such as *.glb + files. + """ raise NotImplementedError def filter_episodes( diff --git a/habitat/datasets/eqa/mp3d_eqa_dataset.py b/habitat/datasets/eqa/mp3d_eqa_dataset.py index d709b6ec518630a24fdd5b6774bbb310ebd83347..1dcd4195d77b345e41ff946a54160b8b142204cf 100644 --- a/habitat/datasets/eqa/mp3d_eqa_dataset.py +++ b/habitat/datasets/eqa/mp3d_eqa_dataset.py @@ -7,7 +7,7 @@ import gzip import json import os -from typing import List +from typing import List, Optional from habitat.config import Config from habitat.core.dataset import Dataset @@ -15,13 +15,13 @@ from habitat.tasks.eqa.eqa_task import EQAEpisode, QuestionData from habitat.tasks.nav.nav_task import ObjectGoal, ShortestPathPoint EQA_MP3D_V1_VAL_EPISODE_COUNT = 1950 +DEFAULT_SCENE_PATH_PREFIX = "data/scene_datasets/" def get_default_mp3d_v1_config(split: str = "val"): config = Config() config.name = "MP3DEQA-v1" config.DATA_PATH = "data/datasets/eqa/mp3d/v1/{split}.json.gz" - config.DATA_PATH = "data/scene_datasets/mp3d" config.SPLIT = split return config @@ -54,11 +54,19 @@ class Matterport3dDatasetV1(Dataset): ) as f: self.from_json(f.read()) - def from_json(self, json_str: str) -> None: + def from_json( + self, json_str: str, scenes_dir: Optional[str] = None + ) -> None: deserialized = json.loads(json_str) self.__dict__.update(deserialized) for ep_index, episode in enumerate(deserialized["episodes"]): episode = EQAEpisode(**episode) + if scenes_dir is not None: + if episode.scene_id.startswith(DEFAULT_SCENE_PATH_PREFIX): + episode.scene_id = episode.scene_id[ + len(DEFAULT_SCENE_PATH_PREFIX) : + ] + episode.scene_id = os.path.join(scenes_dir, episode.scene_id) episode.question = QuestionData(**episode.question) for g_index, goal in enumerate(episode.goals): episode.goals[g_index] = ObjectGoal(**goal) diff --git a/habitat/datasets/pointnav/pointnav_dataset.py b/habitat/datasets/pointnav/pointnav_dataset.py index 256ed5fffca53b0508ce91d8c599c9b658b1863f..2795b2ebff4b28bef75eee20983f0ca344f63fdc 100644 --- a/habitat/datasets/pointnav/pointnav_dataset.py +++ b/habitat/datasets/pointnav/pointnav_dataset.py @@ -7,7 +7,7 @@ import gzip import json import os -from typing import List +from typing import List, Optional from habitat.config import Config from habitat.core.dataset import Dataset @@ -19,6 +19,7 @@ from habitat.tasks.nav.nav_task import ( ALL_SCENES_MASK = "*" CONTENT_SCENES_PATH_FIELD = "content_scenes_path" +DEFAULT_SCENE_PATH_PREFIX = "data/scene_datasets/" class PointNavDatasetV1(Dataset): @@ -70,7 +71,7 @@ class PointNavDatasetV1(Dataset): scenes.sort() return scenes - def __init__(self, config: Config = None) -> None: + def __init__(self, config: Optional[Config] = None) -> None: self.episodes = [] if config is None: @@ -80,7 +81,7 @@ class PointNavDatasetV1(Dataset): split=config.SPLIT ) with gzip.open(datasetfile_path, "rt") as f: - self.from_json(f.read()) + self.from_json(f.read(), scenes_dir=config.SCENES_DIR) # Read separate file for each scene dataset_dir = os.path.dirname(datasetfile_path) @@ -96,15 +97,26 @@ class PointNavDatasetV1(Dataset): data_path=dataset_dir, scene=scene ) with gzip.open(scene_filename, "rt") as f: - self.from_json(f.read()) + self.from_json(f.read(), scenes_dir=config.SCENES_DIR) - def from_json(self, json_str: str) -> None: + def from_json( + self, json_str: str, scenes_dir: Optional[str] = None + ) -> None: deserialized = json.loads(json_str) if CONTENT_SCENES_PATH_FIELD in deserialized: self.content_scenes_path = deserialized[CONTENT_SCENES_PATH_FIELD] for episode in deserialized["episodes"]: episode = NavigationEpisode(**episode) + + if scenes_dir is not None: + if episode.scene_id.startswith(DEFAULT_SCENE_PATH_PREFIX): + episode.scene_id = episode.scene_id[ + len(DEFAULT_SCENE_PATH_PREFIX) : + ] + + episode.scene_id = os.path.join(scenes_dir, episode.scene_id) + for g_index, goal in enumerate(episode.goals): episode.goals[g_index] = NavigationGoal(**goal) if episode.shortest_paths is not None: diff --git a/habitat/tasks/nav/shortest_path_follower.py b/habitat/tasks/nav/shortest_path_follower.py index 4084b0416466fc2c04b122252f48c475cbb22a4b..66a53b64f9b1bc40df77e020a4aa193532512051 100644 --- a/habitat/tasks/nav/shortest_path_follower.py +++ b/habitat/tasks/nav/shortest_path_follower.py @@ -1,3 +1,9 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + from typing import Union import habitat_sim @@ -8,7 +14,6 @@ from habitat.sims.habitat_simulator import SimulatorActions from habitat.utils.geometry_utils import ( angle_between_quaternions, quaternion_from_two_vectors, - quaternion_xyzw_to_wxyz, ) diff --git a/habitat/utils/geometry_utils.py b/habitat/utils/geometry_utils.py index e290f1d14abc14b4b7284454005272ca9a8e6ef9..9bbf53e0cf82aa7274032a6b5e6f852025ade673 100644 --- a/habitat/utils/geometry_utils.py +++ b/habitat/utils/geometry_utils.py @@ -1,3 +1,9 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import numpy as np import quaternion diff --git a/test/test_baseline_agents.py b/test/test_baseline_agents.py index 81dad3b99bac6b3b1d993f6d575b1e2fd51a568f..b8d1cfed68554232a15ff25d837ef98c729620e3 100644 --- a/test/test_baseline_agents.py +++ b/test/test_baseline_agents.py @@ -19,19 +19,19 @@ except ImportError: if has_torch: from baselines.agents import ppo_agents -CFG_TEST = "test/habitat_all_sensors_test.yaml" +CFG_TEST = "configs/test/habitat_all_sensors_test.yaml" @pytest.mark.skipif(not has_torch, reason="Test needs torch") def test_ppo_agents(): config = ppo_agents.get_defaut_config() config.MODEL_PATH = "" - config_env = habitat.get_config(config_file=CFG_TEST) + config_env = habitat.get_config(config_paths=CFG_TEST) config_env.defrost() if not os.path.exists(config_env.SIMULATOR.SCENE): pytest.skip("Please download Habitat test data to data folder.") - benchmark = habitat.Benchmark(config_file=CFG_TEST, config_dir="configs") + benchmark = habitat.Benchmark(config_paths=CFG_TEST) for input_type in ["blind", "rgb", "depth", "rgbd"]: config_env.defrost() @@ -50,12 +50,12 @@ def test_ppo_agents(): def test_simple_agents(): - config_env = habitat.get_config(config_file=CFG_TEST) + config_env = habitat.get_config(config_paths=CFG_TEST) if not os.path.exists(config_env.SIMULATOR.SCENE): pytest.skip("Please download Habitat test data to data folder.") - benchmark = habitat.Benchmark(config_file=CFG_TEST, config_dir="configs") + benchmark = habitat.Benchmark(config_paths=CFG_TEST) for agent_class in [ simple_agents.ForwardOnlyAgent, diff --git a/test/test_config.py b/test/test_config.py new file mode 100644 index 0000000000000000000000000000000000000000..a113e82e91354f6d501c103db46bfdcb73edb91f --- /dev/null +++ b/test/test_config.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from habitat.config.default import get_config + +CFG_TEST = "configs/test/habitat_all_sensors_test.yaml" +CFG_EQA = "configs/test/habitat_mp3d_eqa_test.yaml" +MAX_TEST_STEPS_LIMIT = 3 + + +def test_merged_configs(): + test_config = get_config(CFG_TEST) + eqa_config = get_config(CFG_EQA) + merged_config = get_config("{},{}".format(CFG_TEST, CFG_EQA)) + assert merged_config.TASK.TYPE == eqa_config.TASK.TYPE + assert ( + merged_config.ENVIRONMENT.MAX_EPISODE_STEPS + == test_config.ENVIRONMENT.MAX_EPISODE_STEPS + ) + + +def test_overwrite_options(): + for steps_limit in range(MAX_TEST_STEPS_LIMIT): + config = get_config( + config_paths=CFG_TEST, + opts=["ENVIRONMENT.MAX_EPISODE_STEPS", steps_limit], + ) + assert ( + config.ENVIRONMENT.MAX_EPISODE_STEPS == steps_limit + ), "Overwriting of config options failed." diff --git a/test/test_habitat_env.py b/test/test_habitat_env.py index bed5f6f764d47baedbcd04e738b940082bb28294..3229c9d989993ecaf8c44fc8463d12798b2ab87e 100644 --- a/test/test_habitat_env.py +++ b/test/test_habitat_env.py @@ -17,7 +17,7 @@ from habitat.datasets.pointnav.pointnav_dataset import PointNavDatasetV1 from habitat.sims.habitat_simulator import SimulatorActions from habitat.tasks.nav.nav_task import NavigationEpisode, NavigationGoal -CFG_TEST = "test/habitat_all_sensors_test.yaml" +CFG_TEST = "configs/test/habitat_all_sensors_test.yaml" NUM_ENVS = 4 diff --git a/test/test_mp3d_eqa.py b/test/test_mp3d_eqa.py index 08f762d4afb8eff098877538d12675d82461811a..eef9cda3cac84a59f1142b7fffb6e274a50a16fd 100644 --- a/test/test_mp3d_eqa.py +++ b/test/test_mp3d_eqa.py @@ -16,7 +16,7 @@ from habitat.core.embodied_task import Episode from habitat.core.logging import logger from habitat.datasets import make_dataset -CFG_TEST = "test/habitat_mp3d_eqa_test.yaml" +CFG_TEST = "configs/test/habitat_mp3d_eqa_test.yaml" CLOSE_STEP_THRESHOLD = 0.028 # List of episodes each from unique house @@ -193,7 +193,10 @@ def test_mp3d_eqa_sim_correspondence(): "cur_state.rotation: {} shortest_path.rotation: {} action: {}" "".format( cur_state.position - point.position, - cur_state.rotation - point.rotation, + cur_state.rotation + - habitat.utils.geometry_utils.quaternion_wxyz_to_xyzw( + point.rotation + ), cur_state.position, point.position, cur_state.rotation, diff --git a/test/test_pointnav_dataset.py b/test/test_pointnav_dataset.py index 8271b3250e31e1273414ba26ef6aed849c6b66ce..8778204f01e2a836cdfbf668940df8e6e1abebd1 100644 --- a/test/test_pointnav_dataset.py +++ b/test/test_pointnav_dataset.py @@ -4,7 +4,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import os import time + import pytest import habitat @@ -12,9 +14,12 @@ from habitat.config.default import get_config from habitat.core.embodied_task import Episode from habitat.core.logging import logger from habitat.datasets import make_dataset -from habitat.datasets.pointnav.pointnav_dataset import PointNavDatasetV1 +from habitat.datasets.pointnav.pointnav_dataset import ( + PointNavDatasetV1, + DEFAULT_SCENE_PATH_PREFIX, +) -CFG_TEST = "datasets/pointnav/gibson.yaml" +CFG_TEST = "configs/datasets/pointnav/gibson.yaml" PARTIAL_LOAD_SCENES = 3 @@ -50,6 +55,34 @@ def test_single_pointnav_dataset(): check_json_serializaiton(dataset) +def test_multiple_files_scene_path(): + dataset_config = get_config(CFG_TEST).DATASET + if not PointNavDatasetV1.check_config_paths_exist(dataset_config): + pytest.skip("Test skipped as dataset files are missing.") + scenes = PointNavDatasetV1.get_scenes_to_load(config=dataset_config) + assert ( + len(scenes) > 0 + ), "Expected dataset contains separate episode file per scene." + dataset_config.defrost() + dataset_config.POINTNAVV1.CONTENT_SCENES = scenes[:PARTIAL_LOAD_SCENES] + dataset_config.SCENES_DIR = os.path.join( + os.getcwd(), DEFAULT_SCENE_PATH_PREFIX + ) + dataset_config.freeze() + partial_dataset = make_dataset( + id_dataset=dataset_config.TYPE, config=dataset_config + ) + assert ( + len(partial_dataset.scene_ids) == PARTIAL_LOAD_SCENES + ), "Number of loaded scenes doesn't correspond." + print(partial_dataset.episodes[0].scene_id) + assert os.path.exists( + partial_dataset.episodes[0].scene_id + ), "Scene file {} doesn't exist using absolute path".format( + partial_dataset.episodes[0].scene_id + ) + + def test_multiple_files_pointnav_dataset(): dataset_config = get_config(CFG_TEST).DATASET if not PointNavDatasetV1.check_config_paths_exist(dataset_config): diff --git a/test/test_sensors.py b/test/test_sensors.py index dbdaa208d8e598b6c05adec55189cf1982d66d83..15343f3ae521a790941238540aae3fdd259e9600 100644 --- a/test/test_sensors.py +++ b/test/test_sensors.py @@ -17,7 +17,7 @@ from habitat.tasks.nav.nav_task import ( ) from habitat.sims.habitat_simulator import SimulatorActions -CFG_TEST = "test/habitat_all_sensors_test.yaml" +CFG_TEST = "configs/test/habitat_all_sensors_test.yaml" def _random_episode(env, config):