From 063e8e53ea0ea7f47c8c7dce43230d2d0d0ff3dd Mon Sep 17 00:00:00 2001 From: Erik Wijmans <ewijmans2@gmail.com> Date: Wed, 5 Jun 2019 19:36:11 -0700 Subject: [PATCH] Fixes the registry (#99) Our registry pattern is far from idea and thus we haven't used it consistent. This PR aims to resolve this by copying the registry pattern used in Pythia (as it is much nicer) and then registering all current measures and sensors and using the register to query for those instead of getattr as this is a more clearly extendable pattern (as compared to monkey patching). --- README.md | 2 + examples/register_new_sensors_and_measures.py | 105 +++++++++ habitat/__init__.py | 7 +- habitat/config/__init__.py | 1 + habitat/core/registry.py | 201 +++++++++++++++--- habitat/datasets/__init__.py | 18 +- habitat/datasets/eqa/mp3d_eqa_dataset.py | 2 + habitat/datasets/pointnav/pointnav_dataset.py | 2 + habitat/datasets/registration.py | 34 +-- habitat/sims/__init__.py | 8 +- habitat/sims/habitat_simulator.py | 27 +-- habitat/sims/registration.py | 43 ++-- habitat/tasks/__init__.py | 8 +- habitat/tasks/eqa/eqa_task.py | 2 + habitat/tasks/nav/nav_task.py | 59 +++-- habitat/tasks/registration.py | 43 +--- habitat/utils/visualizations/__init__.py | 3 +- habitat_baselines/rl/ppo/__init__.py | 2 +- habitat_baselines/slambased/path_planners.py | 2 +- test/test_habitat_example.py | 15 +- 20 files changed, 385 insertions(+), 199 deletions(-) create mode 100644 examples/register_new_sensors_and_measures.py diff --git a/README.md b/README.md index 1144b7a27..a2a7f475b 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,8 @@ while not env.episode_over: ``` +See [`examples/register_new_sensors_and_measures.py`](examples/register_new_sensors_and_measures) for an example of how to extend habitat-api from _outside_ the source code + ## Docker Setup We also provide a docker setup for habitat. This works on machines with an NVIDIA GPU and requires users to install [nvidia-docker](https://github.com/NVIDIA/nvidia-docker). The following [Dockerfile](Dockerfile) was used to build the habitat docker. To setup the habitat stack using docker follow the below steps: diff --git a/examples/register_new_sensors_and_measures.py b/examples/register_new_sensors_and_measures.py new file mode 100644 index 000000000..a7bb634cc --- /dev/null +++ b/examples/register_new_sensors_and_measures.py @@ -0,0 +1,105 @@ +from typing import Any + +import numpy as np +from gym import spaces + +import habitat +from habitat.config import Config as CN + + +# Define the measure and register it with habitat +# By default, the things are registered with the class name +@habitat.registry.register_measure +class EpisodeInfo(habitat.Measure): + def __init__(self, sim, config): + # This measure only needs the config + self._config = config + + super().__init__() + + # Defines the name of the measure in the measurements dictionary + def _get_uuid(self, *args: Any, **kwargs: Any): + return "episode_info" + + # This is called whenver the environment is reset + def reset_metric(self, episode): + # Our measure always contains all the attributes of the episode + self._metric = vars(episode).copy() + # But only on reset, it has an additional field of my_value + self._metric["my_value"] = self._config.VALUE + + # This is called whenver an action is taken in the environment + def update_metric(self, episode, action): + # Now the measure will just have all the attributes of the episode + self._metric = vars(episode).copy() + + +# Define the sensor and register it with habitat +# For the sensor, we will register it with a custom name +@habitat.registry.register_sensor(name="my_supercool_sensor") +class AgentPositionSensor(habitat.Sensor): + def __init__(self, sim, config): + super().__init__(config=config) + + self._sim = sim + # Prints out the answer to life on init + print("The answer to life is", self.config.ANSWER_TO_LIFE) + + # Defines the name of the sensor in the sensor suite dictionary + def _get_uuid(self, *args: Any, **kwargs: Any): + return "agent_position" + + # Defines the type of the sensor + def _get_sensor_type(self, *args: Any, **kwargs: Any): + return habitat.SensorTypes.POSITION + + # Defines the size and range of the observations of the sensor + def _get_observation_space(self, *args: Any, **kwargs: Any): + return spaces.Box( + low=np.finfo(np.float32).min, + high=np.finfo(np.float32).max, + shape=(3,), + dtype=np.float32, + ) + + # This is called whenver reset is called or an action is taken + def get_observation(self, observations, episode): + return self._sim.get_agent_state().position + + +def main(): + # Get the default config node + config = habitat.get_config(config_paths="configs/tasks/pointnav.yaml") + config.defrost() + + # Add things to the config to for the measure + config.TASK.EPISODE_INFO = CN() + # The type field is used to look-up the measure in the registry. + # By default, the things are registered with the class name + config.TASK.EPISODE_INFO.TYPE = "EpisodeInfo" + config.TASK.EPISODE_INFO.VALUE = 5 + # Add the measure to the list of measures in use + config.TASK.MEASUREMENTS.append("EPISODE_INFO") + + # Now define the config for the sensor + config.TASK.AGENT_POSITION_SENSOR = CN() + # Use the custom name + config.TASK.AGENT_POSITION_SENSOR.TYPE = "my_supercool_sensor" + config.TASK.AGENT_POSITION_SENSOR.ANSWER_TO_LIFE = 42 + # Add the sensor to the list of sensors in use + config.TASK.SENSORS.append("AGENT_POSITION_SENSOR") + config.freeze() + + env = habitat.Env(config=config) + print(env.reset()["agent_position"]) + print(env.get_metrics()["episode_info"]) + print( + env.step( + habitat.sims.habitat_simulator.SimulatorActions.MOVE_FORWARD.value + )["agent_position"] + ) + print(env.get_metrics()["episode_info"]) + + +if __name__ == "__main__": + main() diff --git a/habitat/__init__.py b/habitat/__init__.py index 847fb295c..a37e4dc41 100644 --- a/habitat/__init__.py +++ b/habitat/__init__.py @@ -4,16 +4,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from habitat.config import Config, get_config from habitat.core.agent import Agent from habitat.core.benchmark import Benchmark from habitat.core.challenge import Challenge -from habitat.config import Config, get_config from habitat.core.dataset import Dataset from habitat.core.embodied_task import EmbodiedTask, Measure, Measurements from habitat.core.env import Env, RLEnv from habitat.core.logging import logger -from habitat.core.simulator import SensorTypes, Sensor, SensorSuite, Simulator -from habitat.core.vector_env import VectorEnv, ThreadedVectorEnv +from habitat.core.registry import registry +from habitat.core.simulator import Sensor, SensorSuite, SensorTypes, Simulator +from habitat.core.vector_env import ThreadedVectorEnv, VectorEnv from habitat.datasets import make_dataset from habitat.version import VERSION as __version__ # noqa diff --git a/habitat/config/__init__.py b/habitat/config/__init__.py index 14f2134a0..6734e3fc2 100644 --- a/habitat/config/__init__.py +++ b/habitat/config/__init__.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from yacs.config import CfgNode as Config + from habitat.config.default import get_config __all__ = ["Config", "get_config"] diff --git a/habitat/core/registry.py b/habitat/core/registry.py index 0aedfbffc..78e71dd4b 100644 --- a/habitat/core/registry.py +++ b/habitat/core/registry.py @@ -4,47 +4,186 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Any +""" +Registry is central source of truth in Habitat. +Shamelessly taken from Pythia, it is inspired from Redux's +concept of global store, Registry maintains mappings of various information +to unique keys. Special functions in registry can be used as decorators to +register different kind of classes. +Import the global registry object using -def load(name: str) -> type: - import pkg_resources +``from habitat.core.registry import registry`` - entry_point = pkg_resources.EntryPoint.parse("x={}".format(name)) - result = entry_point.resolve() - return result +Various decorators for registry different kind of classes with unique keys +- Register a task: ``@registry.register_task`` +- Register a simulator: ``@registry.register_simulator`` +- Register a sensor: ``@registry.register_sensor`` +- Register a measure: ``@registry.register_measure`` +- Register a dataset: ``@registry.register_dataset`` +""" -class Spec: - def __init__(self, id: str, entry_point: str, **kwargs: Any) -> None: - self.id = id - self._entry_point = entry_point +import collections +from typing import Optional - def make(self, **kwargs: Any) -> Any: - return load(self._entry_point)(**kwargs) - def __repr__(self): - return "{}({})".format(self.__class__.__name__, self.id) +class _Registry: + mapping = collections.defaultdict(dict) + @classmethod + def _register_impl(cls, _type, to_register, name, assert_type=None): + def wrap(to_register): + if assert_type is not None: + assert issubclass( + to_register, assert_type + ), "{} must be a subclass of {}".format( + to_register, assert_type + ) -class Registry: - def __init__(self): - self.specs = {} + cls.mapping[_type][ + to_register.__name__ if name is None else name + ] = to_register - def make(self, id: str, **kwargs: Any) -> Any: - spec = self.get_spec(id) - return spec.make(**kwargs) + return to_register - def all(self) -> Any: - return self.specs.values() + if to_register is None: + return wrap + else: + return wrap(to_register) - def get_spec(self, id: str) -> Spec: - spec = self.specs.get(id, None) - if spec is None: - raise KeyError( - "No registered specification with id: {}".format(id) - ) - return spec + @classmethod + def register_task(cls, to_register=None, *, name: Optional[str] = None): + r"""Register a task to registry with key 'name' - def register(self, id: str, **kwargs: Any) -> None: - raise NotImplementedError + Args: + name: Key with which the task will be registered. + If None will use the name of the class + + + Usage:: + from habitat.core.registry import registry + from habitat.core.embodied_task import EmbodiedTask + + @registry.register_task + class MyTask(EmbodiedTask): + pass + + + # or + + @registry.register_task(name="MyTaskName") + class MyTask(EmbodiedTask): + pass + + """ + from habitat.core.embodied_task import EmbodiedTask + + return cls._register_impl( + "task", to_register, name, assert_type=EmbodiedTask + ) + + @classmethod + def register_simulator( + cls, to_register=None, *, name: Optional[str] = None + ): + r"""Register a simulator to registry with key 'name' + + Args: + name: Key with which the simulator will be registered. + If None will use the name of the class + + + Usage:: + from habitat.core.registry import registry + from habitat.core.simulator import Simulator + + @registry.register_simulator + class MySimulator(Simulator): + pass + + + # or + + @registry.register_simulator(name="MySimName") + class MySimulator(Simulator): + pass + + """ + from habitat.core.simulator import Simulator + + return cls._register_impl( + "sim", to_register, name, assert_type=Simulator + ) + + @classmethod + def register_sensor(cls, to_register=None, *, name: Optional[str] = None): + r"""Register a sensor to registry with key 'name' + + Args: + name: Key with which the sensor will be registered. + If None will use the name of the class + + """ + from habitat.core.simulator import Sensor + + return cls._register_impl( + "sensor", to_register, name, assert_type=Sensor + ) + + @classmethod + def register_measure(cls, to_register=None, *, name: Optional[str] = None): + r"""Register a measure to registry with key 'name' + + Args: + name: Key with which the measure will be registered. + If None will use the name of the class + + """ + from habitat.core.embodied_task import Measure + + return cls._register_impl( + "measure", to_register, name, assert_type=Measure + ) + + @classmethod + def register_dataset(cls, to_register=None, *, name: Optional[str] = None): + r"""Register a dataset to registry with key 'name' + + Args: + name: Key with which the dataset will be registered. + If None will use the name of the class + + """ + from habitat.core.dataset import Dataset + + return cls._register_impl( + "dataset", to_register, name, assert_type=Dataset + ) + + @classmethod + def _get_impl(cls, _type, name): + return cls.mapping[_type].get(name, None) + + @classmethod + def get_task(cls, name): + return cls._get_impl("task", name) + + @classmethod + def get_simulator(cls, name): + return cls._get_impl("sim", name) + + @classmethod + def get_sensor(cls, name): + return cls._get_impl("sensor", name) + + @classmethod + def get_measure(cls, name): + return cls._get_impl("measure", name) + + @classmethod + def get_dataset(cls, name): + return cls._get_impl("dataset", name) + + +registry = _Registry() diff --git a/habitat/datasets/__init__.py b/habitat/datasets/__init__.py index 9d71bebcc..b0b567fbc 100644 --- a/habitat/datasets/__init__.py +++ b/habitat/datasets/__init__.py @@ -4,20 +4,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from habitat.datasets.registration import ( - dataset_registry, - register_dataset, - make_dataset, -) - -register_dataset( - id_dataset="MP3DEQA-v1", - entry_point="habitat.datasets.eqa.mp3d_eqa_dataset:Matterport3dDatasetV1", -) - -register_dataset( - id_dataset="PointNav-v1", - entry_point="habitat.datasets.pointnav.pointnav_dataset:PointNavDatasetV1", -) - -__all__ = ["dataset_registry", "register_dataset", "make_dataset"] +from habitat.datasets.registration import make_dataset diff --git a/habitat/datasets/eqa/mp3d_eqa_dataset.py b/habitat/datasets/eqa/mp3d_eqa_dataset.py index 1dcd4195d..dce969f52 100644 --- a/habitat/datasets/eqa/mp3d_eqa_dataset.py +++ b/habitat/datasets/eqa/mp3d_eqa_dataset.py @@ -11,6 +11,7 @@ from typing import List, Optional from habitat.config import Config from habitat.core.dataset import Dataset +from habitat.core.registry import registry from habitat.tasks.eqa.eqa_task import EQAEpisode, QuestionData from habitat.tasks.nav.nav_task import ObjectGoal, ShortestPathPoint @@ -26,6 +27,7 @@ def get_default_mp3d_v1_config(split: str = "val"): return config +@registry.register_dataset(name="MP3DEQA-v1") class Matterport3dDatasetV1(Dataset): """Class inherited from Dataset that loads Matterport3D Embodied Question Answering dataset. diff --git a/habitat/datasets/pointnav/pointnav_dataset.py b/habitat/datasets/pointnav/pointnav_dataset.py index c2ce3cbe2..be834379c 100644 --- a/habitat/datasets/pointnav/pointnav_dataset.py +++ b/habitat/datasets/pointnav/pointnav_dataset.py @@ -11,6 +11,7 @@ from typing import List, Optional from habitat.config import Config from habitat.core.dataset import Dataset +from habitat.core.registry import registry from habitat.tasks.nav.nav_task import ( NavigationEpisode, NavigationGoal, @@ -22,6 +23,7 @@ CONTENT_SCENES_PATH_FIELD = "content_scenes_path" DEFAULT_SCENE_PATH_PREFIX = "data/scene_datasets/" +@registry.register_dataset(name="PointNav-v1") class PointNavDatasetV1(Dataset): """ Class inherited from Dataset that loads Point Navigation dataset. diff --git a/habitat/datasets/registration.py b/habitat/datasets/registration.py index e96733f3a..265c43ef6 100644 --- a/habitat/datasets/registration.py +++ b/habitat/datasets/registration.py @@ -4,35 +4,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from habitat.core.registry import Registry, Spec - - -class DatasetSpec(Spec): - def __init__(self, id_dataset, entry_point): - super().__init__(id_dataset, entry_point) - - -class DatasetRegistry(Registry): - def register(self, id_dataset, **kwargs): - if id_dataset in self.specs: - raise ValueError( - "Cannot re-register dataset specification with id: {}".format( - id_dataset - ) - ) - self.specs[id_dataset] = DatasetSpec(id_dataset, **kwargs) - - -dataset_registry = DatasetRegistry() - - -def register_dataset(id_dataset, **kwargs): - dataset_registry.register(id_dataset, **kwargs) +from habitat.core.registry import registry +from habitat.datasets.eqa.mp3d_eqa_dataset import Matterport3dDatasetV1 +from habitat.datasets.pointnav.pointnav_dataset import PointNavDatasetV1 def make_dataset(id_dataset, **kwargs): - return dataset_registry.make(id_dataset, **kwargs) - + _dataset = registry.get_dataset(id_dataset) + assert _dataset is not None, "Could not find dataset {}".format(id_dataset) -def get_spec_dataset(id_dataset): - return dataset_registry.get_spec(id_dataset) + return _dataset(**kwargs) diff --git a/habitat/sims/__init__.py b/habitat/sims/__init__.py index 345d13b0c..255612702 100644 --- a/habitat/sims/__init__.py +++ b/habitat/sims/__init__.py @@ -4,10 +4,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from habitat.sims.registration import sim_registry, register_sim, make_sim - -register_sim( - id_sim="Sim-v0", entry_point="habitat.sims.habitat_simulator:HabitatSim" -) - -__all__ = ["sim_registry", "register_sim", "make_sim"] +from habitat.sims.registration import make_sim diff --git a/habitat/sims/habitat_simulator.py b/habitat/sims/habitat_simulator.py index 402e02f91..d6c0c3a09 100644 --- a/habitat/sims/habitat_simulator.py +++ b/habitat/sims/habitat_simulator.py @@ -10,17 +10,20 @@ from typing import Any, List, Optional import numpy as np from gym import Space, spaces -import habitat import habitat_sim -from habitat import Config, SensorSuite from habitat.core.logging import logger +from habitat.core.registry import registry from habitat.core.simulator import ( AgentState, + Config, DepthSensor, Observations, RGBSensor, SemanticSensor, + Sensor, + SensorSuite, ShortestPathPoint, + Simulator, ) RGBSENSOR_DIMENSION = 3 @@ -48,6 +51,7 @@ class SimulatorActions(Enum): LOOK_DOWN = 5 +@registry.register_sensor class HabitatSimRGBSensor(RGBSensor): sim_sensor_type: habitat_sim.SensorType @@ -71,6 +75,7 @@ class HabitatSimRGBSensor(RGBSensor): return obs +@registry.register_sensor class HabitatSimDepthSensor(DepthSensor): sim_sensor_type: habitat_sim.SensorType min_depth_value: float @@ -110,6 +115,7 @@ class HabitatSimDepthSensor(DepthSensor): return obs +@registry.register_sensor class HabitatSimSemanticSensor(SemanticSensor): sim_sensor_type: habitat_sim.SensorType @@ -131,7 +137,8 @@ class HabitatSimSemanticSensor(SemanticSensor): return obs -class HabitatSim(habitat.Simulator): +@registry.register_simulator(name="Sim-v0") +class HabitatSim(Simulator): """Simulator wrapper over habitat-sim habitat-sim repo: https://github.com/facebookresearch/habitat-sim @@ -147,18 +154,12 @@ class HabitatSim(habitat.Simulator): sim_sensors = [] for sensor_name in agent_config.SENSORS: sensor_cfg = getattr(self.config, sensor_name) - is_valid_sensor = hasattr( - habitat.sims.habitat_simulator, sensor_cfg.TYPE # type: ignore - ) - assert is_valid_sensor, "invalid sensor type {}".format( + sensor_type = registry.get_sensor(sensor_cfg.TYPE) + + assert sensor_type is not None, "invalid sensor type {}".format( sensor_cfg.TYPE ) - sim_sensors.append( - getattr( - habitat.sims.habitat_simulator, - sensor_cfg.TYPE, # type: ignore - )(sensor_cfg) - ) + sim_sensors.append(sensor_type(sensor_cfg)) self._sensor_suite = SensorSuite(sim_sensors) self.sim_config = self.create_sim_config(self._sensor_suite) diff --git a/habitat/sims/registration.py b/habitat/sims/registration.py index dd9f1b97d..0f59808d5 100644 --- a/habitat/sims/registration.py +++ b/habitat/sims/registration.py @@ -5,35 +5,36 @@ # LICENSE file in the root directory of this source tree. from habitat.core.logging import logger -from habitat.core.registry import Registry, Spec +from habitat.core.registry import registry +from habitat.core.simulator import Simulator -class SimSpec(Spec): - def __init__(self, id_sim, entry_point): - super().__init__(id_sim, entry_point) +def _try_register_habitat_sim(): + try: + import habitat_sim + has_habitat_sim = True + except ImportError as e: + has_habitat_sim = False + habitat_sim_import_error = e -class SimRegistry(Registry): - def register(self, id_sim, **kwargs): - if id_sim in self.specs: - raise ValueError( - "Cannot re-register sim" - " specification with id: {}".format(id_sim) - ) - self.specs[id_sim] = SimSpec(id_sim, **kwargs) + if has_habitat_sim: + from habitat.sims.habitat_simulator import HabitatSim + else: - -sim_registry = SimRegistry() - - -def register_sim(id_sim, **kwargs): - sim_registry.register(id_sim, **kwargs) + @registry.register_simulator(name="Sim-v0") + class HabitatSimImportError(Simulator): + def __init__(self, *args, **kwargs): + raise habitat_sim_import_error def make_sim(id_sim, **kwargs): logger.info("initializing sim {}".format(id_sim)) - return sim_registry.make(id_sim, **kwargs) + _sim = registry.get_simulator(id_sim) + assert _sim is not None, "Could not find simulator with name {}".format( + id_sim + ) + return _sim(**kwargs) -def get_spec_sim(id_sim): - return sim_registry.get_spec(id_sim) +_try_register_habitat_sim() diff --git a/habitat/tasks/__init__.py b/habitat/tasks/__init__.py index 054774275..8a0437e90 100644 --- a/habitat/tasks/__init__.py +++ b/habitat/tasks/__init__.py @@ -4,10 +4,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from habitat.tasks.registration import task_registry, register_task, make_task - -register_task(id_task="EQA-v0", entry_point="habitat.tasks.eqa:EQATask") - -register_task(id_task="Nav-v0", entry_point="habitat.tasks.nav:NavigationTask") - -__all__ = ["task_registry", "register_task", "make_task"] +from habitat.tasks.registration import make_task diff --git a/habitat/tasks/eqa/eqa_task.py b/habitat/tasks/eqa/eqa_task.py index 6898db2a2..dbb0c1f0d 100644 --- a/habitat/tasks/eqa/eqa_task.py +++ b/habitat/tasks/eqa/eqa_task.py @@ -10,6 +10,7 @@ import attr import numpy as np from gym import spaces +from habitat.core.registry import registry from habitat.core.simulator import ( Observations, Sensor, @@ -112,6 +113,7 @@ class RewardSensor(Sensor): return self._get_observation(**kwargs) +@registry.register_task(name="EQA-v0") class EQATask(NavigationTask): _sensor_suite: SensorSuite diff --git a/habitat/tasks/nav/nav_task.py b/habitat/tasks/nav/nav_task.py index c46cf2e8b..578cccc5b 100644 --- a/habitat/tasks/nav/nav_task.py +++ b/habitat/tasks/nav/nav_task.py @@ -11,11 +11,12 @@ import cv2 import numpy as np from gym import spaces -import habitat from habitat.config import Config from habitat.core.dataset import Dataset, Episode -from habitat.core.embodied_task import Measurements +from habitat.core.embodied_task import EmbodiedTask, Measure, Measurements +from habitat.core.registry import registry from habitat.core.simulator import ( + Sensor, SensorSuite, SensorTypes, ShortestPathPoint, @@ -105,7 +106,8 @@ class NavigationEpisode(Episode): shortest_paths: Optional[List[ShortestPathPoint]] = None -class PointGoalSensor(habitat.Sensor): +@registry.register_sensor +class PointGoalSensor(Sensor): """ Sensor for PointGoal observations which are used in the PointNav task. For the agent in simulator the forward direction is along negative-z. @@ -171,7 +173,8 @@ class PointGoalSensor(habitat.Sensor): return direction_vector_agent -class StaticPointGoalSensor(habitat.Sensor): +@registry.register_sensor +class StaticPointGoalSensor(Sensor): """ Sensor for PointGoal observations which are used in the StaticPointNav task. For the agent in simulator the forward direction is along negative-z. @@ -243,7 +246,8 @@ class StaticPointGoalSensor(habitat.Sensor): return self._initial_vector -class HeadingSensor(habitat.Sensor): +@registry.register_sensor +class HeadingSensor(Sensor): """ Sensor for observing the agent's heading in the global coordinate frame. @@ -279,7 +283,8 @@ class HeadingSensor(habitat.Sensor): return np.array(phi) -class ProximitySensor(habitat.Sensor): +@registry.register_sensor +class ProximitySensor(Sensor): """ Sensor for observing the distance to the closest obstacle @@ -317,7 +322,8 @@ class ProximitySensor(habitat.Sensor): ) -class SPL(habitat.Measure): +@registry.register_measure +class SPL(Measure): """SPL (Success weighted by Path Length) ref: On Evaluation of Embodied Agents - Anderson et. al @@ -375,7 +381,8 @@ class SPL(habitat.Measure): ) -class Collisions(habitat.Measure): +@registry.register_measure +class Collisions(Measure): def __init__(self, sim, config): self._sim = sim self._config = config @@ -401,7 +408,8 @@ class Collisions(habitat.Measure): self._metric += 1 -class TopDownMap(habitat.Measure): +@registry.register_measure +class TopDownMap(Measure): """Top Down Map measure """ @@ -547,7 +555,8 @@ class TopDownMap(habitat.Measure): return self._top_down_map, a_x, a_y -class NavigationTask(habitat.EmbodiedTask): +@registry.register_task(name="Nav-v0") +class NavigationTask(EmbodiedTask): def __init__( self, task_config: Config, @@ -558,35 +567,21 @@ class NavigationTask(habitat.EmbodiedTask): task_measurements = [] for measurement_name in task_config.MEASUREMENTS: measurement_cfg = getattr(task_config, measurement_name) - is_valid_measurement = hasattr( - habitat.tasks.nav.nav_task, # type: ignore - measurement_cfg.TYPE, - ) - assert is_valid_measurement, "invalid measurement type {}".format( - measurement_cfg.TYPE - ) - task_measurements.append( - getattr( - habitat.tasks.nav.nav_task, # type: ignore - measurement_cfg.TYPE, - )(sim, measurement_cfg) - ) + measure_type = registry.get_measure(measurement_cfg.TYPE) + assert ( + measure_type is not None + ), "invalid measurement type {}".format(measurement_cfg.TYPE) + task_measurements.append(measure_type(sim, measurement_cfg)) self.measurements = Measurements(task_measurements) task_sensors = [] for sensor_name in task_config.SENSORS: sensor_cfg = getattr(task_config, sensor_name) - is_valid_sensor = hasattr( - habitat.tasks.nav.nav_task, sensor_cfg.TYPE # type: ignore - ) - assert is_valid_sensor, "invalid sensor type {}".format( + sensor_type = registry.get_sensor(sensor_cfg.TYPE) + assert sensor_type is not None, "invalid sensor type {}".format( sensor_cfg.TYPE ) - task_sensors.append( - getattr( - habitat.tasks.nav.nav_task, sensor_cfg.TYPE # type: ignore - )(sim, sensor_cfg) - ) + task_sensors.append(sensor_type(sim, sensor_cfg)) self.sensor_suite = SensorSuite(task_sensors) super().__init__(config=task_config, sim=sim, dataset=dataset) diff --git a/habitat/tasks/registration.py b/habitat/tasks/registration.py index 0c7543fb5..90400ce07 100644 --- a/habitat/tasks/registration.py +++ b/habitat/tasks/registration.py @@ -5,43 +5,16 @@ # LICENSE file in the root directory of this source tree. from habitat.core.logging import logger -from habitat.core.registry import Registry, Spec - - -class TaskSpec(Spec): - def __init__(self, id_task, entry_point): - super().__init__(id_task, entry_point) - - -class TaskRegistry(Registry): - """Registry for maintaining tasks. - - Args: - id_task: id for task being registered. - kwargs: arguments to be passed to task constructor. - """ - - def register(self, id_task, **kwargs): - if id_task in self.specs: - raise ValueError( - "Cannot re-register task specification with id: {}".format( - id_task - ) - ) - self.specs[id_task] = TaskSpec(id_task, **kwargs) - - -task_registry = TaskRegistry() - - -def register_task(id_task, **kwargs): - task_registry.register(id_task, **kwargs) +from habitat.core.registry import registry +from habitat.tasks.eqa.eqa_task import EQATask +from habitat.tasks.nav.nav_task import NavigationTask def make_task(id_task, **kwargs): logger.info("initializing task {}".format(id_task)) - return task_registry.make(id_task, **kwargs) - + _task = registry.get_task(id_task) + assert _task is not None, "Could not find task with name {}".format( + id_task + ) -def get_spec_task(id_task): - return task_registry.get_spec(id_task) + return _task(**kwargs) diff --git a/habitat/utils/visualizations/__init__.py b/habitat/utils/visualizations/__init__.py index 74b7b453c..c27ce957c 100644 --- a/habitat/utils/visualizations/__init__.py +++ b/habitat/utils/visualizations/__init__.py @@ -4,7 +4,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from habitat.utils.visualizations import maps -from habitat.utils.visualizations import utils +from habitat.utils.visualizations import maps, utils __all__ = ["maps", "utils"] diff --git a/habitat_baselines/rl/ppo/__init__.py b/habitat_baselines/rl/ppo/__init__.py index fad59c373..128cc360c 100644 --- a/habitat_baselines/rl/ppo/__init__.py +++ b/habitat_baselines/rl/ppo/__init__.py @@ -4,8 +4,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from habitat_baselines.rl.ppo.ppo import PPO from habitat_baselines.rl.ppo.policy import Policy +from habitat_baselines.rl.ppo.ppo import PPO from habitat_baselines.rl.ppo.utils import RolloutStorage __all__ = ["PPO", "Policy", "RolloutStorage"] diff --git a/habitat_baselines/slambased/path_planners.py b/habitat_baselines/slambased/path_planners.py index 9e15e1574..f56776df4 100644 --- a/habitat_baselines/slambased/path_planners.py +++ b/habitat_baselines/slambased/path_planners.py @@ -1,9 +1,9 @@ -import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +import matplotlib.pyplot as plt from habitat_baselines.slambased.utils import generate_2dgrid diff --git a/test/test_habitat_example.py b/test/test_habitat_example.py index 1f95c35f4..d4884c122 100644 --- a/test/test_habitat_example.py +++ b/test/test_habitat_example.py @@ -7,7 +7,11 @@ import pytest import habitat -from examples import shortest_path_follower_example, visualization_examples +from examples import ( + register_new_sensors_and_measures, + shortest_path_follower_example, + visualization_examples, +) from examples.example import example from habitat.datasets.pointnav.pointnav_dataset import PointNavDatasetV1 @@ -34,3 +38,12 @@ def test_shortest_path_follower_example(): ): pytest.skip("Please download Habitat test data to data folder.") shortest_path_follower_example.main() + + +def test_register_new_sensors_and_measures(): + if not PointNavDatasetV1.check_config_paths_exist( + config=habitat.get_config().DATASET + ): + pytest.skip("Please download Habitat test data to data folder.") + + register_new_sensors_and_measures.main() -- GitLab