diff --git a/habitat_baselines/agents/ppo_agents.py b/habitat_baselines/agents/ppo_agents.py index 919298ffb6f89ee10e1fa83bd0dcb4b8926a02dc..551be4ff33c718595c98cd08cde31a37bf804ea9 100644 --- a/habitat_baselines/agents/ppo_agents.py +++ b/habitat_baselines/agents/ppo_agents.py @@ -17,7 +17,7 @@ from habitat.config import Config from habitat.config.default import get_config from habitat.core.agent import Agent from habitat_baselines.common.utils import batch_obs -from habitat_baselines.rl.ppo import Policy +from habitat_baselines.rl.ppo import PointNavBaselinePolicy def get_default_config(): @@ -75,7 +75,7 @@ class PPOAgent(Agent): if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True - self.actor_critic = Policy( + self.actor_critic = PointNavBaselinePolicy( observation_space=observation_spaces, action_space=action_spaces, hidden_size=self.hidden_size, @@ -88,7 +88,7 @@ class PPOAgent(Agent): # Filter only actor_critic weights self.actor_critic.load_state_dict( { - k.replace("actor_critic.", ""): v + k[len("actor_critic.") :]: v for k, v in ckpt["state_dict"].items() if "actor_critic" in k } @@ -101,12 +101,19 @@ class PPOAgent(Agent): self.test_recurrent_hidden_states = None self.not_done_masks = None + self.prev_actions = None def reset(self): self.test_recurrent_hidden_states = torch.zeros( - 1, self.hidden_size, device=self.device + self.actor_critic.net.num_recurrent_layers, + 1, + self.hidden_size, + device=self.device, ) self.not_done_masks = torch.zeros(1, 1, device=self.device) + self.prev_actions = torch.zeros( + 1, 1, dtype=torch.long, device=self.device + ) def act(self, observations): batch = batch_obs([observations]) @@ -117,11 +124,13 @@ class PPOAgent(Agent): _, actions, _, self.test_recurrent_hidden_states = self.actor_critic.act( batch, self.test_recurrent_hidden_states, + self.prev_actions, self.not_done_masks, deterministic=False, ) # Make masks not done till reset (end of episode) will be called self.not_done_masks = torch.ones(1, 1, device=self.device) + self.prev_actions.copy_(actions) return actions[0][0].item() diff --git a/habitat_baselines/common/rollout_storage.py b/habitat_baselines/common/rollout_storage.py index 3f9b5dd26269863b6b3ac62efde01f1483aef315..056c1b3cc79040c8d4f244103995ff1e9868ff42 100644 --- a/habitat_baselines/common/rollout_storage.py +++ b/habitat_baselines/common/rollout_storage.py @@ -21,6 +21,7 @@ class RolloutStorage: observation_space, action_space, recurrent_hidden_state_size, + num_recurrent_layers=1, ): self.observations = {} @@ -32,7 +33,10 @@ class RolloutStorage: ) self.recurrent_hidden_states = torch.zeros( - num_steps + 1, num_envs, recurrent_hidden_state_size + num_steps + 1, + num_recurrent_layers, + num_envs, + recurrent_hidden_state_size, ) self.rewards = torch.zeros(num_steps, num_envs, 1) @@ -46,8 +50,10 @@ class RolloutStorage: action_shape = action_space.shape[0] self.actions = torch.zeros(num_steps, num_envs, action_shape) + self.prev_actions = torch.zeros(num_steps + 1, num_envs, action_shape) if action_space.__class__.__name__ == "Discrete": self.actions = self.actions.long() + self.prev_actions = self.prev_actions.long() self.masks = torch.ones(num_steps + 1, num_envs, 1) @@ -64,6 +70,7 @@ class RolloutStorage: self.returns = self.returns.to(device) self.action_log_probs = self.action_log_probs.to(device) self.actions = self.actions.to(device) + self.prev_actions = self.prev_actions.to(device) self.masks = self.masks.to(device) def insert( @@ -84,6 +91,7 @@ class RolloutStorage: recurrent_hidden_states ) self.actions[self.step].copy_(actions) + self.prev_actions[self.step + 1].copy_(actions) self.action_log_probs[self.step].copy_(action_log_probs) self.value_preds[self.step].copy_(value_preds) self.rewards[self.step].copy_(rewards) @@ -97,6 +105,7 @@ class RolloutStorage: self.recurrent_hidden_states[0].copy_(self.recurrent_hidden_states[-1]) self.masks[0].copy_(self.masks[-1]) + self.prev_actions[0].copy_(self.prev_actions[-1]) def compute_returns(self, next_value, use_gae, gamma, tau): if use_gae: @@ -132,6 +141,7 @@ class RolloutStorage: recurrent_hidden_states_batch = [] actions_batch = [] + prev_actions_batch = [] value_preds_batch = [] return_batch = [] masks_batch = [] @@ -147,10 +157,11 @@ class RolloutStorage: ) recurrent_hidden_states_batch.append( - self.recurrent_hidden_states[0:1, ind] + self.recurrent_hidden_states[0, :, ind] ) actions_batch.append(self.actions[:, ind]) + prev_actions_batch.append(self.prev_actions[:-1, ind]) value_preds_batch.append(self.value_preds[:-1, ind]) return_batch.append(self.returns[:-1, ind]) masks_batch.append(self.masks[:-1, ind]) @@ -169,6 +180,7 @@ class RolloutStorage: ) actions_batch = torch.stack(actions_batch, 1) + prev_actions_batch = torch.stack(prev_actions_batch, 1) value_preds_batch = torch.stack(value_preds_batch, 1) return_batch = torch.stack(return_batch, 1) masks_batch = torch.stack(masks_batch, 1) @@ -177,10 +189,10 @@ class RolloutStorage: ) adv_targ = torch.stack(adv_targ, 1) - # States is just a (N, -1) tensor + # States is just a (num_recurrent_layers, N, -1) tensor recurrent_hidden_states_batch = torch.stack( recurrent_hidden_states_batch, 1 - ).view(N, -1) + ) # Flatten the (T, N, ...) tensors to (T * N, ...) for sensor in observations_batch: @@ -189,6 +201,7 @@ class RolloutStorage: ) actions_batch = self._flatten_helper(T, N, actions_batch) + prev_actions_batch = self._flatten_helper(T, N, prev_actions_batch) value_preds_batch = self._flatten_helper(T, N, value_preds_batch) return_batch = self._flatten_helper(T, N, return_batch) masks_batch = self._flatten_helper(T, N, masks_batch) @@ -201,6 +214,7 @@ class RolloutStorage: observations_batch, recurrent_hidden_states_batch, actions_batch, + prev_actions_batch, value_preds_batch, return_batch, masks_batch, diff --git a/habitat_baselines/rl/models/__init__.py b/habitat_baselines/rl/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/habitat_baselines/rl/models/rnn_state_encoder.py b/habitat_baselines/rl/models/rnn_state_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e256c2bba9a8f500fbcdab39fd2ad2420b628493 --- /dev/null +++ b/habitat_baselines/rl/models/rnn_state_encoder.py @@ -0,0 +1,143 @@ +import torch +import torch.nn as nn + + +class RNNStateEncoder(nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + rnn_type: str = "GRU", + ): + r"""An RNN for encoding the state in RL. + + Supports masking the hidden state during various timesteps in the forward lass + + Args: + input_size: The input size of the RNN + hidden_size: The hidden size + num_layers: The number of recurrent layers + rnn_type: The RNN cell type. Must be GRU or LSTM + """ + + super().__init__() + self._num_recurrent_layers = num_layers + self._rnn_type = rnn_type + + self.rnn = getattr(nn, rnn_type)( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + ) + + self.layer_init() + + def layer_init(self): + for name, param in self.rnn.named_parameters(): + if "weight" in name: + nn.init.orthogonal_(param) + elif "bias" in name: + nn.init.constant_(param, 0) + + @property + def num_recurrent_layers(self): + return self._num_recurrent_layers * ( + 2 if "LSTM" in self._rnn_type else 1 + ) + + def _pack_hidden(self, hidden_states): + if "LSTM" in self._rnn_type: + hidden_states = torch.cat( + [hidden_states[0], hidden_states[1]], dim=0 + ) + + return hidden_states + + def _unpack_hidden(self, hidden_states): + if "LSTM" in self._rnn_type: + hidden_states = ( + hidden_states[0 : self._num_recurrent_layers], + hidden_states[self._num_recurrent_layers :], + ) + + return hidden_states + + def _mask_hidden(self, hidden_states, masks): + if isinstance(hidden_states, tuple): + hidden_states = tuple(v * masks for v in hidden_states) + else: + hidden_states = masks * hidden_states + + return hidden_states + + def single_forward(self, x, hidden_states, masks): + r"""Forward for a non-sequence input + """ + hidden_states = self._unpack_hidden(hidden_states) + x, hidden_states = self.rnn( + x.unsqueeze(0), + self._mask_hidden(hidden_states, masks.unsqueeze(0)), + ) + x = x.squeeze(0) + hidden_states = self._pack_hidden(hidden_states) + return x, hidden_states + + def seq_forward(self, x, hidden_states, masks): + r"""Forward for a sequence of length T + + Args: + x: (T, N, -1) Tensor that has been flattened to (T * N, -1) + hidden_states: The starting hidden state. + masks: The masks to be applied to hidden state at every timestep. + A (T, N) tensor flatten to (T * N) + """ + # x is a (T, N, -1) tensor flattened to (T * N, -1) + n = hidden_states.size(1) + t = int(x.size(0) / n) + + # unflatten + x = x.view(t, n, x.size(1)) + masks = masks.view(t, n) + + # steps in sequence which have zero for any agent. Assume t=0 has + # a zero in it. + has_zeros = (masks[1:] == 0.0).any(dim=-1).nonzero().squeeze().cpu() + + # +1 to correct the masks[1:] + if has_zeros.dim() == 0: + has_zeros = [has_zeros.item() + 1] # handle scalar + else: + has_zeros = (has_zeros + 1).numpy().tolist() + + # add t=0 and t=T to the list + has_zeros = [0] + has_zeros + [t] + + hidden_states = self._unpack_hidden(hidden_states) + outputs = [] + for i in range(len(has_zeros) - 1): + # process steps that don't have any zeros in masks together + start_idx = has_zeros[i] + end_idx = has_zeros[i + 1] + + rnn_scores, hidden_states = self.rnn( + x[start_idx:end_idx], + self._mask_hidden( + hidden_states, masks[start_idx].view(1, -1, 1) + ), + ) + + outputs.append(rnn_scores) + + # x is a (T, N, -1) tensor + x = torch.cat(outputs, dim=0) + x = x.view(t * n, -1) # flatten + + hidden_states = self._pack_hidden(hidden_states) + return x, hidden_states + + def forward(self, x, hidden_states, masks): + if x.size(0) == hidden_states.size(1): + return self.single_forward(x, hidden_states, masks) + else: + return self.seq_forward(x, hidden_states, masks) diff --git a/habitat_baselines/rl/models/simple_cnn.py b/habitat_baselines/rl/models/simple_cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..4b6a8caee4e84d3a9fa7bd971939e130ecc8ef53 --- /dev/null +++ b/habitat_baselines/rl/models/simple_cnn.py @@ -0,0 +1,147 @@ +import numpy as np +import torch +import torch.nn as nn + +from habitat_baselines.common.utils import Flatten + + +class SimpleCNN(nn.Module): + r"""A Simple 3-Conv CNN followed by a fully connected layer + + Takes in observations and produces an embedding of the rgb and/or depth components + + Args: + observation_space: The observation_space of the agent + output_size: The size of the embedding vector + """ + + def __init__(self, observation_space, output_size): + super().__init__() + if "rgb" in observation_space.spaces: + self._n_input_rgb = observation_space.spaces["rgb"].shape[2] + else: + self._n_input_rgb = 0 + + if "depth" in observation_space.spaces: + self._n_input_depth = observation_space.spaces["depth"].shape[2] + else: + self._n_input_depth = 0 + + # kernel size for different CNN layers + self._cnn_layers_kernel_size = [(8, 8), (4, 4), (3, 3)] + + # strides for different CNN layers + self._cnn_layers_stride = [(4, 4), (2, 2), (1, 1)] + + if self._n_input_rgb > 0: + cnn_dims = np.array( + observation_space.spaces["rgb"].shape[:2], dtype=np.float32 + ) + elif self._n_input_depth > 0: + cnn_dims = np.array( + observation_space.spaces["depth"].shape[:2], dtype=np.float32 + ) + + if self.is_blind: + self.cnn = nn.Sequential() + else: + for kernel_size, stride in zip( + self._cnn_layers_kernel_size, self._cnn_layers_stride + ): + cnn_dims = self._conv_output_dim( + dimension=cnn_dims, + padding=np.array([0, 0], dtype=np.float32), + dilation=np.array([1, 1], dtype=np.float32), + kernel_size=np.array(kernel_size, dtype=np.float32), + stride=np.array(stride, dtype=np.float32), + ) + + self.cnn = nn.Sequential( + nn.Conv2d( + in_channels=self._n_input_rgb + self._n_input_depth, + out_channels=32, + kernel_size=self._cnn_layers_kernel_size[0], + stride=self._cnn_layers_stride[0], + ), + nn.ReLU(True), + nn.Conv2d( + in_channels=32, + out_channels=64, + kernel_size=self._cnn_layers_kernel_size[1], + stride=self._cnn_layers_stride[1], + ), + nn.ReLU(True), + nn.Conv2d( + in_channels=64, + out_channels=32, + kernel_size=self._cnn_layers_kernel_size[2], + stride=self._cnn_layers_stride[2], + ), + # nn.ReLU(True), + Flatten(), + nn.Linear(32 * cnn_dims[0] * cnn_dims[1], output_size), + nn.ReLU(True), + ) + + self.layer_init() + + def _conv_output_dim( + self, dimension, padding, dilation, kernel_size, stride + ): + r"""Calculates the output height and width based on the input + height and width to the convolution layer. + + ref: https://pytorch.org/docs/master/nn.html#torch.nn.Conv2d + """ + assert len(dimension) == 2 + out_dimension = [] + for i in range(len(dimension)): + out_dimension.append( + int( + np.floor( + ( + ( + dimension[i] + + 2 * padding[i] + - dilation[i] * (kernel_size[i] - 1) + - 1 + ) + / stride[i] + ) + + 1 + ) + ) + ) + return tuple(out_dimension) + + def layer_init(self): + for layer in self.cnn: + if isinstance(layer, (nn.Conv2d, nn.Linear)): + nn.init.kaiming_normal_( + layer.weight, nn.init.calculate_gain("relu") + ) + if layer.bias is not None: + nn.init.constant_(layer.bias, val=0) + + @property + def is_blind(self): + return self._n_input_rgb + self._n_input_depth == 0 + + def forward(self, observations): + cnn_input = [] + if self._n_input_rgb > 0: + rgb_observations = observations["rgb"] + # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] + rgb_observations = rgb_observations.permute(0, 3, 1, 2) + rgb_observations = rgb_observations / 255.0 # normalize RGB + cnn_input.append(rgb_observations) + + if self._n_input_depth > 0: + depth_observations = observations["depth"] + # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] + depth_observations = depth_observations.permute(0, 3, 1, 2) + cnn_input.append(depth_observations) + + cnn_input = torch.cat(cnn_input, dim=1) + + return self.cnn(cnn_input) diff --git a/habitat_baselines/rl/ppo/__init__.py b/habitat_baselines/rl/ppo/__init__.py index 9c00af21530364e25ef4b36e9badeb2efdde4e8c..febc3fd73942a8a403503ff26bda8a7a4f30d260 100644 --- a/habitat_baselines/rl/ppo/__init__.py +++ b/habitat_baselines/rl/ppo/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from habitat_baselines.rl.ppo.policy import Policy +from habitat_baselines.rl.ppo.policy import Net, PointNavBaselinePolicy, Policy from habitat_baselines.rl.ppo.ppo import PPO -__all__ = ["PPO", "Policy"] +__all__ = ["PPO", "Policy", "Net", "PointNavBaselinePolicy"] diff --git a/habitat_baselines/rl/ppo/policy.py b/habitat_baselines/rl/ppo/policy.py index 0be21e5afb86fb1207fe780fb4c91c40ac06c0c8..6aced23407545f1eeb1e76b23c225ef17b659ba5 100644 --- a/habitat_baselines/rl/ppo/policy.py +++ b/habitat_baselines/rl/ppo/policy.py @@ -3,43 +3,44 @@ # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import abc import numpy as np import torch import torch.nn as nn from habitat_baselines.common.utils import CategoricalNet, Flatten +from habitat_baselines.rl.models.rnn_state_encoder import RNNStateEncoder +from habitat_baselines.rl.models.simple_cnn import SimpleCNN class Policy(nn.Module): - def __init__( - self, - observation_space, - action_space, - goal_sensor_uuid, - hidden_size=512, - ): + def __init__(self, net, dim_actions): super().__init__() - self.dim_actions = action_space.n - self.goal_sensor_uuid = goal_sensor_uuid - self.net = Net( - observation_space=observation_space, - hidden_size=hidden_size, - goal_sensor_uuid=goal_sensor_uuid, - ) + self.net = net + self.dim_actions = dim_actions self.action_distribution = CategoricalNet( self.net.output_size, self.dim_actions ) + self.critic = CriticHead(self.net.output_size) def forward(self, *x): raise NotImplementedError - def act(self, observations, rnn_hidden_states, masks, deterministic=False): - value, actor_features, rnn_hidden_states = self.net( - observations, rnn_hidden_states, masks + def act( + self, + observations, + rnn_hidden_states, + prev_actions, + masks, + deterministic=False, + ): + features, rnn_hidden_states = self.net( + observations, rnn_hidden_states, prev_actions, masks ) - distribution = self.action_distribution(actor_features) + distribution = self.action_distribution(features) + value = self.critic(features) if deterministic: action = distribution.mode() @@ -50,15 +51,20 @@ class Policy(nn.Module): return value, action, action_log_probs, rnn_hidden_states - def get_value(self, observations, rnn_hidden_states, masks): - value, _, _ = self.net(observations, rnn_hidden_states, masks) - return value + def get_value(self, observations, rnn_hidden_states, prev_actions, masks): + features, _ = self.net( + observations, rnn_hidden_states, prev_actions, masks + ) + return self.critic(features) - def evaluate_actions(self, observations, rnn_hidden_states, masks, action): - value, actor_features, rnn_hidden_states = self.net( - observations, rnn_hidden_states, masks + def evaluate_actions( + self, observations, rnn_hidden_states, prev_actions, masks, action + ): + features, rnn_hidden_states = self.net( + observations, rnn_hidden_states, prev_actions, masks ) - distribution = self.action_distribution(actor_features) + distribution = self.action_distribution(features) + value = self.critic(features) action_log_probs = distribution.log_probs(action) distribution_entropy = distribution.entropy().mean() @@ -66,7 +72,57 @@ class Policy(nn.Module): return value, action_log_probs, distribution_entropy, rnn_hidden_states -class Net(nn.Module): +class CriticHead(nn.Module): + def __init__(self, input_size): + super().__init__() + self.fc = nn.Linear(input_size, 1) + nn.init.orthogonal_(self.fc.weight) + nn.init.constant_(self.fc.bias, 0) + + def forward(self, x): + return self.fc(x) + + +class PointNavBaselinePolicy(Policy): + def __init__( + self, + observation_space, + action_space, + goal_sensor_uuid, + hidden_size=512, + ): + super().__init__( + PointNavBaselineNet( + observation_space=observation_space, + hidden_size=hidden_size, + goal_sensor_uuid=goal_sensor_uuid, + ), + action_space.n, + ) + + +class Net(nn.Module, metaclass=abc.ABCMeta): + @abc.abstractmethod + def forward(self, observations, rnn_hidden_states, prev_actions, masks): + pass + + @property + @abc.abstractmethod + def output_size(self): + pass + + @property + @abc.abstractmethod + def num_recurrent_layers(self): + pass + + @property + @abc.abstractmethod + def is_blind(self): + pass + + +class PointNavBaselineNet(Net): r"""Network which passes the input image through CNN and concatenates goal vector with CNN's output and passes that through RNN. """ @@ -79,218 +135,39 @@ class Net(nn.Module): ].shape[0] self._hidden_size = hidden_size - self.cnn = self._init_perception_model(observation_space) - - if self.is_blind: - self.rnn = nn.GRU(self._n_input_goal, self._hidden_size) - else: - self.rnn = nn.GRU( - self.output_size + self._n_input_goal, self._hidden_size - ) + self.visual_encoder = SimpleCNN(observation_space, hidden_size) - self.critic_linear = nn.Linear(self._hidden_size, 1) + self.state_encoder = RNNStateEncoder( + (0 if self.is_blind else self._hidden_size) + self._n_input_goal, + self._hidden_size, + ) - self.layer_init() self.train() - def _init_perception_model(self, observation_space): - if "rgb" in observation_space.spaces: - self._n_input_rgb = observation_space.spaces["rgb"].shape[2] - else: - self._n_input_rgb = 0 - - if "depth" in observation_space.spaces: - self._n_input_depth = observation_space.spaces["depth"].shape[2] - else: - self._n_input_depth = 0 - - # kernel size for different CNN layers - self._cnn_layers_kernel_size = [(8, 8), (4, 4), (3, 3)] - - # strides for different CNN layers - self._cnn_layers_stride = [(4, 4), (2, 2), (1, 1)] - - if self._n_input_rgb > 0: - cnn_dims = np.array( - observation_space.spaces["rgb"].shape[:2], dtype=np.float32 - ) - elif self._n_input_depth > 0: - cnn_dims = np.array( - observation_space.spaces["depth"].shape[:2], dtype=np.float32 - ) - - if self.is_blind: - return nn.Sequential() - else: - for kernel_size, stride in zip( - self._cnn_layers_kernel_size, self._cnn_layers_stride - ): - cnn_dims = self._conv_output_dim( - dimension=cnn_dims, - padding=np.array([0, 0], dtype=np.float32), - dilation=np.array([1, 1], dtype=np.float32), - kernel_size=np.array(kernel_size, dtype=np.float32), - stride=np.array(stride, dtype=np.float32), - ) - - return nn.Sequential( - nn.Conv2d( - in_channels=self._n_input_rgb + self._n_input_depth, - out_channels=32, - kernel_size=self._cnn_layers_kernel_size[0], - stride=self._cnn_layers_stride[0], - ), - nn.ReLU(), - nn.Conv2d( - in_channels=32, - out_channels=64, - kernel_size=self._cnn_layers_kernel_size[1], - stride=self._cnn_layers_stride[1], - ), - nn.ReLU(), - nn.Conv2d( - in_channels=64, - out_channels=32, - kernel_size=self._cnn_layers_kernel_size[2], - stride=self._cnn_layers_stride[2], - ), - Flatten(), - nn.Linear(32 * cnn_dims[0] * cnn_dims[1], self._hidden_size), - nn.ReLU(), - ) - - def _conv_output_dim( - self, dimension, padding, dilation, kernel_size, stride - ): - r"""Calculates the output height and width based on the input - height and width to the convolution layer. - - ref: https://pytorch.org/docs/master/nn.html#torch.nn.Conv2d - """ - assert len(dimension) == 2 - out_dimension = [] - for i in range(len(dimension)): - out_dimension.append( - int( - np.floor( - ( - ( - dimension[i] - + 2 * padding[i] - - dilation[i] * (kernel_size[i] - 1) - - 1 - ) - / stride[i] - ) - + 1 - ) - ) - ) - return tuple(out_dimension) - @property def output_size(self): return self._hidden_size - def layer_init(self): - for layer in self.cnn: - if isinstance(layer, (nn.Conv2d, nn.Linear)): - nn.init.orthogonal_( - layer.weight, nn.init.calculate_gain("relu") - ) - nn.init.constant_(layer.bias, val=0) - - for name, param in self.rnn.named_parameters(): - if "weight" in name: - nn.init.orthogonal_(param) - elif "bias" in name: - nn.init.constant_(param, 0) - - nn.init.orthogonal_(self.critic_linear.weight, gain=1) - nn.init.constant_(self.critic_linear.bias, val=0) - - def forward_rnn(self, x, hidden_states, masks): - if x.size(0) == hidden_states.size(0): - x, hidden_states = self.rnn( - x.unsqueeze(0), (hidden_states * masks).unsqueeze(0) - ) - x = x.squeeze(0) - hidden_states = hidden_states.squeeze(0) - else: - # x is a (T, N, -1) tensor flattened to (T * N, -1) - n = hidden_states.size(0) - t = int(x.size(0) / n) - - # unflatten - x = x.view(t, n, x.size(1)) - masks = masks.view(t, n) - - # steps in sequence which have zero for any agent. Assume t=0 has - # a zero in it. - has_zeros = ( - (masks[1:] == 0.0).any(dim=-1).nonzero().squeeze().cpu() - ) - - # +1 to correct the masks[1:] - if has_zeros.dim() == 0: - has_zeros = [has_zeros.item() + 1] # handle scalar - else: - has_zeros = (has_zeros + 1).numpy().tolist() - - # add t=0 and t=T to the list - has_zeros = [0] + has_zeros + [t] - - hidden_states = hidden_states.unsqueeze(0) - outputs = [] - for i in range(len(has_zeros) - 1): - # process steps that don't have any zeros in masks together - start_idx = has_zeros[i] - end_idx = has_zeros[i + 1] - - rnn_scores, hidden_states = self.rnn( - x[start_idx:end_idx], - hidden_states * masks[start_idx].view(1, -1, 1), - ) - - outputs.append(rnn_scores) - - # x is a (T, N, -1) tensor - x = torch.cat(outputs, dim=0) - x = x.view(t * n, -1) # flatten - hidden_states = hidden_states.squeeze(0) - - return x, hidden_states - @property def is_blind(self): - return self._n_input_rgb + self._n_input_depth == 0 + return self.visual_encoder.is_blind - def forward_perception_model(self, observations): - cnn_input = [] - if self._n_input_rgb > 0: - rgb_observations = observations["rgb"] - # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] - rgb_observations = rgb_observations.permute(0, 3, 1, 2) - rgb_observations = rgb_observations / 255.0 # normalize RGB - cnn_input.append(rgb_observations) - - if self._n_input_depth > 0: - depth_observations = observations["depth"] - # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] - depth_observations = depth_observations.permute(0, 3, 1, 2) - cnn_input.append(depth_observations) - - cnn_input = torch.cat(cnn_input, dim=1) + @property + def num_recurrent_layers(self): + return self.state_encoder.num_recurrent_layers - return self.cnn(cnn_input) + def get_target_encoding(self, observations): + return observations[self.goal_sensor_uuid] - def forward(self, observations, rnn_hidden_states, masks): - x = observations[self.goal_sensor_uuid] + def forward(self, observations, rnn_hidden_states, prev_actions, masks): + target_encoding = self.get_target_encoding(observations) + x = [target_encoding] if not self.is_blind: - perception_embed = self.forward_perception_model(observations) - x = torch.cat([perception_embed, x], dim=1) + perception_embed = self.visual_encoder(observations) + x = [perception_embed] + x - x, rnn_hidden_states = self.forward_rnn(x, rnn_hidden_states, masks) + x = torch.cat(x, dim=1) + x, rnn_hidden_states = self.state_encoder(x, rnn_hidden_states, masks) - return self.critic_linear(x), x, rnn_hidden_states + return x, rnn_hidden_states diff --git a/habitat_baselines/rl/ppo/ppo.py b/habitat_baselines/rl/ppo/ppo.py index 78f113f89d7f167b0bcd20e556d8c9ec0e6e9b71..85bfd9e621d62375ee58155068283c9914d95bb4 100644 --- a/habitat_baselines/rl/ppo/ppo.py +++ b/habitat_baselines/rl/ppo/ppo.py @@ -24,6 +24,7 @@ class PPO(nn.Module): eps=None, max_grad_norm=None, use_clipped_value_loss=True, + use_normalized_advantage=True, ): super().__init__() @@ -41,15 +42,21 @@ class PPO(nn.Module): self.use_clipped_value_loss = use_clipped_value_loss self.optimizer = optim.Adam(actor_critic.parameters(), lr=lr, eps=eps) + self.device = next(actor_critic.parameters()).device + self.use_normalized_advantage = use_normalized_advantage def forward(self, *x): raise NotImplementedError - def update(self, rollouts): + def get_advantages(self, rollouts): advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1] - advantages = (advantages - advantages.mean()) / ( - advantages.std() + EPS_PPO - ) + if not self.use_normalized_advantage: + return advantages + + return (advantages - advantages.mean()) / (advantages.std() + EPS_PPO) + + def update(self, rollouts): + advantages = self.get_advantages(rollouts) value_loss_epoch = 0 action_loss_epoch = 0 @@ -65,6 +72,7 @@ class PPO(nn.Module): obs_batch, recurrent_hidden_states_batch, actions_batch, + prev_actions_batch, value_preds_batch, return_batch, masks_batch, @@ -81,6 +89,7 @@ class PPO(nn.Module): ) = self.actor_critic.evaluate_actions( obs_batch, recurrent_hidden_states_batch, + prev_actions_batch, masks_batch, actions_batch, ) @@ -113,15 +122,19 @@ class PPO(nn.Module): value_loss = 0.5 * (return_batch - values).pow(2).mean() self.optimizer.zero_grad() - ( + total_loss = ( value_loss * self.value_loss_coef + action_loss - dist_entropy * self.entropy_coef - ).backward() - nn.utils.clip_grad_norm_( - self.actor_critic.parameters(), self.max_grad_norm ) + + self.before_backward(total_loss) + total_loss.backward() + self.after_backward(total_loss) + + self.before_step() self.optimizer.step() + self.after_step() value_loss_epoch += value_loss.item() action_loss_epoch += action_loss.item() @@ -134,3 +147,17 @@ class PPO(nn.Module): dist_entropy_epoch /= num_updates return value_loss_epoch, action_loss_epoch, dist_entropy_epoch + + def before_backward(self, loss): + pass + + def after_backward(self, loss): + pass + + def before_step(self): + nn.utils.clip_grad_norm_( + self.actor_critic.parameters(), self.max_grad_norm + ) + + def after_step(self): + pass diff --git a/habitat_baselines/rl/ppo/ppo_trainer.py b/habitat_baselines/rl/ppo/ppo_trainer.py index 602b22e82fc7d9a0a52f12174de064245a5817ff..dbfab622ef1af6fddc4e88abcececcfcf9a08010 100644 --- a/habitat_baselines/rl/ppo/ppo_trainer.py +++ b/habitat_baselines/rl/ppo/ppo_trainer.py @@ -29,7 +29,7 @@ from habitat_baselines.common.utils import ( poll_checkpoint_folder, update_linear_schedule, ) -from habitat_baselines.rl.ppo import PPO, Policy +from habitat_baselines.rl.ppo import PPO, PointNavBaselinePolicy @baseline_registry.register_trainer(name="ppo") @@ -60,7 +60,7 @@ class PPOTrainer(BaseRLTrainer): """ logger.add_filehandler(ppo_cfg.log_file) - self.actor_critic = Policy( + self.actor_critic = PointNavBaselinePolicy( observation_space=self.envs.observation_spaces[0], action_space=self.envs.action_spaces[0], hidden_size=512, @@ -113,6 +113,96 @@ class PPOTrainer(BaseRLTrainer): """ return torch.load(checkpoint_path, map_location=self.device) + def _collect_rollout_step( + self, rollouts, current_episode_reward, episode_rewards, episode_counts + ): + pth_time = 0.0 + env_time = 0.0 + + t_sample_action = time.time() + # sample actions + with torch.no_grad(): + step_observation = { + k: v[rollouts.step] for k, v in rollouts.observations.items() + } + + ( + values, + actions, + actions_log_probs, + recurrent_hidden_states, + ) = self.actor_critic.act( + step_observation, + rollouts.recurrent_hidden_states[rollouts.step], + rollouts.prev_actions[rollouts.step], + rollouts.masks[rollouts.step], + ) + + pth_time += time.time() - t_sample_action + + t_step_env = time.time() + + outputs = self.envs.step([a[0].item() for a in actions]) + observations, rewards, dones, infos = [list(x) for x in zip(*outputs)] + + env_time += time.time() - t_step_env + + t_update_stats = time.time() + batch = batch_obs(observations) + rewards = torch.tensor(rewards, dtype=torch.float) + rewards = rewards.unsqueeze(1) + + masks = torch.tensor( + [[0.0] if done else [1.0] for done in dones], dtype=torch.float + ) + + current_episode_reward += rewards + episode_rewards += (1 - masks) * current_episode_reward + episode_counts += 1 - masks + current_episode_reward *= masks + + rollouts.insert( + batch, + recurrent_hidden_states, + actions, + actions_log_probs, + values, + rewards, + masks, + ) + + pth_time += time.time() - t_update_stats + + return pth_time, env_time, self.envs.num_envs + + def _update_agent(self, ppo_cfg, rollouts): + t_update_model = time.time() + with torch.no_grad(): + last_observation = { + k: v[-1] for k, v in rollouts.observations.items() + } + next_value = self.actor_critic.get_value( + last_observation, + rollouts.recurrent_hidden_states[-1], + rollouts.prev_actions[-1], + rollouts.masks[-1], + ).detach() + + rollouts.compute_returns( + next_value, ppo_cfg.use_gae, ppo_cfg.gamma, ppo_cfg.tau + ) + + value_loss, action_loss, dist_entropy = self.agent.update(rollouts) + + rollouts.after_update() + + return ( + time.time() - t_update_model, + value_loss, + action_loss, + dist_entropy, + ) + def train(self) -> None: r"""Main method for training PPO. @@ -184,88 +274,24 @@ class PPOTrainer(BaseRLTrainer): ) for step in range(ppo_cfg.num_steps): - t_sample_action = time.time() - # sample actions - with torch.no_grad(): - step_observation = { - k: v[step] - for k, v in rollouts.observations.items() - } - - ( - values, - actions, - actions_log_probs, - recurrent_hidden_states, - ) = self.actor_critic.act( - step_observation, - rollouts.recurrent_hidden_states[step], - rollouts.masks[step], - ) - pth_time += time.time() - t_sample_action - - t_step_env = time.time() - - outputs = self.envs.step([a[0].item() for a in actions]) - observations, rewards, dones, infos = [ - list(x) for x in zip(*outputs) - ] - - env_time += time.time() - t_step_env - - t_update_stats = time.time() - batch = batch_obs(observations) - rewards = torch.tensor(rewards, dtype=torch.float) - rewards = rewards.unsqueeze(1) - - masks = torch.tensor( - [[0.0] if done else [1.0] for done in dones], - dtype=torch.float, - ) - - current_episode_reward += rewards - episode_rewards += (1 - masks) * current_episode_reward - episode_counts += 1 - masks - current_episode_reward *= masks - - rollouts.insert( - batch, - recurrent_hidden_states, - actions, - actions_log_probs, - values, - rewards, - masks, + delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step( + rollouts, + current_episode_reward, + episode_rewards, + episode_counts, ) + pth_time += delta_pth_time + env_time += delta_env_time + count_steps += delta_steps - count_steps += self.envs.num_envs - pth_time += time.time() - t_update_stats + delta_pth_time, value_loss, action_loss, dist_entropy = self._update_agent( + ppo_cfg, rollouts + ) + pth_time += delta_pth_time window_episode_reward.append(episode_rewards.clone()) window_episode_counts.append(episode_counts.clone()) - t_update_model = time.time() - with torch.no_grad(): - last_observation = { - k: v[-1] for k, v in rollouts.observations.items() - } - next_value = self.actor_critic.get_value( - last_observation, - rollouts.recurrent_hidden_states[-1], - rollouts.masks[-1], - ).detach() - - rollouts.compute_returns( - next_value, ppo_cfg.use_gae, ppo_cfg.gamma, ppo_cfg.tau - ) - - value_loss, action_loss, dist_entropy = self.agent.update( - rollouts - ) - - rollouts.after_update() - pth_time += time.time() - t_update_model - losses = [value_loss, action_loss] stats = zip( ["count", "reward"], @@ -434,7 +460,13 @@ class PPOTrainer(BaseRLTrainer): ) test_recurrent_hidden_states = torch.zeros( - ppo_cfg.num_processes, ppo_cfg.hidden_size, device=self.device + actor_critic.net.num_recurrent_layers, + ppo_cfg.num_processes, + ppo_cfg.hidden_size, + device=self.device, + ) + prev_actions = torch.zeros( + args.num_processes, 1, device=device, dtype=torch.long ) not_done_masks = torch.zeros( ppo_cfg.num_processes, 1, device=self.device @@ -457,10 +489,13 @@ class PPOTrainer(BaseRLTrainer): _, actions, _, test_recurrent_hidden_states = self.actor_critic.act( batch, test_recurrent_hidden_states, + prev_actions, not_done_masks, deterministic=False, ) + prev_actions.copy_(actions) + outputs = self.envs.step([a[0].item() for a in actions]) observations, rewards, dones, infos = [ @@ -533,6 +568,7 @@ class PPOTrainer(BaseRLTrainer): ] not_done_masks = not_done_masks[state_index] current_episode_reward = current_episode_reward[state_index] + prev_actions = prev_actions[state_index] for k, v in batch.items(): batch[k] = v[state_index]