From 371a9f8bf64ca766a72be3ce75c29f346f61176e Mon Sep 17 00:00:00 2001
From: Daniel Gordon <xkcd@cs.washington.edu>
Date: Mon, 8 Apr 2019 10:45:13 -0700
Subject: [PATCH] Added Greedy Shortest Path Follower (#51)

---
 .gitignore                                  |   3 +-
 baselines/__init__.py                       |   0
 examples/shortest_path_follower_example.py  |  58 ++++++
 examples/visualization_examples.py          |  19 +-
 habitat/core/simulator.py                   |  41 ++++
 habitat/sims/habitat_simulator.py           |  23 ++-
 habitat/tasks/nav/nav_task.py               |  14 +-
 habitat/tasks/nav/shortest_path_follower.py | 202 ++++++++++++++++++++
 habitat/utils/__init__.py                   |   3 +-
 habitat/utils/geometry_utils.py             |  44 +++++
 habitat/utils/visualizations/__init__.py    |   3 +-
 habitat/utils/visualizations/maps.py        |  20 +-
 requirements.txt                            |   1 +
 test/test_habitat_example.py                |   9 +
 14 files changed, 415 insertions(+), 25 deletions(-)
 create mode 100644 baselines/__init__.py
 create mode 100644 examples/shortest_path_follower_example.py
 create mode 100644 habitat/tasks/nav/shortest_path_follower.py
 create mode 100644 habitat/utils/geometry_utils.py

diff --git a/.gitignore b/.gitignore
index 2dbb8ed72..e2a71935b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -42,6 +42,7 @@ htmlcov/
 nosetests.xml
 coverage.xml
 *,cover
+examples/images
 
 # Translations
 *.mo
@@ -76,7 +77,7 @@ target/
 .ipynb_checkpoints/
 
 # exclude data from source control by default
-/data/
+data
 
 # Mac OS-specific storage files
 .DS_Store
diff --git a/baselines/__init__.py b/baselines/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/shortest_path_follower_example.py b/examples/shortest_path_follower_example.py
new file mode 100644
index 000000000..706179073
--- /dev/null
+++ b/examples/shortest_path_follower_example.py
@@ -0,0 +1,58 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import shutil
+
+import imageio
+
+import habitat
+from habitat.sims.habitat_simulator import SIM_NAME_TO_ACTION
+from habitat.tasks.nav.shortest_path_follower import ShortestPathFollower
+
+IMAGE_DIR = os.path.join("examples", "images")
+if not os.path.exists(IMAGE_DIR):
+    os.makedirs(IMAGE_DIR)
+
+
+def shortest_path_example(mode):
+    config = habitat.get_config(config_file="tasks/pointnav.yaml")
+    env = habitat.Env(config=config)
+    goal_radius = env.episodes[0].goals[0].radius
+    if goal_radius is None:
+        goal_radius = config.SIMULATOR.FORWARD_STEP_SIZE
+    follower = ShortestPathFollower(env.sim, goal_radius, False)
+    follower.mode = mode
+
+    print("Environment creation successful")
+    for episode in range(3):
+        observations = env.reset()
+        dirname = os.path.join(
+            IMAGE_DIR, "shortest_path_example", mode, "%02d" % episode
+        )
+        if os.path.exists(dirname):
+            shutil.rmtree(dirname)
+        os.makedirs(dirname)
+        print("Agent stepping around inside environment.")
+        count_steps = 0
+        while not env.episode_over:
+            best_action = follower.get_next_action(
+                env.current_episode.goals[0].position
+            )
+            observations = env.step(SIM_NAME_TO_ACTION[best_action.value])
+            count_steps += 1
+            im = observations["rgb"]
+            imageio.imsave(os.path.join(dirname, "%03d.jpg" % count_steps), im)
+        print("Episode finished after {} steps.".format(count_steps))
+
+
+def main():
+    shortest_path_example("geodesic_path")
+    shortest_path_example("greedy")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/examples/visualization_examples.py b/examples/visualization_examples.py
index ab5e4d281..e5be1bd29 100644
--- a/examples/visualization_examples.py
+++ b/examples/visualization_examples.py
@@ -5,13 +5,19 @@
 # LICENSE file in the root directory of this source tree.
 
 
-import numpy as np
+import os
+
 import imageio
+import numpy as np
 
 import habitat
 from habitat.tasks.nav.nav_task import NavigationEpisode, NavigationGoal
 from habitat.utils.visualizations import maps
 
+IMAGE_DIR = os.path.join("examples", "images")
+if not os.path.exists(IMAGE_DIR):
+    os.makedirs(IMAGE_DIR)
+
 
 def example_pointnav_draw_target_birdseye_view():
     goal_radius = 0.5
@@ -34,7 +40,9 @@ def example_pointnav_draw_target_birdseye_view():
         agent_radius_px=25,
     )
 
