Skip to content
Snippets Groups Projects
Commit 4c25f0b3 authored by Arjun Majumdar's avatar Arjun Majumdar Committed by Oleksandr
Browse files

Fix contains method in space classes (#272)

The contains methods in the three habitat.core.spaces classes (EmptySpace, ActionSpace, and ListSpace) have errors that are described in #271. This pull request fixes those errors, corrects the documentation for the ActionSpace class, and adds a __repr__ to the EmptySpace class.
closes #271
parent 1d374383
No related branches found
No related tags found
No related merge requests found
......@@ -21,8 +21,13 @@ class EmptySpace(Space):
return None
def contains(self, x):
if x is None:
return True
return False
def __repr__(self):
return "EmptySpace()"
class ActionSpace(gym.spaces.Dict):
"""
......@@ -30,14 +35,13 @@ class ActionSpace(gym.spaces.Dict):
.. code:: py
self.observation_space = spaces.ActionSpace(
self.observation_space = spaces.ActionSpace({
"move": spaces.Dict({
"position": spaces.Discrete(2),
"velocity": spaces.Discrete(3)
},
"move_forward": EmptySpace,
)
)
}),
"move_forward": EmptySpace(),
})
"""
def __init__(self, spaces):
......@@ -59,9 +63,11 @@ class ActionSpace(gym.spaces.Dict):
}
def contains(self, x):
if not isinstance(x, dict) and {"action", "action_args"} not in x:
if not isinstance(x, dict) or "action" not in x:
return False
if x["action"] not in self.spaces:
return False
if not self.spaces[x["action"]].contains(x["action_args"]):
if not self.spaces[x["action"]].contains(x.get("action_args", None)):
return False
return True
......@@ -100,7 +106,7 @@ class ListSpace(Space):
if not isinstance(x, Sized):
return False
if self.min_seq_length <= len(x) <= self.max_seq_length:
if not (self.min_seq_length <= len(x) <= self.max_seq_length):
return False
return all([self.space.contains(el) for el in x])
......
......@@ -13,8 +13,8 @@ import habitat
from habitat.utils.test_utils import sample_non_stop_action
CFG_TEST = "configs/test/habitat_all_sensors_test.yaml"
TELEPORT_POSITION = [-3.2890449, 0.15067159, 11.124366]
TELEPORT_ROTATION = [0.92035, 0, -0.39109465, 0]
TELEPORT_POSITION = np.array([-3.2890449, 0.15067159, 11.124366])
TELEPORT_ROTATION = np.array([0.92035, 0, -0.39109465, 0])
def test_task_actions():
......@@ -25,15 +25,15 @@ def test_task_actions():
env = habitat.Env(config=config)
env.reset()
env.step(
action={
"action": "TELEPORT",
"action_args": {
"position": TELEPORT_POSITION,
"rotation": TELEPORT_ROTATION,
},
}
)
action = {
"action": "TELEPORT",
"action_args": {
"position": TELEPORT_POSITION,
"rotation": TELEPORT_ROTATION,
},
}
assert env.action_space.contains(action)
env.step(action)
agent_state = env.sim.get_agent_state()
assert np.allclose(
np.array(TELEPORT_POSITION, dtype=np.float32), agent_state.position
......@@ -56,6 +56,7 @@ def test_task_actions_sampling_for_teleport():
env.reset()
while not env.episode_over:
action = sample_non_stop_action(env.action_space)
assert env.action_space.contains(action)
habitat.logger.info(
f"Action : "
f"{action['action']}, "
......@@ -89,6 +90,7 @@ def test_task_actions_sampling(config_file):
env.reset()
while not env.episode_over:
action = sample_non_stop_action(env.action_space)
assert env.action_space.contains(action)
habitat.logger.info(
f"Action : "
f"{action['action']}, "
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import gym
import pytest
from habitat.core.spaces import ActionSpace, EmptySpace, ListSpace
def test_empty_space():
space = EmptySpace()
assert space.contains(space.sample())
assert space.contains(None)
assert not space.contains(0)
def test_action_space():
space = ActionSpace(
{
"move": gym.spaces.Dict(
{
"position": gym.spaces.Discrete(2),
"velocity": gym.spaces.Discrete(3),
}
),
"move_forward": EmptySpace(),
}
)
assert space.contains(space.sample())
assert space.contains(
{"action": "move", "action_args": {"position": 0, "velocity": 1}}
)
assert space.contains({"action": "move_forward"})
assert not space.contains([0, 1, 2])
assert not space.contains({"zero": None})
assert not space.contains({"action": "bad"})
assert not space.contains({"action": "move"})
assert not space.contains(
{"action": "move", "action_args": {"position": 0}}
)
assert not space.contains(
{"action": "move_forward", "action_args": {"position": 0}}
)
def test_list_space():
space = ListSpace(gym.spaces.Discrete(2), 5, 10)
assert space.contains(space.sample())
assert not space.contains(0)
assert not space.contains([0] * 4)
assert not space.contains([2] * 5)
assert not space.contains([1] * 11)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment