From 172746769dcff6053c012da2cb5f955efaaeb36e Mon Sep 17 00:00:00 2001
From: Oleksandr Maksymets <maksymets@gmail.com>
Date: Sun, 26 Jan 2020 17:01:28 -0800
Subject: [PATCH] Added ObjectNav task definition, dataset, metrics, goal
 sensor

---
 configs/tasks/obj_nav_mp3d.yaml               |  47 ++++++
 .../test/habitat_mp3d_object_nav_test.yaml    |  48 ++++++
 habitat/config/default.py                     |  12 +-
 habitat/core/utils.py                         |  82 ++++++++++
 habitat/datasets/eqa/mp3d_eqa_dataset.py      |   3 +-
 habitat/datasets/object_nav/__init__.py       |  38 +++++
 .../datasets/object_nav/object_nav_dataset.py |  97 ++++++++++++
 habitat/datasets/registration.py              |   2 +
 habitat/tasks/nav/nav.py                      | 126 +++++++++++----
 habitat/tasks/nav/object_nav_task.py          | 149 ++++++++++++++++++
 habitat/tasks/utils.py                        |   6 +
 test/test_object_nav_task.py                  |  84 ++++++++++
 12 files changed, 661 insertions(+), 33 deletions(-)
 create mode 100644 configs/tasks/obj_nav_mp3d.yaml
 create mode 100644 configs/test/habitat_mp3d_object_nav_test.yaml
 create mode 100644 habitat/datasets/object_nav/__init__.py
 create mode 100644 habitat/datasets/object_nav/object_nav_dataset.py
 create mode 100644 habitat/tasks/nav/object_nav_task.py
 create mode 100644 test/test_object_nav_task.py

diff --git a/configs/tasks/obj_nav_mp3d.yaml b/configs/tasks/obj_nav_mp3d.yaml
new file mode 100644
index 000000000..8fdcb4aa7
--- /dev/null
+++ b/configs/tasks/obj_nav_mp3d.yaml
@@ -0,0 +1,47 @@
+ENVIRONMENT:
+  MAX_EPISODE_STEPS: 750
+SIMULATOR:
+  TURN_ANGLE: 30
+  TILT_ANGLE: 30
+  AGENT_0:
+    SENSORS: ['RGB_SENSOR', 'DEPTH_SENSOR']
+    HEIGHT: 0.88
+    RADIUS: 0.18
+  HABITAT_SIM_V0:
+    GPU_DEVICE_ID: 0
+  SEMANTIC_SENSOR:
+    WIDTH: 640
+    HEIGHT: 480
+    HFOV: 59
+    POSITION: [0, 0.88, 0]
+  RGB_SENSOR:
+    WIDTH: 640
+    HEIGHT: 480
+    HFOV: 59
+    POSITION: [0, 0.88, 0]
+  DEPTH_SENSOR:
+    WIDTH: 640
+    HEIGHT: 480
+    HFOV: 59
+    MIN_DEPTH: 0.5
+    MAX_DEPTH: 3.86
+    POSITION: [0, 0.88, 0]
+TASK:
+  TYPE: ObjectNav-v0
+  SUCCESS_DISTANCE: 0.1
+
+  SENSORS: ['OBJECTGOAL_SENSOR']
+  GOAL_SENSOR_UUID: objectgoal
+
+  MEASUREMENTS: ['SPL']
+  SPL:
+    TYPE: SPL
+    DISTANCE_TO: VIEW_POINTS
+    SUCCESS_DISTANCE: 0.2
+
+DATASET:
+  TYPE: ObjectNav-v1
+  SPLIT: val
+  CONTENT_SCENES: []
+  DATA_PATH: "data/datasets/objectnav/mp3d/v1/{split}/{split}.json.gz"
+  SCENES_DIR: "data/scene_datasets/"
diff --git a/configs/test/habitat_mp3d_object_nav_test.yaml b/configs/test/habitat_mp3d_object_nav_test.yaml
new file mode 100644
index 000000000..80d254ed4
--- /dev/null
+++ b/configs/test/habitat_mp3d_object_nav_test.yaml
@@ -0,0 +1,48 @@
+ENVIRONMENT:
+  MAX_EPISODE_STEPS: 750
+SIMULATOR:
+  TURN_ANGLE: 30
+  TILT_ANGLE: 30
+  AGENT_0:
+    SENSORS: ['RGB_SENSOR', 'DEPTH_SENSOR']
+    HEIGHT: 0.88
+    RADIUS: 0.18
+  HABITAT_SIM_V0:
+    GPU_DEVICE_ID: 0
+  SEMANTIC_SENSOR:
+    WIDTH: 640
+    HEIGHT: 480
+    HFOV: 59
+    POSITION: [0, 0.88, 0]
+  RGB_SENSOR:
+    WIDTH: 640
+    HEIGHT: 480
+    HFOV: 59
+    POSITION: [0, 0.88, 0]
+  DEPTH_SENSOR:
+    WIDTH: 640
+    HEIGHT: 480
+    HFOV: 59
+    MIN_DEPTH: 0.5
+    MAX_DEPTH: 3.86
+    POSITION: [0, 0.88, 0]
+
+TASK:
+  TYPE: ObjectNav-v0
+  SUCCESS_DISTANCE: 0.1
+
+  SENSORS: ['OBJECTGOAL_SENSOR']
+  GOAL_SENSOR_UUID: objectgoal
+
+  MEASUREMENTS: ['SPL']
+  SPL:
+    TYPE: SPL
+    DISTANCE_TO: VIEW_POINTS
+    SUCCESS_DISTANCE: 0.2
+
+DATASET:
+  TYPE: ObjectNav-v1
+  SPLIT: mini_val
+  CONTENT_SCENES: []
+  DATA_PATH: "data/datasets/objectnav/mp3d/v1/{split}/{split}.json.gz"
+  SCENES_DIR: "data/scene_datasets/"
diff --git a/habitat/config/default.py b/habitat/config/default.py
index 8a8136855..c2de0fdf6 100644
--- a/habitat/config/default.py
+++ b/habitat/config/default.py
@@ -94,6 +94,13 @@ _C.TASK.POINTGOAL_WITH_GPS_COMPASS_SENSOR.TYPE = (
     "PointGoalWithGPSCompassSensor"
 )
 # -----------------------------------------------------------------------------
+# OBJECTGOAL SENSOR
+# -----------------------------------------------------------------------------
+_C.TASK.OBJECTGOAL_SENSOR = CN()
+_C.TASK.OBJECTGOAL_SENSOR.TYPE = "ObjectGoalSensor"
+_C.TASK.OBJECTGOAL_SENSOR.GOAL_SPEC = "TASK_CATEGORY_ID"
+_C.TASK.OBJECTGOAL_SENSOR.GOAL_SPEC_MAX_VAL = 50
+# -----------------------------------------------------------------------------
 # HEADING SENSOR
 # -----------------------------------------------------------------------------
 _C.TASK.HEADING_SENSOR = CN()
@@ -121,6 +128,7 @@ _C.TASK.PROXIMITY_SENSOR.MAX_DETECTION_RADIUS = 2.0
 _C.TASK.SPL = CN()
 _C.TASK.SPL.TYPE = "SPL"
 _C.TASK.SPL.SUCCESS_DISTANCE = 0.2
+_C.TASK.SPL.DISTANCE_TO = "POINT"
 # -----------------------------------------------------------------------------
 # TopDownMap MEASUREMENT
 # -----------------------------------------------------------------------------
@@ -210,8 +218,8 @@ _C.SIMULATOR.RGB_SENSOR.TYPE = "HabitatSimRGBSensor"
 # -----------------------------------------------------------------------------
 _C.SIMULATOR.DEPTH_SENSOR = SIMULATOR_SENSOR.clone()
 _C.SIMULATOR.DEPTH_SENSOR.TYPE = "HabitatSimDepthSensor"
-_C.SIMULATOR.DEPTH_SENSOR.MIN_DEPTH = 0
-_C.SIMULATOR.DEPTH_SENSOR.MAX_DEPTH = 10
+_C.SIMULATOR.DEPTH_SENSOR.MIN_DEPTH = 0.0
+_C.SIMULATOR.DEPTH_SENSOR.MAX_DEPTH = 10.0
 _C.SIMULATOR.DEPTH_SENSOR.NORMALIZE_DEPTH = True
 # -----------------------------------------------------------------------------
 # SEMANTIC SENSOR
diff --git a/habitat/core/utils.py b/habitat/core/utils.py
index 7b4bb2994..e5b7f15bf 100644
--- a/habitat/core/utils.py
+++ b/habitat/core/utils.py
@@ -4,9 +4,22 @@
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
 
+import json
 from typing import List
 
 import numpy as np
+import quaternion
+
+from habitat.utils.geometry_utils import quaternion_to_list
+
+try:
+    from _json import encode_basestring_ascii
+except ImportError:
+    encode_basestring_ascii = None
+try:
+    from _json import encode_basestring
+except ImportError:
+    encode_basestring = None
 
 
 def tile_images(images: List[np.ndarray]) -> np.ndarray:
@@ -92,3 +105,72 @@ def center_crop(obs, new_shape):
     obs = obs[top_left[0] : bottom_right[0], top_left[1] : bottom_right[1], :]
 
     return obs
+
+
+class DatasetFloatJSONEncoder(json.JSONEncoder):
+    """
+        JSON Encoder that set float precision for space saving purpose.
+    """
+
+    # Version of JSON library that encoder is compatible with.
+    __version__ = "2.0.9"
+
+    def default(self, object):
+        # JSON doesn't support numpy ndarray and quaternion
+        if isinstance(object, np.ndarray):
+            return object.tolist()
+        if isinstance(object, np.quaternion):
+            return quaternion_to_list(object)
+        quaternion
+        return object.__dict__
+
+    # Overriding method to inject own `_repr` function for floats with needed
+    # precision.
+    def iterencode(self, o, _one_shot=False):
+
+        if self.check_circular:
+            markers = {}
+        else:
+            markers = None
+        if self.ensure_ascii:
+            _encoder = encode_basestring_ascii
+        else:
+            _encoder = encode_basestring
+
+        def floatstr(
+            o,
+            allow_nan=self.allow_nan,
+            _repr=lambda x: format(x, ".5f"),
+            _inf=float("inf"),
+            _neginf=-float("inf"),
+        ):
+            if o != o:
+                text = "NaN"
+            elif o == _inf:
+                text = "Infinity"
+            elif o == _neginf:
+                text = "-Infinity"
+            else:
+                return _repr(o)
+
+            if not allow_nan:
+                raise ValueError(
+                    "Out of range float values are not JSON compliant: "
+                    + repr(o)
+                )
+
+            return text
+
+        _iterencode = json.encoder._make_iterencode(
+            markers,
+            self.default,
+            _encoder,
+            self.indent,
+            floatstr,
+            self.key_separator,
+            self.item_separator,
+            self.sort_keys,
+            self.skipkeys,
+            _one_shot,
+        )
+        return _iterencode(o, 0)
diff --git a/habitat/datasets/eqa/mp3d_eqa_dataset.py b/habitat/datasets/eqa/mp3d_eqa_dataset.py
index 610aefe5f..bf6b92021 100644
--- a/habitat/datasets/eqa/mp3d_eqa_dataset.py
+++ b/habitat/datasets/eqa/mp3d_eqa_dataset.py
@@ -15,7 +15,8 @@ from habitat.core.registry import registry
 from habitat.core.simulator import AgentState
 from habitat.datasets.utils import VocabDict, VocabFromText
 from habitat.tasks.eqa.eqa import EQAEpisode, QuestionData
-from habitat.tasks.nav.nav import ObjectGoal, ShortestPathPoint
+from habitat.tasks.nav.nav import ShortestPathPoint
+from habitat.tasks.nav.object_nav_task import ObjectGoal
 
 EQA_MP3D_V1_VAL_EPISODE_COUNT = 1950
 DEFAULT_SCENE_PATH_PREFIX = "data/scene_datasets/"
