Skip to content
Snippets Groups Projects
Commit eb817df2 authored by Erik Wijmans's avatar Erik Wijmans Committed by Oleksandr
Browse files

Use attrs to define structs (#102)

use attrs for defining things like Episodes and Goals as inheritance works beautifully and classes get a lot of helpful methods for free.
parent 6a694539
No related branches found
No related tags found
No related merge requests found
...@@ -4,13 +4,17 @@ ...@@ -4,13 +4,17 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import attr
import copy import copy
import json import json
from typing import Dict, List, Type, TypeVar, Generic, Optional, Callable from typing import Dict, List, Type, TypeVar, Generic, Optional, Callable
from habitat.core.utils import not_none_validator
import numpy as np import numpy as np
@attr.s(auto_attribs=True, kw_only=True)
class Episode: class Episode:
"""Base class for episode specification that includes initial position and """Base class for episode specification that includes initial position and
rotation of agent, scene id, episode. This information is provided by rotation of agent, scene id, episode. This information is provided by
...@@ -28,29 +32,16 @@ class Episode: ...@@ -28,29 +32,16 @@ class Episode:
axes. axes.
""" """
episode_id: str episode_id: str = attr.ib(default=None, validator=not_none_validator)
scene_id: str scene_id: str = attr.ib(default=None, validator=not_none_validator)
start_position: List[float] start_position: List[float] = attr.ib(
start_rotation: List[float] default=None, validator=not_none_validator
)
start_rotation: List[float] = attr.ib(
default=None, validator=not_none_validator
)
info: Optional[Dict[str, str]] = None info: Optional[Dict[str, str]] = None
def __init__(
self,
episode_id: str,
scene_id: str,
start_position: List[float],
start_rotation: List[float],
info: Optional[Dict[str, str]] = None,
) -> None:
self.episode_id = episode_id
self.scene_id = scene_id
self.start_position = start_position
self.start_rotation = start_rotation
self.info = info
def __str__(self):
return str(self.__dict__)
T = TypeVar("T", Episode, Type[Episode]) T = TypeVar("T", Episode, Type[Episode])
......
...@@ -40,3 +40,8 @@ def tile_images(images: List[np.ndarray]) -> np.ndarray: ...@@ -40,3 +40,8 @@ def tile_images(images: List[np.ndarray]) -> np.ndarray:
new_height * height, new_width * width, n_channels new_height * height, new_width * width, n_channels
) )
return out_image return out_image
def not_none_validator(self, attribute, value):
if value is None:
raise ValueError(f"Argument '{attribute.name}' must be set")
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
from typing import Dict, Optional from typing import Dict, Optional
import attr
import numpy as np import numpy as np
from gym import spaces from gym import spaces
from habitat.core.simulator import ( from habitat.core.simulator import (
...@@ -14,25 +16,18 @@ from habitat.core.simulator import ( ...@@ -14,25 +16,18 @@ from habitat.core.simulator import (
SensorSuite, SensorSuite,
Observations, Observations,
) )
from habitat.core.utils import not_none_validator
from habitat.tasks.nav.nav_task import NavigationEpisode, NavigationTask from habitat.tasks.nav.nav_task import NavigationEpisode, NavigationTask
@attr.s(auto_attribs=True)
class QuestionData: class QuestionData:
question_text: str question_text: str
answer_text: Optional[str] answer_text: str
question_type: Optional[str] question_type: Optional[str] = None
def __init__(
self,
question_text: str,
question_type: str,
answer_text: Optional[str] = None,
) -> None:
self.question_text = question_text
self.answer_text = answer_text
self.question_type = question_type
@attr.s(auto_attribs=True, kw_only=True)
class EQAEpisode(NavigationEpisode): class EQAEpisode(NavigationEpisode):
"""Specification of episode that includes initial position and rotation of """Specification of episode that includes initial position and rotation of
agent, goal, question specifications and optional shortest paths. agent, goal, question specifications and optional shortest paths.
...@@ -47,11 +42,9 @@ class EQAEpisode(NavigationEpisode): ...@@ -47,11 +42,9 @@ class EQAEpisode(NavigationEpisode):
question: question related to goal object. question: question related to goal object.
""" """
question: QuestionData question: QuestionData = attr.ib(
default=None, validator=not_none_validator
def __init__(self, question: QuestionData, **kwargs) -> None: )
super().__init__(**kwargs)
self.question = question
class QuestionSensor(Sensor): class QuestionSensor(Sensor):
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
from typing import Any, List, Optional, Type from typing import Any, List, Optional, Type
import cv2 import cv2
import attr
import numpy as np import numpy as np
from gym import spaces from gym import spaces
...@@ -20,6 +21,7 @@ from habitat.core.simulator import ( ...@@ -20,6 +21,7 @@ from habitat.core.simulator import (
SensorTypes, SensorTypes,
SensorSuite, SensorSuite,
) )
from habitat.core.utils import not_none_validator
from habitat.tasks.utils import cartesian_to_polar, quaternion_rotate_vector from habitat.tasks.utils import cartesian_to_polar, quaternion_rotate_vector
from habitat.utils.visualizations import maps from habitat.utils.visualizations import maps
...@@ -47,63 +49,38 @@ def merge_sim_episode_config( ...@@ -47,63 +49,38 @@ def merge_sim_episode_config(
return sim_config return sim_config
@attr.s(auto_attribs=True, kw_only=True)
class NavigationGoal: class NavigationGoal:
"""Base class for a goal specification hierarchy. """Base class for a goal specification hierarchy.
""" """
position: List[float] position: List[float] = attr.ib(default=None, validator=not_none_validator)
radius: Optional[float] radius: Optional[float] = None
def __init__(
self, position: List[float], radius: Optional[float] = None, **kwargs
) -> None:
self.position = position
self.radius = radius
@attr.s(auto_attribs=True, kw_only=True)
class ObjectGoal(NavigationGoal): class ObjectGoal(NavigationGoal):
"""Object goal that can be specified by object_id or position or object """Object goal that can be specified by object_id or position or object
category. category.
""" """
object_id: str object_id: str = attr.ib(default=None, validator=not_none_validator)
object_name: Optional[str] object_name: Optional[str] = None
object_category: Optional[str] object_category: Optional[str] = None
room_id: Optional[str] room_id: Optional[str] = None
room_name: Optional[str] room_name: Optional[str] = None
def __init__(
self,
object_id: str,
room_id: Optional[str] = None,
object_name: Optional[str] = None,
object_category: Optional[str] = None,
room_name: Optional[str] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
self.object_id = object_id
self.object_name = object_name
self.object_category = object_category
self.room_id = room_id
self.room_name = room_name
@attr.s(auto_attribs=True, kw_only=True)
class RoomGoal(NavigationGoal): class RoomGoal(NavigationGoal):
"""Room goal that can be specified by room_id or position with radius. """Room goal that can be specified by room_id or position with radius.
""" """
room_id: str room_id: str = attr.ib(default=None, validator=not_none_validator)
room_name: Optional[str] room_name: Optional[str] = None
def __init__(
self, room_id: str, room_name: Optional[str] = None, **kwargs
) -> None:
super().__init__(**kwargs) # type: ignore
self.room_id = room_id
self.room_name = room_name
@attr.s(auto_attribs=True, kw_only=True)
class NavigationEpisode(Episode): class NavigationEpisode(Episode):
"""Class for episode specification that includes initial position and """Class for episode specification that includes initial position and
rotation of agent, scene name, goal and optional shortest paths. An rotation of agent, scene name, goal and optional shortest paths. An
...@@ -121,21 +98,11 @@ class NavigationEpisode(Episode): ...@@ -121,21 +98,11 @@ class NavigationEpisode(Episode):
shortest_paths: list containing shortest paths to goals shortest_paths: list containing shortest paths to goals
""" """
goals: List[NavigationGoal] goals: List[NavigationGoal] = attr.ib(
start_room: Optional[str] default=None, validator=not_none_validator
shortest_paths: Optional[List[ShortestPathPoint]] )
start_room: Optional[str] = None
def __init__( shortest_paths: Optional[List[ShortestPathPoint]] = None
self,
goals: List[NavigationGoal],
start_room: Optional[str] = None,
shortest_paths: Optional[List[ShortestPathPoint]] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
self.goals = goals
self.shortest_paths = shortest_paths
self.start_room = start_room
class PointGoalSensor(habitat.Sensor): class PointGoalSensor(habitat.Sensor):
......
...@@ -2,6 +2,7 @@ gym==0.10.9 ...@@ -2,6 +2,7 @@ gym==0.10.9
numpy>=1.16.1 numpy>=1.16.1
yacs>=0.1.5 yacs>=0.1.5
numpy-quaternion>=2019.3.18.14.33.20 numpy-quaternion>=2019.3.18.14.33.20
attrs>=19.1.0
opencv-python>=3.3.0 opencv-python>=3.3.0
# visualization optional dependencies # visualization optional dependencies
imageio>=2.2.0 imageio>=2.2.0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment