diff --git a/README.md b/README.md index 1144b7a27f7ab91fb82baf39af5785f9bffeaed0..a2a7f475bc59cb7ca5348d93a5c74c0b33bda795 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,8 @@ while not env.episode_over: ``` +See [`examples/register_new_sensors_and_measures.py`](examples/register_new_sensors_and_measures) for an example of how to extend habitat-api from _outside_ the source code + ## Docker Setup We also provide a docker setup for habitat. This works on machines with an NVIDIA GPU and requires users to install [nvidia-docker](https://github.com/NVIDIA/nvidia-docker). The following [Dockerfile](Dockerfile) was used to build the habitat docker. To setup the habitat stack using docker follow the below steps: diff --git a/examples/register_new_sensors_and_measures.py b/examples/register_new_sensors_and_measures.py new file mode 100644 index 0000000000000000000000000000000000000000..a7bb634cc4b7f1ce18ea0320c9c1314eef13ec10 --- /dev/null +++ b/examples/register_new_sensors_and_measures.py @@ -0,0 +1,105 @@ +from typing import Any + +import numpy as np +from gym import spaces + +import habitat +from habitat.config import Config as CN + + +# Define the measure and register it with habitat +# By default, the things are registered with the class name +@habitat.registry.register_measure +class EpisodeInfo(habitat.Measure): + def __init__(self, sim, config): + # This measure only needs the config + self._config = config + + super().__init__() + + # Defines the name of the measure in the measurements dictionary + def _get_uuid(self, *args: Any, **kwargs: Any): + return "episode_info" + + # This is called whenver the environment is reset + def reset_metric(self, episode): + # Our measure always contains all the attributes of the episode + self._metric = vars(episode).copy() + # But only on reset, it has an additional field of my_value + self._metric["my_value"] = self._config.VALUE + + # This is called whenver an action is taken in the environment + def update_metric(self, episode, action): + # Now the measure will just have all the attributes of the episode + self._metric = vars(episode).copy() + + +# Define the sensor and register it with habitat +# For the sensor, we will register it with a custom name +@habitat.registry.register_sensor(name="my_supercool_sensor") +class AgentPositionSensor(habitat.Sensor): + def __init__(self, sim, config): + super().__init__(config=config) + + self._sim = sim + # Prints out the answer to life on init + print("The answer to life is", self.config.ANSWER_TO_LIFE) + + # Defines the name of the sensor in the sensor suite dictionary + def _get_uuid(self, *args: Any, **kwargs: Any): + return "agent_position" + + # Defines the type of the sensor + def _get_sensor_type(self, *args: Any, **kwargs: Any): + return habitat.SensorTypes.POSITION + + # Defines the size and range of the observations of the sensor + def _get_observation_space(self, *args: Any, **kwargs: Any): + return spaces.Box( + low=np.finfo(np.float32).min, + high=np.finfo(np.float32).max, + shape=(3,), + dtype=np.float32, + ) + + # This is called whenver reset is called or an action is taken + def get_observation(self, observations, episode): + return self._sim.get_agent_state().position + + +def main(): + # Get the default config node + config = habitat.get_config(config_paths="configs/tasks/pointnav.yaml") + config.defrost() + + # Add things to the config to for the measure + config.TASK.EPISODE_INFO = CN() + # The type field is used to look-up the measure in the registry. + # By default, the things are registered with the class name + config.TASK.EPISODE_INFO.TYPE = "EpisodeInfo" + config.TASK.EPISODE_INFO.VALUE = 5 + # Add the measure to the list of measures in use + config.TASK.MEASUREMENTS.append("EPISODE_INFO") + + # Now define the config for the sensor + config.TASK.AGENT_POSITION_SENSOR = CN() + # Use the custom name + config.TASK.AGENT_POSITION_SENSOR.TYPE = "my_supercool_sensor" + config.TASK.AGENT_POSITION_SENSOR.ANSWER_TO_LIFE = 42 + # Add the sensor to the list of sensors in use + config.TASK.SENSORS.append("AGENT_POSITION_SENSOR") + config.freeze() + + env = habitat.Env(config=config) + print(env.reset()["agent_position"]) + print(env.get_metrics()["episode_info"]) + print( + env.step( + habitat.sims.habitat_simulator.SimulatorActions.MOVE_FORWARD.value + )["agent_position"] + ) + print(env.get_metrics()["episode_info"]) + + +if __name__ == "__main__": + main() diff --git a/habitat/__init__.py b/habitat/__init__.py index 847fb295c3c0ef3e362af0b3216b39d1e32c2eda..a37e4dc4159f66036a63e90bbc5cb19fecfc1a6d 100644 --- a/habitat/__init__.py +++ b/habitat/__init__.py @@ -4,16 +4,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from habitat.config import Config, get_config from habitat.core.agent import Agent from habitat.core.benchmark import Benchmark from habitat.core.challenge import Challenge -from habitat.config import Config, get_config from habitat.core.dataset import Dataset from habitat.core.embodied_task import EmbodiedTask, Measure, Measurements from habitat.core.env import Env, RLEnv from habitat.core.logging import logger -from habitat.core.simulator import SensorTypes, Sensor, SensorSuite, Simulator -from habitat.core.vector_env import VectorEnv, ThreadedVectorEnv +from habitat.core.registry import registry +from habitat.core.simulator import Sensor, SensorSuite, SensorTypes, Simulator +from habitat.core.vector_env import ThreadedVectorEnv, VectorEnv from habitat.datasets import make_dataset from habitat.version import VERSION as __version__ # noqa diff --git a/habitat/config/__init__.py b/habitat/config/__init__.py index 14f2134a0767b8f33b2bea6f610333e015f32746..6734e3fc2cfeb467fbe56f92f09e5d9c9b7e6d8a 100644 --- a/habitat/config/__init__.py +++ b/habitat/config/__init__.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from yacs.config import CfgNode as Config + from habitat.config.default import get_config __all__ = ["Config", "get_config"] diff --git a/habitat/core/registry.py b/habitat/core/registry.py index 0aedfbffcf509ef7b7117b03ae15f85dc911d1e8..78e71dd4bae76f54546cd72ca7db088fdd7b4046 100644 --- a/habitat/core/registry.py +++ b/habitat/core/registry.py @@ -4,47 +4,186 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Any +""" +Registry is central source of truth in Habitat. +Shamelessly taken from Pythia, it is inspired from Redux's +concept of global store, Registry maintains mappings of various information +to unique keys. Special functions in registry can be used as decorators to +register different kind of classes. +Import the global registry object using -def load(name: str) -> type: - import pkg_resources +``from habitat.core.registry import registry`` - entry_point = pkg_resources.EntryPoint.parse("x={}".format(name)) - result = entry_point.resolve() - return result +Various decorators for registry different kind of classes with unique keys +- Register a task: ``@registry.register_task`` +- Register a simulator: ``@registry.register_simulator`` +- Register a sensor: ``@registry.register_sensor`` +- Register a measure: ``@registry.register_measure`` +- Register a dataset: ``@registry.register_dataset`` +""" -class Spec: - def __init__(self, id: str, entry_point: str, **kwargs: Any) -> None: - self.id = id - self._entry_point = entry_point +import collections +from typing import Optional - def make(self, **kwargs: Any) -> Any: - return load(self._entry_point)(**kwargs) - def __repr__(self): - return "{}({})".format(self.__class__.__name__, self.id) +class _Registry: + mapping = collections.defaultdict(dict) + @classmethod + def _register_impl(cls, _type, to_register, name, assert_type=None): + def wrap(to_register): + if assert_type is not None: + assert issubclass( + to_register, assert_type + ), "{} must be a subclass of {}".format( + to_register, assert_type + ) -class Registry: - def __init__(self): - self.specs = {} + cls.mapping[_type][ + to_register.__name__ if name is None else name + ] = to_register - def make(self, id: str, **kwargs: Any) -> Any: - spec = self.get_spec(id) - return spec.make(**kwargs) + return to_register - def all(self) -> Any: - return self.specs.values() + if to_register is None: + return wrap + else: + return wrap(to_register) - def get_spec(self, id: str) -> Spec: - spec = self.specs.get(id, None) - if spec is None: - raise KeyError( - "No registered specification with id: {}".format(id) - ) - return spec + @classmethod + def register_task(cls, to_register=None, *, name: Optional[str] = None): + r"""Register a task to registry with key 'name' - def register(self, id: str, **kwargs: Any) -> None: - raise NotImplementedError + Args: + name: Key with which the task will be registered. + If None will use the name of the class + + + Usage:: + from habitat.core.registry import registry + from habitat.core.embodied_task import EmbodiedTask + + @registry.register_task + class MyTask(EmbodiedTask): + pass + + + # or + + @registry.register_task(name="MyTaskName") + class MyTask(EmbodiedTask): + pass + + """ + from habitat.core.embodied_task import EmbodiedTask + + return cls._register_impl( + "task", to_register, name, assert_type=EmbodiedTask + ) + + @classmethod + def register_simulator( + cls, to_register=None, *, name: Optional[str] = None + ): + r"""Register a simulator to registry with key 'name' + + Args: + name: Key with which the simulator will be registered. + If None will use the name of the class + + + Usage:: + from habitat.core.registry import registry + from habitat.core.simulator import Simulator + + @registry.register_simulator + class MySimulator(Simulator): + pass + + + # or + + @registry.register_simulator(name="MySimName") + class MySimulator(Simulator): + pass + + """ + from habitat.core.simulator import Simulator + + return cls._register_impl( + "sim", to_register, name, assert_type=Simulator + ) + + @classmethod + def register_sensor(cls, to_register=None, *, name: Optional[str] = None): + r"""Register a sensor to registry with key 'name' + + Args: + name: Key with which the sensor will be registered. + If None will use the name of the class + + """ + from habitat.core.simulator import Sensor + + return cls._register_impl( + "sensor", to_register, name, assert_type=Sensor + ) + + @classmethod + def register_measure(cls, to_register=None, *, name: Optional[str] = None): + r"""Register a measure to registry with key 'name' + + Args: + name: Key with which the measure will be registered. + If None will use the name of the class + + """ + from habitat.core.embodied_task import Measure + + return cls._register_impl( + "measure", to_register, name, assert_type=Measure + ) + + @classmethod + def register_dataset(cls, to_register=None, *, name: Optional[str] = None): + r"""Register a dataset to registry with key 'name' + + Args: + name: Key with which the dataset will be registered. + If None will use the name of the class + + """ + from habitat.core.dataset import Dataset + + return cls._register_impl( + "dataset", to_register, name, assert_type=Dataset + ) + + @classmethod + def _get_impl(cls, _type, name): + return cls.mapping[_type].get(name, None) + + @classmethod + def get_task(cls, name): + return cls._get_impl("task", name) + + @classmethod + def get_simulator(cls, name): + return cls._get_impl("sim", name) + + @classmethod + def get_sensor(cls, name): + return cls._get_impl("sensor", name) + + @classmethod + def get_measure(cls, name): + return cls._get_impl("measure", name) + + @classmethod + def get_dataset(cls, name): + return cls._get_impl("dataset", name) + + +registry = _Registry() diff --git a/habitat/datasets/__init__.py b/habitat/datasets/__init__.py index 9d71bebcc24ce659ecf71e20240dfb874d2e2ab9..b0b567fbcee2393806eb0afa10ae7f8ec97f1e0c 100644 --- a/habitat/datasets/__init__.py +++ b/habitat/datasets/__init__.py @@ -4,20 +4,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from habitat.datasets.registration import ( - dataset_registry, - register_dataset, - make_dataset, -) - -register_dataset( - id_dataset="MP3DEQA-v1", - entry_point="habitat.datasets.eqa.mp3d_eqa_dataset:Matterport3dDatasetV1", -) - -register_dataset( - id_dataset="PointNav-v1", - entry_point="habitat.datasets.pointnav.pointnav_dataset:PointNavDatasetV1", -) - -__all__ = ["dataset_registry", "register_dataset", "make_dataset"] +from habitat.datasets.registration import make_dataset diff --git a/habitat/datasets/eqa/mp3d_eqa_dataset.py b/habitat/datasets/eqa/mp3d_eqa_dataset.py index 1dcd4195d77b345e41ff946a54160b8b142204cf..dce969f52cdf3c1d650aaf2f7ce27a246e658037 100644 --- a/habitat/datasets/eqa/mp3d_eqa_dataset.py +++ b/habitat/datasets/eqa/mp3d_eqa_dataset.py @@ -11,6 +11,7 @@ from typing import List, Optional from habitat.config import Config from habitat.core.dataset import Dataset +from habitat.core.registry import registry from habitat.tasks.eqa.eqa_task import EQAEpisode, QuestionData from habitat.tasks.nav.nav_task import ObjectGoal, ShortestPathPoint @@ -26,6 +27,7 @@ def get_default_mp3d_v1_config(split: str = "val"): return config +@registry.register_dataset(name="MP3DEQA-v1") class Matterport3dDatasetV1(Dataset): """Class inherited from Dataset that loads Matterport3D Embodied Question Answering dataset. diff --git a/habitat/datasets/pointnav/pointnav_dataset.py b/habitat/datasets/pointnav/pointnav_dataset.py index c2ce3cbe23dad71933c632c355dd7c70070d87a2..be834379cd60b5370c4a060ea6ebd88d51ea8afd 100644 --- a/habitat/datasets/pointnav/pointnav_dataset.py +++ b/habitat/datasets/pointnav/pointnav_dataset.py @@ -11,6 +11,7 @@ from typing import List, Optional from habitat.config import Config from habitat.core.dataset import Dataset +from habitat.core.registry import registry from habitat.tasks.nav.nav_task import ( NavigationEpisode, NavigationGoal, @@ -22,6 +23,7 @@ CONTENT_SCENES_PATH_FIELD = "content_scenes_path" DEFAULT_SCENE_PATH_PREFIX = "data/scene_datasets/" +@registry.register_dataset(name="PointNav-v1") class PointNavDatasetV1(Dataset): """ Class inherited from Dataset that loads Point Navigation dataset. diff --git a/habitat/datasets/registration.py b/habitat/datasets/registration.py index e96733f3a5f12180f2b97370c03a561d992c7ab4..265c43ef6ef2e2afcd4a141fab6c8c15fd9dd92a 100644 --- a/habitat/datasets/registration.py +++ b/habitat/datasets/registration.py @@ -4,35 +4,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from habitat.core.registry import Registry, Spec - - -class DatasetSpec(Spec): - def __init__(self, id_dataset, entry_point): - super().__init__(id_dataset, entry_point) - - -class DatasetRegistry(Registry): - def register(self, id_dataset, **kwargs): - if id_dataset in self.specs: - raise ValueError( - "Cannot re-register dataset specification with id: {}".format( - id_dataset - ) - ) - self.specs[id_dataset] = DatasetSpec(id_dataset, **kwargs) - - -dataset_registry = DatasetRegistry() - - -def register_dataset(id_dataset, **kwargs): - dataset_registry.register(id_dataset, **kwargs) +from habitat.core.registry import registry +from habitat.datasets.eqa.mp3d_eqa_dataset import Matterport3dDatasetV1 +from habitat.datasets.pointnav.pointnav_dataset import PointNavDatasetV1 def make_dataset(id_dataset, **kwargs): - return dataset_registry.make(id_dataset, **kwargs) - + _dataset = registry.get_dataset(id_dataset) + assert _dataset is not None, "Could not find dataset {}".format(id_dataset) -def get_spec_dataset(id_dataset): - return dataset_registry.get_spec(id_dataset) + return _dataset(**kwargs) diff --git a/habitat/sims/__init__.py b/habitat/sims/__init__.py index 345d13b0c024abfd34c4beac615fc6b9c6cb6305..255612702b58cfec57abf883c4996e450120f9e0 100644 --- a/habitat/sims/__init__.py +++ b/habitat/sims/__init__.py @@ -4,10 +4,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from habitat.sims.registration import sim_registry, register_sim, make_sim - -register_sim( - id_sim="Sim-v0", entry_point="habitat.sims.habitat_simulator:HabitatSim" -) - -__all__ = ["sim_registry", "register_sim", "make_sim"] +from habitat.sims.registration import make_sim diff --git a/habitat/sims/habitat_simulator.py b/habitat/sims/habitat_simulator.py index 402e02f91471e47a5ba5cb27e086df0537c445f4..d6c0c3a093ebfa5452843890740b2f27197aa0ac 100644 --- a/habitat/sims/habitat_simulator.py +++ b/habitat/sims/habitat_simulator.py @@ -10,17 +10,20 @@ from typing import Any, List, Optional import numpy as np from gym import Space, spaces -import habitat import habitat_sim -from habitat import Config, SensorSuite from habitat.core.logging import logger +from habitat.core.registry import registry from habitat.core.simulator import ( AgentState, + Config, DepthSensor, Observations, RGBSensor, SemanticSensor, + Sensor, + SensorSuite, ShortestPathPoint, + Simulator, ) RGBSENSOR_DIMENSION = 3 @@ -48,6 +51,7 @@ class SimulatorActions(Enum): LOOK_DOWN = 5 +@registry.register_sensor class HabitatSimRGBSensor(RGBSensor): sim_sensor_type: habitat_sim.SensorType @@ -71,6 +75,7 @@ class HabitatSimRGBSensor(RGBSensor): return obs +@registry.register_sensor class HabitatSimDepthSensor(DepthSensor): sim_sensor_type: habitat_sim.SensorType min_depth_value: float @@ -110,6 +115,7 @@ class HabitatSimDepthSensor(DepthSensor): return obs +@registry.register_sensor class HabitatSimSemanticSensor(SemanticSensor): sim_sensor_type: habitat_sim.SensorType @@ -131,7 +137,8 @@ class HabitatSimSemanticSensor(SemanticSensor): return obs -class HabitatSim(habitat.Simulator): +@registry.register_simulator(name="Sim-v0") +class HabitatSim(Simulator): """Simulator wrapper over habitat-sim habitat-sim repo: https://github.com/facebookresearch/habitat-sim @@ -147,18 +154,12 @@ class HabitatSim(habitat.Simulator): sim_sensors = [] for sensor_name in agent_config.SENSORS: sensor_cfg = getattr(self.config, sensor_name) - is_valid_sensor = hasattr( - habitat.sims.habitat_simulator, sensor_cfg.TYPE # type: ignore - ) - assert is_valid_sensor, "invalid sensor type {}".format( + sensor_type = registry.get_sensor(sensor_cfg.TYPE) + + assert sensor_type is not None, "invalid sensor type {}".format( sensor_cfg.TYPE ) - sim_sensors.append( - getattr( - habitat.sims.habitat_simulator, - sensor_cfg.TYPE, # type: ignore - )(sensor_cfg) - ) + sim_sensors.append(sensor_type(sensor_cfg)) self._sensor_suite = SensorSuite(sim_sensors) self.sim_config = self.create_sim_config(self._sensor_suite) diff --git a/habitat/sims/registration.py b/habitat/sims/registration.py index dd9f1b97d4145c6880f338526f85ceec892d6d8b..0f59808d5707c7d2c13706fe2b6d53eac392d7c7 100644 --- a/habitat/sims/registration.py +++ b/habitat/sims/registration.py @@ -5,35 +5,36 @@ # LICENSE file in the root directory of this source tree. from habitat.core.logging import logger -from habitat.core.registry import Registry, Spec +from habitat.core.registry import registry +from habitat.core.simulator import Simulator -class SimSpec(Spec): - def __init__(self, id_sim, entry_point): - super().__init__(id_sim, entry_point) +def _try_register_habitat_sim(): + try: + import habitat_sim + has_habitat_sim = True + except ImportError as e: + has_habitat_sim = False + habitat_sim_import_error = e -class SimRegistry(Registry): - def register(self, id_sim, **kwargs): - if id_sim in self.specs: - raise ValueError( - "Cannot re-register sim" - " specification with id: {}".format(id_sim) - ) - self.specs[id_sim] = SimSpec(id_sim, **kwargs) + if has_habitat_sim: + from habitat.sims.habitat_simulator import HabitatSim + else: - -sim_registry = SimRegistry() - - -def register_sim(id_sim, **kwargs): - sim_registry.register(id_sim, **kwargs) + @registry.register_simulator(name="Sim-v0") + class HabitatSimImportError(Simulator): + def __init__(self, *args, **kwargs): + raise habitat_sim_import_error def make_sim(id_sim, **kwargs): logger.info("initializing sim {}".format(id_sim)) - return sim_registry.make(id_sim, **kwargs) + _sim = registry.get_simulator(id_sim) + assert _sim is not None, "Could not find simulator with name {}".format( + id_sim + ) + return _sim(**kwargs) -def get_spec_sim(id_sim): - return sim_registry.get_spec(id_sim) +_try_register_habitat_sim() diff --git a/habitat/tasks/__init__.py b/habitat/tasks/__init__.py index 05477427562d4f543e1a0bb4ec646d16a494ffe2..8a0437e9023a6b022acf051ae049fdb976ef39ad 100644 --- a/habitat/tasks/__init__.py +++ b/habitat/tasks/__init__.py @@ -4,10 +4,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from habitat.tasks.registration import task_registry, register_task, make_task - -register_task(id_task="EQA-v0", entry_point="habitat.tasks.eqa:EQATask") - -register_task(id_task="Nav-v0", entry_point="habitat.tasks.nav:NavigationTask") - -__all__ = ["task_registry", "register_task", "make_task"] +from habitat.tasks.registration import make_task diff --git a/habitat/tasks/eqa/eqa_task.py b/habitat/tasks/eqa/eqa_task.py index 6898db2a2bf3737a639491ed417c0ce9a23e7a7f..dbb0c1f0d6321fd154c410307f72529ba6c559a0 100644 --- a/habitat/tasks/eqa/eqa_task.py +++ b/habitat/tasks/eqa/eqa_task.py @@ -10,6 +10,7 @@ import attr import numpy as np from gym import spaces +from habitat.core.registry import registry from habitat.core.simulator import ( Observations, Sensor, @@ -112,6 +113,7 @@ class RewardSensor(Sensor): return self._get_observation(**kwargs) +@registry.register_task(name="EQA-v0") class EQATask(NavigationTask): _sensor_suite: SensorSuite diff --git a/habitat/tasks/nav/nav_task.py b/habitat/tasks/nav/nav_task.py index c46cf2e8b52b88fce3ca36f6e1cc069e966619b1..578cccc5b7454bf20f82ca57a1a85edbed4bfbb7 100644 --- a/habitat/tasks/nav/nav_task.py +++ b/habitat/tasks/nav/nav_task.py @@ -11,11 +11,12 @@ import cv2 import numpy as np from gym import spaces -import habitat from habitat.config import Config from habitat.core.dataset import Dataset, Episode -from habitat.core.embodied_task import Measurements +from habitat.core.embodied_task import EmbodiedTask, Measure, Measurements +from habitat.core.registry import registry from habitat.core.simulator import ( + Sensor, SensorSuite, SensorTypes, ShortestPathPoint, @@ -105,7 +106,8 @@ class NavigationEpisode(Episode): shortest_paths: Optional[List[ShortestPathPoint]] = None -class PointGoalSensor(habitat.Sensor): +@registry.register_sensor +class PointGoalSensor(Sensor): """ Sensor for PointGoal observations which are used in the PointNav task. For the agent in simulator the forward direction is along negative-z. @@ -171,7 +173,8 @@ class PointGoalSensor(habitat.Sensor): return direction_vector_agent -class StaticPointGoalSensor(habitat.Sensor): +@registry.register_sensor +class StaticPointGoalSensor(Sensor): """ Sensor for PointGoal observations which are used in the StaticPointNav task. For the agent in simulator the forward direction is along negative-z. @@ -243,7 +246,8 @@ class StaticPointGoalSensor(habitat.Sensor): return self._initial_vector -class HeadingSensor(habitat.Sensor): +@registry.register_sensor +class HeadingSensor(Sensor): """ Sensor for observing the agent's heading in the global coordinate frame. @@ -279,7 +283,8 @@ class HeadingSensor(habitat.Sensor): return np.array(phi) -class ProximitySensor(habitat.Sensor): +@registry.register_sensor +class ProximitySensor(Sensor): """ Sensor for observing the distance to the closest obstacle @@ -317,7 +322,8 @@ class ProximitySensor(habitat.Sensor): ) -class SPL(habitat.Measure): +@registry.register_measure +class SPL(Measure): """SPL (Success weighted by Path Length) ref: On Evaluation of Embodied Agents - Anderson et. al @@ -375,7 +381,8 @@ class SPL(habitat.Measure): ) -class Collisions(habitat.Measure): +@registry.register_measure +class Collisions(Measure): def __init__(self, sim, config): self._sim = sim self._config = config @@ -401,7 +408,8 @@ class Collisions(habitat.Measure): self._metric += 1 -class TopDownMap(habitat.Measure): +@registry.register_measure +class TopDownMap(Measure): """Top Down Map measure """ @@ -547,7 +555,8 @@ class TopDownMap(habitat.Measure): return self._top_down_map, a_x, a_y -class NavigationTask(habitat.EmbodiedTask): +@registry.register_task(name="Nav-v0") +class NavigationTask(EmbodiedTask): def __init__( self, task_config: Config, @@ -558,35 +567,21 @@ class NavigationTask(habitat.EmbodiedTask): task_measurements = [] for measurement_name in task_config.MEASUREMENTS: measurement_cfg = getattr(task_config, measurement_name) - is_valid_measurement = hasattr( - habitat.tasks.nav.nav_task, # type: ignore - measurement_cfg.TYPE, - ) - assert is_valid_measurement, "invalid measurement type {}".format( - measurement_cfg.TYPE - ) - task_measurements.append( - getattr( - habitat.tasks.nav.nav_task, # type: ignore - measurement_cfg.TYPE, - )(sim, measurement_cfg) - ) + measure_type = registry.get_measure(measurement_cfg.TYPE) + assert ( + measure_type is not None + ), "invalid measurement type {}".format(measurement_cfg.TYPE) + task_measurements.append(measure_type(sim, measurement_cfg)) self.measurements = Measurements(task_measurements) task_sensors = [] for sensor_name in task_config.SENSORS: sensor_cfg = getattr(task_config, sensor_name) - is_valid_sensor = hasattr( - habitat.tasks.nav.nav_task, sensor_cfg.TYPE # type: ignore - ) - assert is_valid_sensor, "invalid sensor type {}".format( + sensor_type = registry.get_sensor(sensor_cfg.TYPE) + assert sensor_type is not None, "invalid sensor type {}".format( sensor_cfg.TYPE ) - task_sensors.append( - getattr( - habitat.tasks.nav.nav_task, sensor_cfg.TYPE # type: ignore - )(sim, sensor_cfg) - ) + task_sensors.append(sensor_type(sim, sensor_cfg)) self.sensor_suite = SensorSuite(task_sensors) super().__init__(config=task_config, sim=sim, dataset=dataset) diff --git a/habitat/tasks/registration.py b/habitat/tasks/registration.py index 0c7543fb5f499b537387f9419ef7353494723e44..90400ce079b75f17140be01f88c9e25764e5904b 100644 --- a/habitat/tasks/registration.py +++ b/habitat/tasks/registration.py @@ -5,43 +5,16 @@ # LICENSE file in the root directory of this source tree. from habitat.core.logging import logger -from habitat.core.registry import Registry, Spec - - -class TaskSpec(Spec): - def __init__(self, id_task, entry_point): - super().__init__(id_task, entry_point) - - -class TaskRegistry(Registry): - """Registry for maintaining tasks. - - Args: - id_task: id for task being registered. - kwargs: arguments to be passed to task constructor. - """ - - def register(self, id_task, **kwargs): - if id_task in self.specs: - raise ValueError( - "Cannot re-register task specification with id: {}".format( - id_task - ) - ) - self.specs[id_task] = TaskSpec(id_task, **kwargs) - - -task_registry = TaskRegistry() - - -def register_task(id_task, **kwargs): - task_registry.register(id_task, **kwargs) +from habitat.core.registry import registry +from habitat.tasks.eqa.eqa_task import EQATask +from habitat.tasks.nav.nav_task import NavigationTask def make_task(id_task, **kwargs): logger.info("initializing task {}".format(id_task)) - return task_registry.make(id_task, **kwargs) - + _task = registry.get_task(id_task) + assert _task is not None, "Could not find task with name {}".format( + id_task + ) -def get_spec_task(id_task): - return task_registry.get_spec(id_task) + return _task(**kwargs) diff --git a/habitat/utils/visualizations/__init__.py b/habitat/utils/visualizations/__init__.py index 74b7b453c0daa5c0b3bd61440c51f7de298ea539..c27ce957cafcc35cf003ddf618a3467f2b30a5d8 100644 --- a/habitat/utils/visualizations/__init__.py +++ b/habitat/utils/visualizations/__init__.py @@ -4,7 +4,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from habitat.utils.visualizations import maps -from habitat.utils.visualizations import utils +from habitat.utils.visualizations import maps, utils __all__ = ["maps", "utils"] diff --git a/habitat_baselines/rl/ppo/__init__.py b/habitat_baselines/rl/ppo/__init__.py index fad59c373afdd59cb57b179c9cdcc91e6e331be0..128cc360c41ea8fc110d8f7c7e96e07345fc7afc 100644 --- a/habitat_baselines/rl/ppo/__init__.py +++ b/habitat_baselines/rl/ppo/__init__.py @@ -4,8 +4,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from habitat_baselines.rl.ppo.ppo import PPO from habitat_baselines.rl.ppo.policy import Policy +from habitat_baselines.rl.ppo.ppo import PPO from habitat_baselines.rl.ppo.utils import RolloutStorage __all__ = ["PPO", "Policy", "RolloutStorage"] diff --git a/habitat_baselines/slambased/path_planners.py b/habitat_baselines/slambased/path_planners.py index 9e15e15747e00d3fdede0ebfe76f2849e956a7d7..f56776df416f1036cd61e565f4d5ec39d8206132 100644 --- a/habitat_baselines/slambased/path_planners.py +++ b/habitat_baselines/slambased/path_planners.py @@ -1,9 +1,9 @@ -import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +import matplotlib.pyplot as plt from habitat_baselines.slambased.utils import generate_2dgrid diff --git a/test/test_habitat_example.py b/test/test_habitat_example.py index 1f95c35f419e19ea401abe86eae9e37ee0408da2..d4884c1223d357a6840f9872a40af01c5a7b4350 100644 --- a/test/test_habitat_example.py +++ b/test/test_habitat_example.py @@ -7,7 +7,11 @@ import pytest import habitat -from examples import shortest_path_follower_example, visualization_examples +from examples import ( + register_new_sensors_and_measures, + shortest_path_follower_example, + visualization_examples, +) from examples.example import example from habitat.datasets.pointnav.pointnav_dataset import PointNavDatasetV1 @@ -34,3 +38,12 @@ def test_shortest_path_follower_example(): ): pytest.skip("Please download Habitat test data to data folder.") shortest_path_follower_example.main() + + +def test_register_new_sensors_and_measures(): + if not PointNavDatasetV1.check_config_paths_exist( + config=habitat.get_config().DATASET + ): + pytest.skip("Please download Habitat test data to data folder.") + + register_new_sensors_and_measures.main()