From 028fa12648bb1e6b2d86015fd258808257dd14de Mon Sep 17 00:00:00 2001
From: danielgordon10 <danielgordon10@gmail.com>
Date: Wed, 5 Jun 2019 19:06:47 -0700
Subject: [PATCH] Adds get_observations_at functionality (#80)

Utility method for getting the current observation. Sometimes you want the observation without having to make a step/reset call for whatever reason.
---
 habitat/core/simulator.py         | 24 +++++++++
 habitat/sims/habitat_simulator.py | 67 +++++++++++++++++++------
 test/test_sensors.py              | 81 ++++++++++++++++++++++++++++---
 3 files changed, 150 insertions(+), 22 deletions(-)

diff --git a/habitat/core/simulator.py b/habitat/core/simulator.py
index 69320bcc3..15f8c6d09 100644
--- a/habitat/core/simulator.py
+++ b/habitat/core/simulator.py
@@ -277,6 +277,30 @@ class Simulator:
         """
         raise NotImplementedError
 
+    def get_observations_at(
+        self,
+        position: List[float],
+        rotation: List[float],
+        keep_agent_at_new_pose: bool = False,
+    ) -> Optional[Observations]:
+        """Returns the observation.
+
+        Args:
+            position: list containing 3 entries for (x, y, z).
+            rotation: list with 4 entries for (x, y, z, w) elements of unit
+                quaternion (versor) representing agent 3D orientation,
+                (https://en.wikipedia.org/wiki/Versor)
+            keep_agent_at_new_pose: If true, the agent will stay at the
+                requested location. Otherwise it will return to where it
+                started.
+
+        Returns:
+            The observations or None if it was unable to get valid
+            observations.
+
+        """
+        raise NotImplementedError
+
     def sample_navigable_point(self) -> List[float]:
         """Samples a navigable point from the simulator. A point is defined as
         navigable if the agent can be initialized at that point.
diff --git a/habitat/sims/habitat_simulator.py b/habitat/sims/habitat_simulator.py
index 8050ca023..402e02f91 100644
--- a/habitat/sims/habitat_simulator.py
+++ b/habitat/sims/habitat_simulator.py
@@ -17,6 +17,7 @@ from habitat.core.logging import logger
 from habitat.core.simulator import (
     AgentState,
     DepthSensor,
+    Observations,
     RGBSensor,
     SemanticSensor,
     ShortestPathPoint,
@@ -407,7 +408,7 @@ class HabitatSim(habitat.Simulator):
         agent_config = getattr(self.config, agent_name)
         return agent_config
 
-    def get_agent_state(self, agent_id: int = 0):
+    def get_agent_state(self, agent_id: int = 0) -> habitat_sim.AgentState:
         assert agent_id == 0, "No support of multi agent in {} yet.".format(
             self.__class__.__name__
         )
@@ -415,43 +416,77 @@ class HabitatSim(habitat.Simulator):
 
     def set_agent_state(
         self,
-        position: List[float] = None,
-        rotation: List[float] = None,
+        position: List[float],
+        rotation: List[float],
         agent_id: int = 0,
         reset_sensors: bool = True,
-    ) -> None:
+    ) -> bool:
         """Sets agent state similar to initialize_agent, but without agents
-        creation.
+        creation. On failure to place the agent in the proper position, it is
+        moved back to its previous pose.
 
         Args:
-            position: numpy ndarray containing 3 entries for (x, y, z).
-            rotation: numpy ndarray with 4 entries for (x, y, z, w) elements
-            of unit quaternion (versor) representing agent 3D orientation,
-            (https://en.wikipedia.org/wiki/Versor)
+            position: list containing 3 entries for (x, y, z).
+            rotation: list with 4 entries for (x, y, z, w) elements of unit
+                quaternion (versor) representing agent 3D orientation,
+                (https://en.wikipedia.org/wiki/Versor)
             agent_id: int identification of agent from multiagent setup.
             reset_sensors: bool for if sensor changes (e.g. tilt) should be
                 reset).
+
+        Returns:
+            True if the set was successful else moves the agent back to its
+            original pose and returns false.
         """
         agent = self._sim.get_agent(agent_id)
-        state = self.get_agent_state(agent_id)
-        state.position = position
-        state.rotation = rotation
+        original_state = self.get_agent_state(agent_id)
+        new_state = self.get_agent_state(agent_id)
+        new_state.position = position
+        new_state.rotation = rotation
 
         # NB: The agent state also contains the sensor states in _absolute_
         # coordinates. In order to set the agent's body to a specific
         # location and have the sensors follow, we must not provide any
         # state for the sensors. This will cause them to follow the agent's
         # body
-        state.sensor_states = dict()
+        new_state.sensor_states = dict()
+
+        agent.set_state(new_state, reset_sensors)
+
+        if not self._check_agent_position(position, agent_id):
+            agent.set_state(original_state, reset_sensors)
+            return False
+        return True
+
+    def get_observations_at(
+        self,
+        position: List[float],
+        rotation: List[float],
+        keep_agent_at_new_pose: bool = False,
+    ) -> Optional[Observations]:
 
-        agent.set_state(state, reset_sensors)
+        current_state = self.get_agent_state()
 
-        self._check_agent_position(position, agent_id)
+        success = self.set_agent_state(position, rotation, reset_sensors=False)
+        if success:
+            sim_obs = self._sim.get_sensor_observations()
+            observations = self._sensor_suite.get_observations(sim_obs)
+            if not keep_agent_at_new_pose:
+                self.set_agent_state(
+                    current_state.position,
+                    current_state.rotation,
+                    reset_sensors=False,
+                )
+            return observations
+        else:
+            return None
 
     # TODO (maksymets): Remove check after simulator became stable
-    def _check_agent_position(self, position, agent_id=0):
+    def _check_agent_position(self, position, agent_id=0) -> bool:
         if not np.allclose(position, self.get_agent_state(agent_id).position):
             logger.info("Agent state diverges from configured start position.")
+            return False
+        return True
 
     def distance_to_closest_obstacle(self, position, max_search_radius=2.0):
         return self._sim.pathfinder.distance_to_closest_obstacle(
diff --git a/test/test_sensors.py b/test/test_sensors.py
index 771704089..86dc34d13 100644
--- a/test/test_sensors.py
+++ b/test/test_sensors.py
@@ -11,6 +11,8 @@ import numpy as np
 import pytest
 
 import habitat
+import numpy as np
+import pytest
 from habitat.config.default import get_config
 from habitat.sims.habitat_simulator import SimulatorActions
 from habitat.tasks.nav.nav_task import (
@@ -23,6 +25,12 @@ NON_STOP_ACTIONS = [
     v for v in range(len(SimulatorActions)) if v != SimulatorActions.STOP.value
 ]
 
+MOVEMENT_ACTIONS = [
+    SimulatorActions.MOVE_FORWARD.value,
+    SimulatorActions.TURN_LEFT.value,
+    SimulatorActions.TURN_RIGHT.value,
+]
+
 
 def _random_episode(env, config):
     random_location = env._sim.sample_navigable_point()
@@ -187,11 +195,72 @@ def test_static_pointgoal_sensor():
         )
     ]
 
-    obs = env.reset()
-    for _ in range(5):
-        env.step(np.random.choice(NON_STOP_ACTIONS))
-    static_pointgoal = obs["static_pointgoal"]
+    env.reset()
+    for _ in range(100):
+        obs = env.step(np.random.choice(NON_STOP_ACTIONS))
+        static_pointgoal = obs["static_pointgoal"]
+        # check to see if taking non-stop actions will affect static point_goal
+        assert np.allclose(static_pointgoal, expected_static_pointgoal)
+
+    env.close()
+
+
+def test_get_observations_at():
+    config = get_config()
+    if not os.path.exists(config.SIMULATOR.SCENE):
+        pytest.skip("Please download Habitat test data to data folder.")
+    config.defrost()
+    config.TASK.SENSORS = []
+    config.SIMULATOR.AGENT_0.SENSORS = ["RGB_SENSOR", "DEPTH_SENSOR"]
+    config.freeze()
+    env = habitat.Env(config=config, dataset=None)
+
+    # start position is checked for validity for the specific test scene
+    valid_start_position = [-1.3731, 0.08431, 8.60692]
+    expected_static_pointgoal = [0.1, 0.2, 0.3]
+    goal_position = np.add(valid_start_position, expected_static_pointgoal)
+
+    # starting quaternion is rotated 180 degree along z-axis, which
+    # corresponds to simulator using z-negative as forward action
+    start_rotation = [0, 0, 0, 1]
+
+    env.episodes = [
+        NavigationEpisode(
+            episode_id="0",
+            scene_id=config.SIMULATOR.SCENE,
+            start_position=valid_start_position,
+            start_rotation=start_rotation,
+            goals=[NavigationGoal(position=goal_position)],
+        )
+    ]
 
-    # check to see if taking non-stop actions will affect static point_goal
-    assert np.allclose(static_pointgoal, expected_static_pointgoal)
+    obs = env.reset()
+    start_state = env.sim.get_agent_state()
+    for _ in range(100):
+        # Note, this test will not currently work for camera change actions
+        # (look up/down), only for movement actions.
+        new_obs = env.step(np.random.choice(MOVEMENT_ACTIONS))
+        for key, val in new_obs.items():
+            agent_state = env.sim.get_agent_state()
+            if not (
+                np.allclose(agent_state.position, start_state.position)
+                and np.allclose(agent_state.rotation, start_state.rotation)
+            ):
+                assert not np.allclose(val, obs[key])
+        obs_at_point = env.sim.get_observations_at(
+            start_state.position,
+            start_state.rotation,
+            keep_agent_at_new_pose=False,
+        )
+        for key, val in obs_at_point.items():
+            assert np.allclose(val, obs[key])
+
+    obs_at_point = env.sim.get_observations_at(
+        start_state.position, start_state.rotation, keep_agent_at_new_pose=True
+    )
+    for key, val in obs_at_point.items():
+        assert np.allclose(val, obs[key])
+    agent_state = env.sim.get_agent_state()
+    assert np.allclose(agent_state.position, start_state.position)
+    assert np.allclose(agent_state.rotation, start_state.rotation)
     env.close()
-- 
GitLab