diff --git a/habitat/datasets/object_nav/__init__.py b/habitat/datasets/object_nav/__init__.py
new file mode 100644
index 000000000..c37139061
--- /dev/null
+++ b/habitat/datasets/object_nav/__init__.py
@@ -0,0 +1,38 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from habitat.core.dataset import Dataset
+from habitat.core.registry import registry
+
+
+# TODO(akadian): This is a result of moving SimulatorActions away from core
+# and into simulators specifically. As a result of that the connection points
+# for our tasks and datasets for actions is coming from inside habitat-sim
+# which makes it impossible for anyone to use habitat-api without having
+# habitat-sim installed. In a future PR we will implement a base simulator
+# action class which will be the connection point for tasks and datasets.
+# Post that PR we would no longer need try register blocks.
+def _try_register_objectnavdatasetv1():
+    try:
+        from habitat.datasets.object_nav.object_nav_dataset import (
+            ObjectNavDatasetV1,
+        )
+
+        has_pointnav = True
+    except ImportError as e:
+        has_pointnav = False
+        pointnav_import_error = e
+
+    if has_pointnav:
+        from habitat.datasets.object_nav.object_nav_dataset import (
+            ObjectNavDatasetV1,
+        )
+    else:
+
+        @registry.register_dataset(name="ObjectNav-v1")
+        class ObjectNavDatasetImportError(Dataset):
+            def __init__(self, *args, **kwargs):
+                raise pointnav_import_error
diff --git a/habitat/datasets/object_nav/object_nav_dataset.py b/habitat/datasets/object_nav/object_nav_dataset.py
new file mode 100644
index 000000000..bfdc9e765
--- /dev/null
+++ b/habitat/datasets/object_nav/object_nav_dataset.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+from typing import Dict, List, Optional
+
+from habitat.core.registry import registry
+from habitat.core.simulator import AgentState, ShortestPathPoint
+from habitat.core.utils import DatasetFloatJSONEncoder
+from habitat.datasets.pointnav.pointnav_dataset import (
+    CONTENT_SCENES_PATH_FIELD,
+    DEFAULT_SCENE_PATH_PREFIX,
+    PointNavDatasetV1,
+)
+from habitat.tasks.nav.nav import NavigationEpisode
+from habitat.tasks.nav.object_nav_task import ObjectGoal, ObjectViewLocation
+
+
+@registry.register_dataset(name="ObjectNav-v1")
+class ObjectNavDatasetV1(PointNavDatasetV1):
+    r"""Class inherited from PointNavDataset that loads Object Navigation dataset.
+    """
+    category_to_task_category_id: Dict[str, int]
+    category_to_mp3d_category_id: Dict[str, int]
+    episodes: List[NavigationEpisode]
+    content_scenes_path: str = "{data_path}/content/{scene}.json.gz"
+
+    def to_json(self) -> str:
+        result = DatasetFloatJSONEncoder().encode(self)
+        return result
+
+    def from_json(
+        self, json_str: str, scenes_dir: Optional[str] = None
+    ) -> None:
+        deserialized = json.loads(json_str)
+        if CONTENT_SCENES_PATH_FIELD in deserialized:
+            self.content_scenes_path = deserialized[CONTENT_SCENES_PATH_FIELD]
+
+        if "category_to_task_category_id" in deserialized:
+            self.category_to_task_category_id = deserialized[
+                "category_to_task_category_id"
+            ]
+
+        if "category_to_mp3d_category_id" in deserialized:
+            self.category_to_mp3d_category_id = deserialized[
+                "category_to_mp3d_category_id"
+            ]
+
+        assert len(self.category_to_task_category_id) == len(
+            self.category_to_mp3d_category_id
+        )
+
+        assert set(self.category_to_task_category_id.keys()) == set(
+            self.category_to_mp3d_category_id.keys()
+        ), "category_to_task and category_to_mp3d must have the same keys"
+
+        for episode in deserialized["episodes"]:
+            episode = NavigationEpisode(**episode)
+
+            if scenes_dir is not None:
+                if episode.scene_id.startswith(DEFAULT_SCENE_PATH_PREFIX):
+                    episode.scene_id = episode.scene_id[
+                        len(DEFAULT_SCENE_PATH_PREFIX) :
+                    ]
+
+                episode.scene_id = os.path.join(scenes_dir, episode.scene_id)
+
+            for i in range(len(episode.goals)):
+                if "best_iou" in episode.goals[i]:  # remove before release
+                    del episode.goals[i]["best_iou"]
+                episode.goals[i] = ObjectGoal(**episode.goals[i])
+
+                for vidx, view in enumerate(episode.goals[i].view_points):
+                    view_location = ObjectViewLocation(**view)
+                    view_location.agent_state = AgentState(
+                        **view_location.agent_state
+                    )
+                    episode.goals[i].view_points[vidx] = view_location
+
+            if episode.shortest_paths is not None:
+                for path in episode.shortest_paths:
+                    for p_index, point in enumerate(path):
+                        point = {
+                            "action": point,
+                            "rotation": None,
+                            "position": None,
+                        }
+                        path[p_index] = ShortestPathPoint(**point)
+
+            self.episodes.append(episode)
+
+        for i, ep in enumerate(self.episodes):
+            ep.episode_id = str(i)
diff --git a/habitat/datasets/registration.py b/habitat/datasets/registration.py
index a9f3add46..27e509e95 100644
--- a/habitat/datasets/registration.py
+++ b/habitat/datasets/registration.py
@@ -7,6 +7,7 @@
 from habitat.core.logging import logger
 from habitat.core.registry import registry
 from habitat.datasets.eqa import _try_register_mp3d_eqa_dataset