-    imageio.imsave("pointnav_target_image.png", target_image)
+    imageio.imsave(
+        os.path.join(IMAGE_DIR, "pointnav_target_image.png"), target_image
+    )
 
 
 def example_pointnav_draw_target_birdseye_view_agent_on_border():
@@ -64,7 +72,10 @@ def example_pointnav_draw_target_birdseye_view_agent_on_border():
                 agent_radius_px=25,
             )
             imageio.imsave(
-                "pointnav_target_image_edge_%d.png" % ii, target_image
+                os.path.join(
+                    IMAGE_DIR, "pointnav_target_image_edge_%d.png" % ii
+                ),
+                target_image,
             )
 
 
@@ -94,7 +105,7 @@ def example_get_topdown_map():
         range_x[0] : range_x[1], range_y[0] : range_y[1]
     ]
     top_down_map = recolor_map[top_down_map]
-    imageio.imsave("top_down_map.png", top_down_map)
+    imageio.imsave(os.path.join(IMAGE_DIR, "top_down_map.png"), top_down_map)
 
 
 def main():
diff --git a/habitat/core/simulator.py b/habitat/core/simulator.py
index b2acb18f9..fddd6bfdd 100644
--- a/habitat/core/simulator.py
+++ b/habitat/core/simulator.py
@@ -285,6 +285,14 @@ class Simulator:
         """
         raise NotImplementedError
 
+    def is_navigable(self, point: List[float]) -> bool:
+        """Return true if the agent can stand at the specified point.
+
+        Args:
+            point: The point to check.
+        """
+        raise NotImplementedError
+
     def action_space_shortest_path(
         self, source: AgentState, targets: List[AgentState], agent_id: int = 0
     ) -> List[ShortestPathPoint]:
@@ -301,6 +309,39 @@ class Simulator:
         """
         raise NotImplementedError
 
+    def get_straight_shortest_path_points(
+        self, position_a: List[float], position_b: List[float]
+    ) -> List[List[float]]:
+        """Returns points along the geodesic (shortest) path between two points
+         irrespective of the angles between the waypoints.
+
+         Args:
+            position_a: The start point. This will be the first point in the
+                returned list.
+            position_b: The end point. This will be the last point in the
+                returned list.
+        Returns:
+            A list of waypoints (x, y, z) on the geodesic path between the two
+            points.
+         """
+
+        raise NotImplementedError
+
+    @property
+    def up_vector(self):
+        """The vector representing the direction upward (perpendicular to the
+        floor) from the global coordinate frame.
+        """
+        raise NotImplementedError
+
+    @property
+    def forward_vector(self):
+        """The forward direction in the global coordinate frame i.e. the
+        direction of forward movement for an agent with 0 degrees rotation in
+        the ground plane.
+        """
+        raise NotImplementedError
+
     def render(self, mode: str = "rgb") -> Any:
         raise NotImplementedError
 
diff --git a/habitat/sims/habitat_simulator.py b/habitat/sims/habitat_simulator.py
index 97385c5de..188ed2e45 100644
--- a/habitat/sims/habitat_simulator.py
+++ b/habitat/sims/habitat_simulator.py
@@ -357,9 +357,27 @@ class HabitatSim(habitat.Simulator):
 
         return shortest_path
 
+    @property
+    def up_vector(self):
+        return np.array([0.0, 1.0, 0.0])
+
+    @property
+    def forward_vector(self):
+        return -np.array([0.0, 0.0, 1.0])
+
+    def get_straight_shortest_path_points(self, position_a, position_b):
+        path = habitat_sim.ShortestPath()
+        path.requested_start = position_a
+        path.requested_end = position_b
+        self._sim.pathfinder.find_path(path)
+        return path.points
+
     def sample_navigable_point(self):
         return self._sim.pathfinder.get_random_navigable_point().tolist()
 
+    def is_navigable(self, point: List[float]):
+        return self._sim.pathfinder.is_navigable(point)
+
     def semantic_annotations(self):
         """
         Returns:
@@ -423,6 +441,7 @@ class HabitatSim(habitat.Simulator):
         position: List[float] = None,
         rotation: List[float] = None,
         agent_id: int = 0,
+        reset_sensors: bool = True,
     ) -> None:
         """Sets agent state similar to initialize_agent, but without agents
         creation.
@@ -433,12 +452,14 @@ class HabitatSim(habitat.Simulator):
             of unit quaternion (versor) representing agent 3D orientation,
             (https://en.wikipedia.org/wiki/Versor)
             agent_id: int identification of agent from multiagent setup.
+            reset_sensors: bool for if sensor changes (e.g. tilt) should be
+                reset).
         """
         agent = self._sim.get_agent(agent_id)
         state = self.get_agent_state(agent_id)
         state.position = position
         state.rotation = rotation
-        agent.set_state(state)
+        agent.set_state(state, reset_sensors)
 
         self._check_agent_position(position, agent_id)
 
diff --git a/habitat/tasks/nav/nav_task.py b/habitat/tasks/nav/nav_task.py
index 3f703a7df..abd33840e 100644
--- a/habitat/tasks/nav/nav_task.py
+++ b/habitat/tasks/nav/nav_task.py
@@ -377,16 +377,16 @@ class NavigationTask(habitat.EmbodiedTask):
         for measurement_name in task_config.MEASUREMENTS:
             measurement_cfg = getattr(task_config, measurement_name)
             is_valid_measurement = hasattr(
-                habitat.tasks.nav.nav_task,
-                measurement_cfg.TYPE,  # type: ignore
+                habitat.tasks.nav.nav_task,  # type: ignore
+                measurement_cfg.TYPE,
             )
             assert is_valid_measurement, "invalid measurement type {}".format(
                 measurement_cfg.TYPE
             )
             task_measurements.append(
                 getattr(
-                    habitat.tasks.nav.nav_task,
-                    measurement_cfg.TYPE,  # type: ignore
+                    habitat.tasks.nav.nav_task,  # type: ignore
+                    measurement_cfg.TYPE,
                 )(sim, measurement_cfg)
             )
         self.measurements = Measurements(task_measurements)
@@ -402,10 +402,8 @@ class NavigationTask(habitat.EmbodiedTask):
             )
             task_sensors.append(
                 getattr(
-                    habitat.tasks.nav.nav_task, sensor_cfg.TYPE
-                )(  # type: ignore
-                    sim, sensor_cfg
-                )
+                    habitat.tasks.nav.nav_task, sensor_cfg.TYPE  # type: ignore
+                )(sim, sensor_cfg)
             )
 
         self.sensor_suite = SensorSuite(task_sensors)
diff --git a/habitat/tasks/nav/shortest_path_follower.py b/habitat/tasks/nav/shortest_path_follower.py
new file mode 100644
index 000000000..b1b3b217d
--- /dev/null
+++ b/habitat/tasks/nav/shortest_path_follower.py
@@ -0,0 +1,202 @@
+from typing import Union
+
+import habitat_sim
+import numpy as np
+
+from habitat.sims.habitat_simulator import HabitatSim
+from habitat.sims.habitat_simulator import SIM_NAME_TO_ACTION, SimulatorActions
+from habitat.utils.geometry_utils import (
+    angle_between_quaternions,
+    quaternion_from_two_vectors,
+    quaternion_xyzw_to_wxyz,
+)
+
+
+EPSILON = 1e-6
+
+
+def action_to_one_hot(action: int) -> np.array:
+    one_hot = np.zeros(len(SIM_NAME_TO_ACTION), dtype=np.float32)
+    one_hot[action] = 1
+    return one_hot
+
+
+class ShortestPathFollower:
+    """Utility class for extracting the action on the shortest path to the
+        goal.
+    Args:
+        sim: HabitatSim instance.
+        goal_radius: Distance between the agent and the goal for it to be
+            considered successful.
+        return_one_hot: If true, returns a one-hot encoding of the action
+            (useful for training ML agents). If false, returns the
+            SimulatorAction.
+    """
+
+    def __init__(
+        self, sim: HabitatSim, goal_radius: float, return_one_hot: bool = True
+    ):
+        assert (
+            getattr(sim, "geodesic_distance", None) is not None
+        ), "{} must have a method called geodesic_distance".format(
+            type(sim).__name__
+        )
+
+        self._sim = sim
+        self._max_delta = self._sim.config.FORWARD_STEP_SIZE - EPSILON
+        self._goal_radius = goal_radius
+        self._step_size = self._sim.config.FORWARD_STEP_SIZE
+
+        self._mode = (
+            "geodesic_path"
+            if getattr(sim, "get_straight_shortest_path_points", None)
+            is not None
+            else "greedy"
+        )
+        self._return_one_hot = return_one_hot
+
+    def _get_return_value(
+        self, action: SimulatorActions
+    ) -> Union[SimulatorActions, np.array]:
+        if self._return_one_hot:
+            return action_to_one_hot(SIM_NAME_TO_ACTION[action.value])
+        else:
+            return action
+
+    def get_next_action(
+        self, goal_pos: np.array
+    ) -> Union[SimulatorActions, np.array]:
+        """Returns the next action along the shortest path."""
+        if (
+            np.linalg.norm(goal_pos - self._sim.get_agent_state().position)
+            <= self._goal_radius
+        ):
+            return self._get_return_value(SimulatorActions.STOP)
+
+        max_grad_dir = self._est_max_grad_dir(goal_pos)
+        if max_grad_dir is None:
+            return self._get_return_value(SimulatorActions.FORWARD)
+        return self._step_along_grad(max_grad_dir)
+
+    def _step_along_grad(
+        self, grad_dir: np.quaternion
+    ) -> Union[SimulatorActions, np.array]:
+        current_state = self._sim.get_agent_state()
+        alpha = angle_between_quaternions(
+            grad_dir, quaternion_xyzw_to_wxyz(current_state.rotation)
+        )
+        if alpha <= np.deg2rad(self._sim.config.TURN_ANGLE) + EPSILON:
+            return self._get_return_value(SimulatorActions.FORWARD)
+        else:
+            sim_action = SIM_NAME_TO_ACTION[SimulatorActions.LEFT.value]
+            self._sim.step(sim_action)
+            best_turn = (
+                SimulatorActions.LEFT
+                if (
+                    angle_between_quaternions(
+                        grad_dir,
+                        quaternion_xyzw_to_wxyz(
+                            self._sim.get_agent_state().rotation
+                        ),
+                    )
+                    < alpha
+                )
+                else SimulatorActions.RIGHT
+            )
+            self._reset_agent_state(current_state)
+            return self._get_return_value(best_turn)
+
+    def _reset_agent_state(self, state: habitat_sim.AgentState) -> None:
+        self._sim.set_agent_state(
+            state.position, state.rotation, reset_sensors=False
+        )
+
+    def _geo_dist(self, goal_pos: np.array) -> float:
+        return self._sim.geodesic_distance(
+            self._sim.get_agent_state().position, goal_pos
+        )
+
+    def _est_max_grad_dir(self, goal_pos: np.array) -> np.array:
+
+        current_state = self._sim.get_agent_state()
+        current_pos = current_state.position
+
+        if self.mode == "geodesic_path":
+            points = self._sim.get_straight_shortest_path_points(
+                self._sim.get_agent_state().position, goal_pos
+            )
+            # Add a little offset as things get weird if
+            # points[1] - points[0] is anti-parallel with forward
+            if len(points) < 2:
+                return None
+            max_grad_dir = quaternion_from_two_vectors(
+                self._sim.forward_vector,
+                points[1]
+                - points[0]
+                + EPSILON
+                * np.cross(self._sim.up_vector, self._sim.forward_vector),
+            )
+            max_grad_dir.x = 0
+            max_grad_dir = np.normalized(max_grad_dir)
+        else:
+            current_rotation = self._sim.get_agent_state().rotation
+            current_dist = self._geo_dist(goal_pos)
+
+            best_geodesic_delta = -2 * self._max_delta
+            best_rotation = current_rotation
+            for _ in range(0, 360, self._sim.config.TURN_ANGLE):
+                sim_action = SIM_NAME_TO_ACTION[SimulatorActions.FORWARD.value]
+                self._sim.step(sim_action)
+                new_delta = current_dist - self._geo_dist(goal_pos)
+
+                if new_delta > best_geodesic_delta:
+                    best_rotation = self._sim.get_agent_state().rotation
+                    best_geodesic_delta = new_delta
+
+                # If the best delta is within (1 - cos(TURN_ANGLE))% of the
+                # best delta (the step size), then we almost certainly have
+                # found the max grad dir and should just exit
+                if np.isclose(
+                    best_geodesic_delta,
+                    self._max_delta,
+                    rtol=1 - np.cos(np.deg2rad(self._sim.config.TURN_ANGLE)),
+                ):
+                    break
+
+                self._sim.set_agent_state(
+                    current_pos,
+                    self._sim.get_agent_state().rotation,
+                    reset_sensors=False,
+                )
+
+                sim_action = SIM_NAME_TO_ACTION[SimulatorActions.LEFT.value]
+                self._sim.step(sim_action)
+
+            self._reset_agent_state(current_state)
+
+            max_grad_dir = quaternion_xyzw_to_wxyz(best_rotation)
+
+        return max_grad_dir
+
+    @property
+    def mode(self):
+        return self._mode
+
+    @mode.setter
+    def mode(self, new_mode: str):
+        """Sets the mode for how the greedy follower determines the best next
+            step.
+        Args:
+            new_mode: geodesic_path indicates using the simulator's shortest
+                path algorithm to find points on the map to navigate between.
+                greedy indicates trying to move forward at all possible
+                orientations and selecting the one which reduces the geodesic
+                distance the most.
+        """
+        assert new_mode in {"geodesic_path", "greedy"}
+        if new_mode == "geodesic_path":
+            assert (
+                getattr(self._sim, "get_straight_shortest_path_points", None)
+                is not None
+            )
+        self._mode = new_mode
diff --git a/habitat/utils/__init__.py b/habitat/utils/__init__.py
index d1f5f3fbc..49c8ba90f 100644
--- a/habitat/utils/__init__.py
+++ b/habitat/utils/__init__.py
@@ -3,5 +3,6 @@
 # Copyright (c) Facebook, Inc. and its affiliates.
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
+from habitat.utils import geometry_utils
 
-__all__ = ["visualizations"]
+__all__ = ["visualizations", "geometry_utils"]
diff --git a/habitat/utils/geometry_utils.py b/habitat/utils/geometry_utils.py
new file mode 100644
index 000000000..e290f1d14
--- /dev/null
+++ b/habitat/utils/geometry_utils.py
@@ -0,0 +1,44 @@
+import numpy as np
+import quaternion
+
+
+EPSILON = 1e-8
+
+
+def angle_between_quaternions(q1: np.quaternion, q2: np.quaternion) -> float:
+    """Returns the angle (in radians) between two quaternions. This angle will
+    always be positive.
+    """
+    q1_inv = np.conjugate(q1)
+    dq = quaternion.as_float_array(q1_inv * q2)
+
+    return 2 * np.arctan2(np.linalg.norm(dq[1:]), np.abs(dq[0]))
+
+
+def quaternion_from_two_vectors(v0: np.array, v1: np.array) -> np.quaternion:
+    """Computes the quaternion representation of v1 using v0 as the origin."""
+    v0 = v0 / np.linalg.norm(v0)
+    v1 = v1 / np.linalg.norm(v1)
+    c = v0.dot(v1)
+    # Epsilon prevents issues at poles.
+    if c < (-1 + EPSILON):
+        c = max(c, -1)
+        m = np.stack([v0, v1], 0)
+        _, _, vh = np.linalg.svd(m, full_matrices=True)
+        axis = vh.T[:, 2]
+        w2 = (1 + c) * 0.5
+        w = np.sqrt(w2)
+        axis = axis * np.sqrt(1 - w2)
+        return np.quaternion(w, *axis)
+
+    axis = np.cross(v0, v1)
+    s = np.sqrt((1 + c) * 2)
+    return np.quaternion(s * 0.5, *(axis / s))
+
+
+def quaternion_xyzw_to_wxyz(v: np.array):
+    return np.quaternion(v[3], *v[0:3])
+
+
+def quaternion_wxyz_to_xyzw(v: np.array):
+    return np.quaternion(*v[1:4], v[0])
diff --git a/habitat/utils/visualizations/__init__.py b/habitat/utils/visualizations/__init__.py
index 5f27292dc..74b7b453c 100644
--- a/habitat/utils/visualizations/__init__.py
+++ b/habitat/utils/visualizations/__init__.py
@@ -5,5 +5,6 @@
 # LICENSE file in the root directory of this source tree.
 
 from habitat.utils.visualizations import maps
+from habitat.utils.visualizations import utils
 
-__all__ = ["maps"]
+__all__ = ["maps", "utils"]
diff --git a/habitat/utils/visualizations/maps.py b/habitat/utils/visualizations/maps.py
index b8f226df9..1790f06f8 100644
--- a/habitat/utils/visualizations/maps.py
+++ b/habitat/utils/visualizations/maps.py
@@ -47,8 +47,8 @@ def draw_agent(
     rotated_agent = scipy.ndimage.interpolation.rotate(
         AGENT_SPRITE, agent_rotation * -180 / np.pi
     )
-    # Rescale because rotation may result in larger image than original, but the
-    # agent sprite size should stay the same.
+    # Rescale because rotation may result in larger image than original, but
+    # the agent sprite size should stay the same.
     initial_agent_size = AGENT_SPRITE.shape[0]
     new_size = rotated_agent.shape[0]
     agent_size_px = max(
@@ -186,10 +186,10 @@ def _from_grid(
     coordinate_max: float,
     grid_resolution: Tuple[int, int],
 ) -> Tuple[float, float]:
-    """Inverse of _to_grid function. Return real world coordinate from gridworld
-    assuming top-left corner is the origin. The real world coordinates of lower
-    left corner are (coordinate_min, coordinate_min) and of top right corner
-    are (coordinate_max, coordinate_max)
+    """Inverse of _to_grid function. Return real world coordinate from
+    gridworld assuming top-left corner is the origin. The real world
+    coordinates of lower left corner are (coordinate_min, coordinate_min) and
+    of top right corner are (coordinate_max, coordinate_max)
     """
     grid_size = (
         (coordinate_max - coordinate_min) / grid_resolution[0],
@@ -233,8 +233,10 @@ def get_topdown_map(
 
     Args:
         sim: The simulator.
-        map_resolution: The resolution of map which will be computed and returned.
-        num_samples: The number of random navigable points which will be initially
+        map_resolution: The resolution of map which will be computed and
+            returned.
+        num_samples: The number of random navigable points which will be
+            initially
             sampled. For large environments it may need to be increased.
         draw_border: Whether to outline the border of the occupied spaces.
 
@@ -280,7 +282,7 @@ def get_topdown_map(
             realworld_x, realworld_y = _from_grid(
                 ii, jj, COORDINATE_MIN, COORDINATE_MAX, map_resolution
             )
-            valid_point = sim._sim.pathfinder.is_navigable(
+            valid_point = sim.is_navigable(
                 [realworld_x, start_height, realworld_y]
             )
             top_down_map[ii, jj] = 1 if valid_point else 0
diff --git a/requirements.txt b/requirements.txt
index 0232814bc..82f4441ee 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,7 @@
 gym==0.10.9
 numpy>=1.16.1
 yacs>=0.1.5
+numpy-quaternion>=2019.3.18.14.33.20
 # visualization optional dependencies
 imageio>=2.2.0
 opencv-python>=3.3.0
diff --git a/test/test_habitat_example.py b/test/test_habitat_example.py
index 227014d60..39131be24 100644
--- a/test/test_habitat_example.py
+++ b/test/test_habitat_example.py
@@ -10,6 +10,7 @@ import habitat
 from examples.example import example
 from habitat.datasets.pointnav.pointnav_dataset import PointNavDatasetV1
 from examples import visualization_examples
+from examples import shortest_path_follower_example
 
 
 def test_readme_example():
@@ -26,3 +27,11 @@ def test_visualizations_example():
     ):
         pytest.skip("Please download Habitat test data to data folder.")
     visualization_examples.main()
+
+
+def test_shortest_path_follower_example():
+    if not PointNavDatasetV1.check_config_paths_exist(
+        config=habitat.get_config().DATASET
+    ):
+        pytest.skip("Please download Habitat test data to data folder.")
+    shortest_path_follower_example.main()
-- 
GitLab