+from habitat.datasets.object_nav import _try_register_objectnavdatasetv1
 from habitat.datasets.pointnav import _try_register_pointnavdatasetv1
 from habitat.datasets.vln import _try_register_r2r_vln_dataset
 
@@ -19,6 +20,7 @@ def make_dataset(id_dataset, **kwargs):
     return _dataset(**kwargs)
 
 
+_try_register_objectnavdatasetv1()
 _try_register_mp3d_eqa_dataset()
 _try_register_pointnavdatasetv1()
 _try_register_r2r_vln_dataset()
diff --git a/habitat/tasks/nav/nav.py b/habitat/tasks/nav/nav.py
index c1bbf8dd6..31a829f94 100644
--- a/habitat/tasks/nav/nav.py
+++ b/habitat/tasks/nav/nav.py
@@ -10,6 +10,7 @@ import attr
 import numpy as np
 from gym import spaces
 
+import habitat_sim
 from habitat.config import Config
 from habitat.core.dataset import Dataset, Episode
 from habitat.core.embodied_task import (
@@ -70,20 +71,6 @@ class NavigationGoal:
     radius: Optional[float] = None
 
 
-@attr.s(auto_attribs=True, kw_only=True)
-class ObjectGoal(NavigationGoal):
-    r"""Object goal that can be specified by object_id or position or object
-    category.
-    """
-
-    object_id: str = attr.ib(default=None, validator=not_none_validator)
-    object_name: Optional[str] = None
-    object_category: Optional[str] = None
-    room_id: Optional[str] = None
-    room_name: Optional[str] = None
-    view_points: Optional[List[AgentState]] = None
-
-
 @attr.s(auto_attribs=True, kw_only=True)
 class RoomGoal(NavigationGoal):
     r"""Room goal that can be specified by room_id or position with radius.
@@ -429,6 +416,7 @@ class SPL(Measure):
         self._previous_position = None
         self._start_end_episode_distance = None
         self._agent_episode_distance = None
+        self._episode_view_points = None
         self._sim = sim
         self._config = config
 
@@ -441,7 +429,14 @@ class SPL(Measure):
         self._previous_position = self._sim.get_agent_state().position.tolist()
         self._start_end_episode_distance = episode.info["geodesic_distance"]
         self._agent_episode_distance = 0.0
+        self._previous_distance_to_target = 0.0  # remove after debug
         self._metric = None
+        if self._config.DISTANCE_TO == "VIEW_POINTS":
+            self._episode_view_points = [
+                view_point.agent_state.position
+                for goal in episode.goals
+                for view_point in goal.view_points
+            ]
 
     def _euclidean_distance(self, position_a, position_b):
         return np.linalg.norm(
@@ -454,9 +449,27 @@ class SPL(Measure):
         ep_success = 0
         current_position = self._sim.get_agent_state().position.tolist()
 
-        distance_to_target = self._sim.geodesic_distance(
-            current_position, episode.goals[0].position
-        )
+        if self._config.DISTANCE_TO == "POINT":
+            distance_to_target = self._sim.geodesic_distance(
+                current_position, episode.goals[0].position
+            )
+        elif self._config.DISTANCE_TO == "VIEW_POINTS":
+            multi_goal = habitat_sim.MultiGoalShortestPath()
+            multi_goal.requested_start = current_position
+            multi_goal.requested_ends = self._episode_view_points
+            self._sim._sim.pathfinder.find_path(multi_goal)
+            distance_to_target = multi_goal.geodesic_distance
+
+        # remove after debug
+        # print(
+        #     f"distance_to_target: {distance_to_target}, delta d: {self._previous_distance_to_target - distance_to_target}"
+        # )
+        self._previous_distance_to_target = distance_to_target
+        if (
+            self._previous_distance_to_target - distance_to_target
+            > self._sim.config.FORWARD_STEP_SIZE
+        ):
+            print("!!! distance_to_target change more than forward step")
 
         if (
             hasattr(task, "is_stop_called")
@@ -577,18 +590,71 @@ class TopDownMap(Measure):
             s_y - point_padding : s_y + point_padding + 1,
         ] = maps.MAP_SOURCE_POINT_INDICATOR
 
-        # mark target point
-        t_x, t_y = maps.to_grid(
-            episode.goals[0].position[0],
-            episode.goals[0].position[2],
-            self._coordinate_min,
-            self._coordinate_max,
-            self._map_resolution,
-        )
-        self._top_down_map[
-            t_x - point_padding : t_x + point_padding + 1,
-            t_y - point_padding : t_y + point_padding + 1,
-        ] = maps.MAP_TARGET_POINT_INDICATOR
+        for goal in episode.goals:
+            if goal.view_points is not None:
+                for view_point in goal.view_points:
+                    # mark view point
+                    t_x, t_y = maps.to_grid(
+                        view_point.agent_state.position[0],
+                        view_point.agent_state.position[2],
+                        self._coordinate_min,
+                        self._coordinate_max,
+                        self._map_resolution,
+                    )
+
+                    self._top_down_map[
+                        t_x - point_padding : t_x + point_padding + 1,
+                        t_y - point_padding : t_y + point_padding + 1,
+                    ] = maps.MAP_VIEW_POINT_INDICATOR
+
+        for goal in episode.goals:
+            # mark target point
+            t_x, t_y = maps.to_grid(
+                goal.position[0],
+                goal.position[2],
+                self._coordinate_min,
+                self._coordinate_max,
+                self._map_resolution,
+            )
+
+            self._top_down_map[
+                t_x - point_padding : t_x + point_padding + 1,
+                t_y - point_padding : t_y + point_padding + 1,
+            ] = maps.MAP_TARGET_POINT_INDICATOR
+
+            sem_scene = self._sim.semantic_annotations()
+            object_id = goal.object_id
+            assert int(sem_scene.objects[object_id].id.split("_")[-1]) == int(
+                goal.object_id
+            )
+
+            center = sem_scene.objects[object_id].aabb.center
+            import itertools
+
+            x_len, _, z_len = sem_scene.objects[object_id].aabb.sizes / 2.0
+            corners = [
+                center + np.array([x, 0, z])
+                for x, z in itertools.product([-x_len, x_len], [z_len, -z_len])
+            ]
+            corners[2], corners[3] = corners[3], corners[2]
+            corners.append(corners[0])
+            map_corners = [
+                maps.to_grid(
+                    p[0],
+                    p[2],
+                    self._coordinate_min,
+                    self._coordinate_max,
+                    self._map_resolution,
+                )[::-1]
+                for p in corners
+            ]
+
+            maps.draw_path(
+                self._top_down_map,
+                map_corners,
+                maps.MAP_TARGET_BOUNDING_BOX,
+                self.line_thickness,
+            )
 
     def reset_metric(self, *args: Any, episode, **kwargs: Any):
         self._step_count = 0
@@ -854,7 +920,7 @@ class TeleportAction(SimulatorTaskAction):
         *args: Any,
         position: List[float],
         rotation: List[float],
-        **kwargs: Any
+        **kwargs: Any,
     ):
         r"""Update ``_metric``, this method is called from ``Env`` on each
         ``step``.
diff --git a/habitat/tasks/nav/object_nav_task.py b/habitat/tasks/nav/object_nav_task.py
new file mode 100644
index 000000000..d7adfdb1b
--- /dev/null
+++ b/habitat/tasks/nav/object_nav_task.py
@@ -0,0 +1,149 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any, List, Optional
+
+import attr
+import numpy as np
+from gym import spaces
+
+from habitat.config import Config
+from habitat.core.dataset import Dataset
+from habitat.core.logging import logger
+from habitat.core.registry import registry
+from habitat.core.simulator import AgentState, Sensor, SensorTypes
+from habitat.core.utils import not_none_validator
+from habitat.tasks.nav.nav import (
+    NavigationEpisode,
+    NavigationGoal,
+    NavigationTask,
+)
+
+
+@attr.s(auto_attribs=True)
+class ObjectViewLocation:
+    r"""ObjectViewLocation provides information about a position around an object goal
+    usually that is navigable and the object is visible with specific agent
+    configuration that episode's dataset was created.
+     that is target for
+    navigation. That can be specify object_id, position and object
+    category. An important part for metrics calculation are view points that
+     describe success area for the navigation.
+
+    Args:
+        agent_state: navigable AgentState with a position and a rotation where
+        the object is visible.
+        iou: an intersection of a union of the object and a rectangle in the
+        center of view. This metric is used to evaluate how good is the object
+        view form current position. Higher iou means better view, iou equals
+        1.0 if whole object is inside of the rectangle and no pixel inside
+        the rectangle belongs to anything except the object.
+    """
+    agent_state: AgentState
+    iou: Optional[float]
+
+
+@attr.s(auto_attribs=True, kw_only=True)
+class ObjectGoal(NavigationGoal):
+    r"""Object goal provides information about an object that is target for
+    navigation. That can be specify object_id, position and object
+    category. An important part for metrics calculation are view points that
+     describe success area for the navigation.
+
+    Args:
+        object_id: id that can be used to retrieve object from the semantic
+        scene annotation
+        object_name: name of the object
+        object_category: object category name usually similar to scene semantic
+        categories
+        room_id: id of a room where object is located, can be used to retrieve
+        room from the semantic scene annotation
+        room_name: name of the room, where object is located
+        view_points: navigable positions around the object with specified
+        proximity of the object surface used for navigation metrics calculation.
+        The object is visible from these positions.
+    """
+
+    object_id: str = attr.ib(default=None, validator=not_none_validator)
+    object_name: Optional[str] = None
+    object_category: Optional[str] = None
+    room_id: Optional[str] = None
+    room_name: Optional[str] = None
+    view_points: Optional[List[ObjectViewLocation]] = None
+
+
+@registry.register_sensor
+class ObjectGoalSensor(Sensor):
+    r"""A sensor for Object Goal specification as observations which is used in
+    ObjectGoal Navigation. The goal is expected to be specified by object_id or
+    semantic category id.
+    For the agent in simulator the forward direction is along negative-z.
+    In polar coordinate format the angle returned is azimuth to the goal.
+    Args:
+        sim: a reference to the simulator for calculating task observations.
+        config: a config for the ObjectGoalSensor sensor. Can contain field
+            GOAL_SPEC that specifies which id use for goal specification,
+            GOAL_SPEC_MAX_VAL the maximum object_id possible used for
+            observation space definition.
+        dataset: a Object Goal navigation dataset that contains dictionaries
+        of categories id to text mapping.
+    """
+
+    def __init__(
+        self, sim, config: Config, dataset: Dataset, *args: Any, **kwargs: Any
+    ):
+        self._sim = sim
+        self._dataset = dataset
+        super().__init__(config=config)
+
+    def _get_uuid(self, *args: Any, **kwargs: Any):
+        return "objectgoal"
+
+    def _get_sensor_type(self, *args: Any, **kwargs: Any):
+        return SensorTypes.SEMANTIC
+
+    def _get_observation_space(self, *args: Any, **kwargs: Any):
+        sensor_shape = (1,)
+        max_value = (self.config.GOAL_SPEC_MAX_VAL - 1,)
+        if self.config.GOAL_SPEC == "TASK_CATEGORY_ID":
+            max_value = len(self._dataset.category_to_task_category_id)
+
+        return spaces.Box(
+            low=0, high=max_value, shape=sensor_shape, dtype=np.int64
+        )
+
+    def get_observation(
+        self,
+        observations,
+        *args: Any,
+        episode: NavigationEpisode,
+        **kwargs: Any,
+    ) -> Optional[int]:
+        if self.config.GOAL_SPEC == "TASK_CATEGORY_ID":
+            if len(episode.goals) == 0:
+                logger.error(
+                    f"No goal specified for episode {episode.episode_id}."
+                )
+                return None
+            if not isinstance(episode.goals[0], ObjectGoal):
+                logger.error(
+                    f"First goal should be ObjectGoal, episode {episode.episode_id}."
+                )
+                return None
+            category_name = episode.goals[0].object_category
+            return self._dataset.category_to_task_category_id[category_name]
+        elif self.config.GOAL_SPEC == "OBJECT_ID":
+            return np.array([episode.goals[0].object_name_id], dtype=np.int64)
+        else:
+            return None
+
+
+@registry.register_task(name="ObjectNav-v0")
+class ObjectNavigationTask(NavigationTask):
+    r"""An Object Navigation Task class for a task specific methods.
+        Used to explicitly state a type of the task in config.
+    """
+    pass
diff --git a/habitat/tasks/utils.py b/habitat/tasks/utils.py
index c32d4b4f2..f468377d5 100644
--- a/habitat/tasks/utils.py
+++ b/habitat/tasks/utils.py
@@ -64,3 +64,9 @@ def cartesian_to_polar(x, y):
     rho = np.sqrt(x ** 2 + y ** 2)
     phi = np.arctan2(y, x)
     return rho, phi
+
+
+def compute_pixel_coverage(instance_seg, object_id):
+    cand_mask = instance_seg == object_id
+    score = cand_mask.sum().astype(np.float64) / cand_mask.size
+    return score
diff --git a/test/test_object_nav_task.py b/test/test_object_nav_task.py
new file mode 100644
index 000000000..59629fd15
--- /dev/null
+++ b/test/test_object_nav_task.py
@@ -0,0 +1,84 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import time
+
+import numpy as np
+import pytest
+
+import habitat
+from habitat.config.default import get_config
+from habitat.core.embodied_task import Episode
+from habitat.core.logging import logger
+from habitat.datasets import make_dataset
+from habitat.datasets.object_nav.object_nav_dataset import ObjectNavDatasetV1
+from habitat.tasks.eqa.eqa import AnswerAction
+from habitat.tasks.nav.nav import MoveForwardAction
+from habitat.utils.test_utils import sample_non_stop_action
+
+CFG_TEST = "configs/test/habitat_mp3d_object_nav_test.yaml"
+EPISODES_LIMIT = 6
+
+
+def check_json_serializaiton(dataset: habitat.Dataset):
+    start_time = time.time()
+    json_str = str(dataset.to_json())
+    logger.info(
+        "JSON conversion finished. {} sec".format((time.time() - start_time))
+    )
+    decoded_dataset = dataset.__class__()
+    decoded_dataset.from_json(json_str)
+    assert len(decoded_dataset.episodes) > 0
+    episode = decoded_dataset.episodes[0]
+    assert isinstance(episode, Episode)
+    assert (
+        decoded_dataset.to_json() == json_str
+    ), "JSON dataset encoding/decoding isn't consistent"
+
+
+def test_mp3d_object_nav_dataset():
+    dataset_config = get_config(CFG_TEST).DATASET
+    if not ObjectNavDatasetV1.check_config_paths_exist(dataset_config):
+        pytest.skip(
+            "Please download Matterport3D ObjectNav Dataset to data folder."
+        )
+
+    dataset = habitat.make_dataset(
+        id_dataset=dataset_config.TYPE, config=dataset_config
+    )
+    assert dataset
+    check_json_serializaiton(dataset)
+
+
+def test_object_nav_task():
+    config = get_config(CFG_TEST)
+
+    if not ObjectNavDatasetV1.check_config_paths_exist(config.DATASET):
+        pytest.skip(
+            "Please download Matterport3D scene and ObjectNav Datasets to data folder."
+        )
+
+    dataset = make_dataset(
+        id_dataset=config.DATASET.TYPE, config=config.DATASET
+    )
+    env = habitat.Env(config=config, dataset=dataset)
+
+    for i in range(10):
+        env.reset()
+        while not env.episode_over:
+            action = env.action_space.sample()
+            habitat.logger.info(
+                f"Action : "
+                f"{action['action']}, "
+                f"args: {action['action_args']}."
+            )
+            env.step(action)
+
+        metrics = env.get_metrics()
+        logger.info(metrics)
+
+    with pytest.raises(AssertionError):
+        env.step({"action": MoveForwardAction.name})
-- 
GitLab