diff --git a/habitat_baselines/agents/ppo_agents.py b/habitat_baselines/agents/ppo_agents.py
index 919298ffb6f89ee10e1fa83bd0dcb4b8926a02dc..551be4ff33c718595c98cd08cde31a37bf804ea9 100644
--- a/habitat_baselines/agents/ppo_agents.py
+++ b/habitat_baselines/agents/ppo_agents.py
@@ -17,7 +17,7 @@ from habitat.config import Config
 from habitat.config.default import get_config
 from habitat.core.agent import Agent
 from habitat_baselines.common.utils import batch_obs
-from habitat_baselines.rl.ppo import Policy
+from habitat_baselines.rl.ppo import PointNavBaselinePolicy
 
 
 def get_default_config():
@@ -75,7 +75,7 @@ class PPOAgent(Agent):
         if torch.cuda.is_available():
             torch.backends.cudnn.deterministic = True
 
-        self.actor_critic = Policy(
+        self.actor_critic = PointNavBaselinePolicy(
             observation_space=observation_spaces,
             action_space=action_spaces,
             hidden_size=self.hidden_size,
@@ -88,7 +88,7 @@ class PPOAgent(Agent):
             #  Filter only actor_critic weights
             self.actor_critic.load_state_dict(
                 {
-                    k.replace("actor_critic.", ""): v
+                    k[len("actor_critic.") :]: v
                     for k, v in ckpt["state_dict"].items()
                     if "actor_critic" in k
                 }
@@ -101,12 +101,19 @@ class PPOAgent(Agent):
 
         self.test_recurrent_hidden_states = None
         self.not_done_masks = None
+        self.prev_actions = None
 
     def reset(self):
         self.test_recurrent_hidden_states = torch.zeros(
-            1, self.hidden_size, device=self.device
+            self.actor_critic.net.num_recurrent_layers,
+            1,
+            self.hidden_size,
+            device=self.device,
         )
         self.not_done_masks = torch.zeros(1, 1, device=self.device)
+        self.prev_actions = torch.zeros(
+            1, 1, dtype=torch.long, device=self.device
+        )
 
     def act(self, observations):
         batch = batch_obs([observations])
@@ -117,11 +124,13 @@ class PPOAgent(Agent):
             _, actions, _, self.test_recurrent_hidden_states = self.actor_critic.act(
                 batch,
                 self.test_recurrent_hidden_states,
+                self.prev_actions,
                 self.not_done_masks,
                 deterministic=False,
             )
             #  Make masks not done till reset (end of episode) will be called
             self.not_done_masks = torch.ones(1, 1, device=self.device)
+            self.prev_actions.copy_(actions)
 
         return actions[0][0].item()
 
diff --git a/habitat_baselines/common/rollout_storage.py b/habitat_baselines/common/rollout_storage.py
index 3f9b5dd26269863b6b3ac62efde01f1483aef315..056c1b3cc79040c8d4f244103995ff1e9868ff42 100644
--- a/habitat_baselines/common/rollout_storage.py
+++ b/habitat_baselines/common/rollout_storage.py
@@ -21,6 +21,7 @@ class RolloutStorage:
         observation_space,
         action_space,
         recurrent_hidden_state_size,
+        num_recurrent_layers=1,
     ):
         self.observations = {}
 
@@ -32,7 +33,10 @@ class RolloutStorage:
             )
 
         self.recurrent_hidden_states = torch.zeros(
-            num_steps + 1, num_envs, recurrent_hidden_state_size
+            num_steps + 1,
+            num_recurrent_layers,
+            num_envs,
+            recurrent_hidden_state_size,
         )
 
         self.rewards = torch.zeros(num_steps, num_envs, 1)
@@ -46,8 +50,10 @@ class RolloutStorage:
             action_shape = action_space.shape[0]
 
         self.actions = torch.zeros(num_steps, num_envs, action_shape)
+        self.prev_actions = torch.zeros(num_steps + 1, num_envs, action_shape)
         if action_space.__class__.__name__ == "Discrete":
             self.actions = self.actions.long()
+            self.prev_actions = self.prev_actions.long()
 
         self.masks = torch.ones(num_steps + 1, num_envs, 1)
 
@@ -64,6 +70,7 @@ class RolloutStorage:
         self.returns = self.returns.to(device)
         self.action_log_probs = self.action_log_probs.to(device)
         self.actions = self.actions.to(device)
+        self.prev_actions = self.prev_actions.to(device)
         self.masks = self.masks.to(device)
 
     def insert(
@@ -84,6 +91,7 @@ class RolloutStorage:
             recurrent_hidden_states
         )
         self.actions[self.step].copy_(actions)
+        self.prev_actions[self.step + 1].copy_(actions)
         self.action_log_probs[self.step].copy_(action_log_probs)
         self.value_preds[self.step].copy_(value_preds)
         self.rewards[self.step].copy_(rewards)
@@ -97,6 +105,7 @@ class RolloutStorage:
 
         self.recurrent_hidden_states[0].copy_(self.recurrent_hidden_states[-1])
         self.masks[0].copy_(self.masks[-1])
+        self.prev_actions[0].copy_(self.prev_actions[-1])
 
     def compute_returns(self, next_value, use_gae, gamma, tau):
         if use_gae:
@@ -132,6 +141,7 @@ class RolloutStorage:
 
             recurrent_hidden_states_batch = []
             actions_batch = []
+            prev_actions_batch = []
             value_preds_batch = []
             return_batch = []
             masks_batch = []
@@ -147,10 +157,11 @@ class RolloutStorage:
                     )
 
                 recurrent_hidden_states_batch.append(
-                    self.recurrent_hidden_states[0:1, ind]
+                    self.recurrent_hidden_states[0, :, ind]
                 )
 
                 actions_batch.append(self.actions[:, ind])
+                prev_actions_batch.append(self.prev_actions[:-1, ind])
                 value_preds_batch.append(self.value_preds[:-1, ind])
                 return_batch.append(self.returns[:-1, ind])
                 masks_batch.append(self.masks[:-1, ind])
@@ -169,6 +180,7 @@ class RolloutStorage:
                 )
 
             actions_batch = torch.stack(actions_batch, 1)
+            prev_actions_batch = torch.stack(prev_actions_batch, 1)
             value_preds_batch = torch.stack(value_preds_batch, 1)
             return_batch = torch.stack(return_batch, 1)
             masks_batch = torch.stack(masks_batch, 1)
@@ -177,10 +189,10 @@ class RolloutStorage:
             )
             adv_targ = torch.stack(adv_targ, 1)
 
-            # States is just a (N, -1) tensor
+            # States is just a (num_recurrent_layers, N, -1) tensor
             recurrent_hidden_states_batch = torch.stack(
                 recurrent_hidden_states_batch, 1
-            ).view(N, -1)
+            )
 
             # Flatten the (T, N, ...) tensors to (T * N, ...)
             for sensor in observations_batch:
@@ -189,6 +201,7 @@ class RolloutStorage:
                 )
 
             actions_batch = self._flatten_helper(T, N, actions_batch)
+            prev_actions_batch = self._flatten_helper(T, N, prev_actions_batch)
             value_preds_batch = self._flatten_helper(T, N, value_preds_batch)
             return_batch = self._flatten_helper(T, N, return_batch)
             masks_batch = self._flatten_helper(T, N, masks_batch)
@@ -201,6 +214,7 @@ class RolloutStorage:
                 observations_batch,
                 recurrent_hidden_states_batch,
                 actions_batch,
+                prev_actions_batch,
                 value_preds_batch,
                 return_batch,
                 masks_batch,
diff --git a/habitat_baselines/rl/models/__init__.py b/habitat_baselines/rl/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/habitat_baselines/rl/models/rnn_state_encoder.py b/habitat_baselines/rl/models/rnn_state_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e256c2bba9a8f500fbcdab39fd2ad2420b628493
--- /dev/null
+++ b/habitat_baselines/rl/models/rnn_state_encoder.py
@@ -0,0 +1,143 @@
+import torch
+import torch.nn as nn
+
+
+class RNNStateEncoder(nn.Module):
+    def __init__(
+        self,
+        input_size: int,
+        hidden_size: int,
+        num_layers: int = 1,
+        rnn_type: str = "GRU",
+    ):
+        r"""An RNN for encoding the state in RL.
+
+        Supports masking the hidden state during various timesteps in the forward lass
+
+        Args:
+            input_size: The input size of the RNN
+            hidden_size: The hidden size
+            num_layers: The number of recurrent layers
+            rnn_type: The RNN cell type.  Must be GRU or LSTM
+        """
+
+        super().__init__()
+        self._num_recurrent_layers = num_layers
+        self._rnn_type = rnn_type
+
+        self.rnn = getattr(nn, rnn_type)(
+            input_size=input_size,
+            hidden_size=hidden_size,
+            num_layers=num_layers,
+        )
+
+        self.layer_init()
+
+    def layer_init(self):
+        for name, param in self.rnn.named_parameters():
+            if "weight" in name:
+                nn.init.orthogonal_(param)
+            elif "bias" in name:
+                nn.init.constant_(param, 0)
+
+    @property
+    def num_recurrent_layers(self):
+        return self._num_recurrent_layers * (
+            2 if "LSTM" in self._rnn_type else 1
+        )
+
+    def _pack_hidden(self, hidden_states):
+        if "LSTM" in self._rnn_type:
+            hidden_states = torch.cat(
+                [hidden_states[0], hidden_states[1]], dim=0
+            )
+
+        return hidden_states
+
+    def _unpack_hidden(self, hidden_states):
+        if "LSTM" in self._rnn_type:
+            hidden_states = (
+                hidden_states[0 : self._num_recurrent_layers],
+                hidden_states[self._num_recurrent_layers :],
+            )
+
+        return hidden_states
+
+    def _mask_hidden(self, hidden_states, masks):
+        if isinstance(hidden_states, tuple):
+            hidden_states = tuple(v * masks for v in hidden_states)
+        else:
+            hidden_states = masks * hidden_states
+
+        return hidden_states
+
+    def single_forward(self, x, hidden_states, masks):
+        r"""Forward for a non-sequence input
+        """
+        hidden_states = self._unpack_hidden(hidden_states)
+        x, hidden_states = self.rnn(
+            x.unsqueeze(0),
+            self._mask_hidden(hidden_states, masks.unsqueeze(0)),
+        )
+        x = x.squeeze(0)
+        hidden_states = self._pack_hidden(hidden_states)
+        return x, hidden_states
+
+    def seq_forward(self, x, hidden_states, masks):
+        r"""Forward for a sequence of length T
+
+        Args:
+            x: (T, N, -1) Tensor that has been flattened to (T * N, -1)
+            hidden_states: The starting hidden state.
+            masks: The masks to be applied to hidden state at every timestep.
+                A (T, N) tensor flatten to (T * N)
+        """
+        # x is a (T, N, -1) tensor flattened to (T * N, -1)
+        n = hidden_states.size(1)
+        t = int(x.size(0) / n)
+
+        # unflatten
+        x = x.view(t, n, x.size(1))
+        masks = masks.view(t, n)
+
+        # steps in sequence which have zero for any agent. Assume t=0 has
+        # a zero in it.
+        has_zeros = (masks[1:] == 0.0).any(dim=-1).nonzero().squeeze().cpu()
+
+        # +1 to correct the masks[1:]
+        if has_zeros.dim() == 0:
+            has_zeros = [has_zeros.item() + 1]  # handle scalar
+        else:
+            has_zeros = (has_zeros + 1).numpy().tolist()
+
+        # add t=0 and t=T to the list
+        has_zeros = [0] + has_zeros + [t]
+
+        hidden_states = self._unpack_hidden(hidden_states)
+        outputs = []
+        for i in range(len(has_zeros) - 1):
+            # process steps that don't have any zeros in masks together
+            start_idx = has_zeros[i]
+            end_idx = has_zeros[i + 1]
+
+            rnn_scores, hidden_states = self.rnn(
+                x[start_idx:end_idx],
+                self._mask_hidden(
+                    hidden_states, masks[start_idx].view(1, -1, 1)
+                ),
+            )
+
+            outputs.append(rnn_scores)
+
+        # x is a (T, N, -1) tensor
+        x = torch.cat(outputs, dim=0)
+        x = x.view(t * n, -1)  # flatten
+
+        hidden_states = self._pack_hidden(hidden_states)
+        return x, hidden_states
+
+    def forward(self, x, hidden_states, masks):
+        if x.size(0) == hidden_states.size(1):
+            return self.single_forward(x, hidden_states, masks)
+        else:
+            return self.seq_forward(x, hidden_states, masks)
diff --git a/habitat_baselines/rl/models/simple_cnn.py b/habitat_baselines/rl/models/simple_cnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b6a8caee4e84d3a9fa7bd971939e130ecc8ef53
--- /dev/null
+++ b/habitat_baselines/rl/models/simple_cnn.py
@@ -0,0 +1,147 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+from habitat_baselines.common.utils import Flatten
+
+
+class SimpleCNN(nn.Module):
+    r"""A Simple 3-Conv CNN followed by a fully connected layer
+
+    Takes in observations and produces an embedding of the rgb and/or depth components
+
+    Args:
+        observation_space: The observation_space of the agent
+        output_size: The size of the embedding vector
+    """
+
+    def __init__(self, observation_space, output_size):
+        super().__init__()
+        if "rgb" in observation_space.spaces:
+            self._n_input_rgb = observation_space.spaces["rgb"].shape[2]
+        else:
+            self._n_input_rgb = 0
+
+        if "depth" in observation_space.spaces:
+            self._n_input_depth = observation_space.spaces["depth"].shape[2]
+        else:
+            self._n_input_depth = 0
+
+        # kernel size for different CNN layers
+        self._cnn_layers_kernel_size = [(8, 8), (4, 4), (3, 3)]
+
+        # strides for different CNN layers
+        self._cnn_layers_stride = [(4, 4), (2, 2), (1, 1)]
+
+        if self._n_input_rgb > 0:
+            cnn_dims = np.array(
+                observation_space.spaces["rgb"].shape[:2], dtype=np.float32
+            )
+        elif self._n_input_depth > 0:
+            cnn_dims = np.array(
+                observation_space.spaces["depth"].shape[:2], dtype=np.float32
+            )
+
+        if self.is_blind:
+            self.cnn = nn.Sequential()
+        else:
+            for kernel_size, stride in zip(
+                self._cnn_layers_kernel_size, self._cnn_layers_stride
+            ):
+                cnn_dims = self._conv_output_dim(
+                    dimension=cnn_dims,
+                    padding=np.array([0, 0], dtype=np.float32),
+                    dilation=np.array([1, 1], dtype=np.float32),
+                    kernel_size=np.array(kernel_size, dtype=np.float32),
+                    stride=np.array(stride, dtype=np.float32),
+                )
+
+            self.cnn = nn.Sequential(
+                nn.Conv2d(
+                    in_channels=self._n_input_rgb + self._n_input_depth,
+                    out_channels=32,
+                    kernel_size=self._cnn_layers_kernel_size[0],
+                    stride=self._cnn_layers_stride[0],
+                ),
+                nn.ReLU(True),
+                nn.Conv2d(
+                    in_channels=32,
+                    out_channels=64,
+                    kernel_size=self._cnn_layers_kernel_size[1],
+                    stride=self._cnn_layers_stride[1],
+                ),
+                nn.ReLU(True),
+                nn.Conv2d(
+                    in_channels=64,
+                    out_channels=32,
+                    kernel_size=self._cnn_layers_kernel_size[2],
+                    stride=self._cnn_layers_stride[2],
+                ),
+                #  nn.ReLU(True),
+                Flatten(),
+                nn.Linear(32 * cnn_dims[0] * cnn_dims[1], output_size),
+                nn.ReLU(True),
+            )
+
+        self.layer_init()
+
+    def _conv_output_dim(
+        self, dimension, padding, dilation, kernel_size, stride
+    ):
+        r"""Calculates the output height and width based on the input
+        height and width to the convolution layer.
+
+        ref: https://pytorch.org/docs/master/nn.html#torch.nn.Conv2d
+        """
+        assert len(dimension) == 2
+        out_dimension = []
+        for i in range(len(dimension)):
+            out_dimension.append(
+                int(
+                    np.floor(
+                        (
+                            (
+                                dimension[i]
+                                + 2 * padding[i]
+                                - dilation[i] * (kernel_size[i] - 1)
+                                - 1
+                            )
+                            / stride[i]
+                        )
+                        + 1
+                    )
+                )
+            )
+        return tuple(out_dimension)
+
+    def layer_init(self):
+        for layer in self.cnn:
+            if isinstance(layer, (nn.Conv2d, nn.Linear)):
+                nn.init.kaiming_normal_(
+                    layer.weight, nn.init.calculate_gain("relu")
+                )
+                if layer.bias is not None:
+                    nn.init.constant_(layer.bias, val=0)
+
+    @property
+    def is_blind(self):
+        return self._n_input_rgb + self._n_input_depth == 0
+
+    def forward(self, observations):
+        cnn_input = []
+        if self._n_input_rgb > 0:
+            rgb_observations = observations["rgb"]
+            # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH]
+            rgb_observations = rgb_observations.permute(0, 3, 1, 2)
+            rgb_observations = rgb_observations / 255.0  # normalize RGB
+            cnn_input.append(rgb_observations)
+
+        if self._n_input_depth > 0:
+            depth_observations = observations["depth"]
+            # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH]
+            depth_observations = depth_observations.permute(0, 3, 1, 2)
+            cnn_input.append(depth_observations)
+
+        cnn_input = torch.cat(cnn_input, dim=1)
+
+        return self.cnn(cnn_input)
diff --git a/habitat_baselines/rl/ppo/__init__.py b/habitat_baselines/rl/ppo/__init__.py
index 9c00af21530364e25ef4b36e9badeb2efdde4e8c..febc3fd73942a8a403503ff26bda8a7a4f30d260 100644
--- a/habitat_baselines/rl/ppo/__init__.py
+++ b/habitat_baselines/rl/ppo/__init__.py
@@ -4,7 +4,7 @@
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
 
-from habitat_baselines.rl.ppo.policy import Policy
+from habitat_baselines.rl.ppo.policy import Net, PointNavBaselinePolicy, Policy
 from habitat_baselines.rl.ppo.ppo import PPO
 
-__all__ = ["PPO", "Policy"]
+__all__ = ["PPO", "Policy", "Net", "PointNavBaselinePolicy"]
diff --git a/habitat_baselines/rl/ppo/policy.py b/habitat_baselines/rl/ppo/policy.py
index 0be21e5afb86fb1207fe780fb4c91c40ac06c0c8..6aced23407545f1eeb1e76b23c225ef17b659ba5 100644
--- a/habitat_baselines/rl/ppo/policy.py
+++ b/habitat_baselines/rl/ppo/policy.py
@@ -3,43 +3,44 @@
 # Copyright (c) Facebook, Inc. and its affiliates.
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
+import abc
 
 import numpy as np
 import torch
 import torch.nn as nn
 
 from habitat_baselines.common.utils import CategoricalNet, Flatten
+from habitat_baselines.rl.models.rnn_state_encoder import RNNStateEncoder
+from habitat_baselines.rl.models.simple_cnn import SimpleCNN
 
 
 class Policy(nn.Module):
-    def __init__(
-        self,
-        observation_space,
-        action_space,
-        goal_sensor_uuid,
-        hidden_size=512,
-    ):
+    def __init__(self, net, dim_actions):
         super().__init__()
-        self.dim_actions = action_space.n
-        self.goal_sensor_uuid = goal_sensor_uuid
-        self.net = Net(
-            observation_space=observation_space,
-            hidden_size=hidden_size,
-            goal_sensor_uuid=goal_sensor_uuid,
-        )
+        self.net = net
+        self.dim_actions = dim_actions
 
         self.action_distribution = CategoricalNet(
             self.net.output_size, self.dim_actions
         )
+        self.critic = CriticHead(self.net.output_size)
 
     def forward(self, *x):
         raise NotImplementedError
 
-    def act(self, observations, rnn_hidden_states, masks, deterministic=False):
-        value, actor_features, rnn_hidden_states = self.net(
-            observations, rnn_hidden_states, masks
+    def act(
+        self,
+        observations,
+        rnn_hidden_states,
+        prev_actions,
+        masks,
+        deterministic=False,
+    ):
+        features, rnn_hidden_states = self.net(
+            observations, rnn_hidden_states, prev_actions, masks
         )
-        distribution = self.action_distribution(actor_features)
+        distribution = self.action_distribution(features)
+        value = self.critic(features)
 
         if deterministic:
             action = distribution.mode()
@@ -50,15 +51,20 @@ class Policy(nn.Module):
 
         return value, action, action_log_probs, rnn_hidden_states
 
-    def get_value(self, observations, rnn_hidden_states, masks):
-        value, _, _ = self.net(observations, rnn_hidden_states, masks)
-        return value
+    def get_value(self, observations, rnn_hidden_states, prev_actions, masks):
+        features, _ = self.net(
+            observations, rnn_hidden_states, prev_actions, masks
+        )
+        return self.critic(features)
 
-    def evaluate_actions(self, observations, rnn_hidden_states, masks, action):
-        value, actor_features, rnn_hidden_states = self.net(
-            observations, rnn_hidden_states, masks
+    def evaluate_actions(
+        self, observations, rnn_hidden_states, prev_actions, masks, action
+    ):
+        features, rnn_hidden_states = self.net(
+            observations, rnn_hidden_states, prev_actions, masks
         )
-        distribution = self.action_distribution(actor_features)
+        distribution = self.action_distribution(features)
+        value = self.critic(features)
 
         action_log_probs = distribution.log_probs(action)
         distribution_entropy = distribution.entropy().mean()
@@ -66,7 +72,57 @@ class Policy(nn.Module):
         return value, action_log_probs, distribution_entropy, rnn_hidden_states
 
 
-class Net(nn.Module):
+class CriticHead(nn.Module):
+    def __init__(self, input_size):
+        super().__init__()
+        self.fc = nn.Linear(input_size, 1)
+        nn.init.orthogonal_(self.fc.weight)
+        nn.init.constant_(self.fc.bias, 0)
+
+    def forward(self, x):
+        return self.fc(x)
+
+
+class PointNavBaselinePolicy(Policy):
+    def __init__(
+        self,
+        observation_space,
+        action_space,
+        goal_sensor_uuid,
+        hidden_size=512,
+    ):
+        super().__init__(
+            PointNavBaselineNet(
+                observation_space=observation_space,
+                hidden_size=hidden_size,
+                goal_sensor_uuid=goal_sensor_uuid,
+            ),
+            action_space.n,
+        )
+
+
+class Net(nn.Module, metaclass=abc.ABCMeta):
+    @abc.abstractmethod
+    def forward(self, observations, rnn_hidden_states, prev_actions, masks):
+        pass
+
+    @property
+    @abc.abstractmethod
+    def output_size(self):
+        pass
+
+    @property
+    @abc.abstractmethod
+    def num_recurrent_layers(self):
+        pass
+
+    @property
+    @abc.abstractmethod
+    def is_blind(self):
+        pass
+
+
+class PointNavBaselineNet(Net):
     r"""Network which passes the input image through CNN and concatenates
     goal vector with CNN's output and passes that through RNN.
     """
@@ -79,218 +135,39 @@ class Net(nn.Module):
         ].shape[0]
         self._hidden_size = hidden_size
 
-        self.cnn = self._init_perception_model(observation_space)
-
-        if self.is_blind:
-            self.rnn = nn.GRU(self._n_input_goal, self._hidden_size)
-        else:
-            self.rnn = nn.GRU(
-                self.output_size + self._n_input_goal, self._hidden_size
-            )
+        self.visual_encoder = SimpleCNN(observation_space, hidden_size)
 
-        self.critic_linear = nn.Linear(self._hidden_size, 1)
+        self.state_encoder = RNNStateEncoder(
+            (0 if self.is_blind else self._hidden_size) + self._n_input_goal,
+            self._hidden_size,
+        )
 
-        self.layer_init()
         self.train()
 
-    def _init_perception_model(self, observation_space):
-        if "rgb" in observation_space.spaces:
-            self._n_input_rgb = observation_space.spaces["rgb"].shape[2]
-        else:
-            self._n_input_rgb = 0
-
-        if "depth" in observation_space.spaces:
-            self._n_input_depth = observation_space.spaces["depth"].shape[2]
-        else:
-            self._n_input_depth = 0
-
-        # kernel size for different CNN layers
-        self._cnn_layers_kernel_size = [(8, 8), (4, 4), (3, 3)]
-
-        # strides for different CNN layers
-        self._cnn_layers_stride = [(4, 4), (2, 2), (1, 1)]
-
-        if self._n_input_rgb > 0:
-            cnn_dims = np.array(
-                observation_space.spaces["rgb"].shape[:2], dtype=np.float32
-            )
-        elif self._n_input_depth > 0:
-            cnn_dims = np.array(
-                observation_space.spaces["depth"].shape[:2], dtype=np.float32
-            )
-
-        if self.is_blind:
-            return nn.Sequential()
-        else:
-            for kernel_size, stride in zip(
-                self._cnn_layers_kernel_size, self._cnn_layers_stride
-            ):
-                cnn_dims = self._conv_output_dim(
-                    dimension=cnn_dims,
-                    padding=np.array([0, 0], dtype=np.float32),
-                    dilation=np.array([1, 1], dtype=np.float32),
-                    kernel_size=np.array(kernel_size, dtype=np.float32),
-                    stride=np.array(stride, dtype=np.float32),
-                )
-
-            return nn.Sequential(
-                nn.Conv2d(
-                    in_channels=self._n_input_rgb + self._n_input_depth,
-                    out_channels=32,
-                    kernel_size=self._cnn_layers_kernel_size[0],
-                    stride=self._cnn_layers_stride[0],
-                ),
-                nn.ReLU(),
-                nn.Conv2d(
-                    in_channels=32,
-                    out_channels=64,
-                    kernel_size=self._cnn_layers_kernel_size[1],
-                    stride=self._cnn_layers_stride[1],
-                ),
-                nn.ReLU(),
-                nn.Conv2d(
-                    in_channels=64,
-                    out_channels=32,
-                    kernel_size=self._cnn_layers_kernel_size[2],
-                    stride=self._cnn_layers_stride[2],
-                ),
-                Flatten(),
-                nn.Linear(32 * cnn_dims[0] * cnn_dims[1], self._hidden_size),
-                nn.ReLU(),
-            )
-
-    def _conv_output_dim(
-        self, dimension, padding, dilation, kernel_size, stride
-    ):
-        r"""Calculates the output height and width based on the input
-        height and width to the convolution layer.
-
-        ref: https://pytorch.org/docs/master/nn.html#torch.nn.Conv2d
-        """
-        assert len(dimension) == 2
-        out_dimension = []
-        for i in range(len(dimension)):
-            out_dimension.append(
-                int(
-                    np.floor(
-                        (
-                            (
-                                dimension[i]
-                                + 2 * padding[i]
-                                - dilation[i] * (kernel_size[i] - 1)
-                                - 1
-                            )
-                            / stride[i]
-                        )
-                        + 1
-                    )
-                )
-            )
-        return tuple(out_dimension)
-
     @property
     def output_size(self):
         return self._hidden_size
 
-    def layer_init(self):
-        for layer in self.cnn:
-            if isinstance(layer, (nn.Conv2d, nn.Linear)):
-                nn.init.orthogonal_(
-                    layer.weight, nn.init.calculate_gain("relu")
-                )
-                nn.init.constant_(layer.bias, val=0)
-
-        for name, param in self.rnn.named_parameters():
-            if "weight" in name:
-                nn.init.orthogonal_(param)
-            elif "bias" in name:
-                nn.init.constant_(param, 0)
-
-        nn.init.orthogonal_(self.critic_linear.weight, gain=1)
-        nn.init.constant_(self.critic_linear.bias, val=0)
-
-    def forward_rnn(self, x, hidden_states, masks):
-        if x.size(0) == hidden_states.size(0):
-            x, hidden_states = self.rnn(
-                x.unsqueeze(0), (hidden_states * masks).unsqueeze(0)
-            )
-            x = x.squeeze(0)
-            hidden_states = hidden_states.squeeze(0)
-        else:
-            # x is a (T, N, -1) tensor flattened to (T * N, -1)
-            n = hidden_states.size(0)
-            t = int(x.size(0) / n)
-
-            # unflatten
-            x = x.view(t, n, x.size(1))
-            masks = masks.view(t, n)
-
-            # steps in sequence which have zero for any agent. Assume t=0 has
-            # a zero in it.
-            has_zeros = (
-                (masks[1:] == 0.0).any(dim=-1).nonzero().squeeze().cpu()
-            )
-
-            # +1 to correct the masks[1:]
-            if has_zeros.dim() == 0:
-                has_zeros = [has_zeros.item() + 1]  # handle scalar
-            else:
-                has_zeros = (has_zeros + 1).numpy().tolist()
-
-            # add t=0 and t=T to the list
-            has_zeros = [0] + has_zeros + [t]
-
-            hidden_states = hidden_states.unsqueeze(0)
-            outputs = []
-            for i in range(len(has_zeros) - 1):
-                # process steps that don't have any zeros in masks together
-                start_idx = has_zeros[i]
-                end_idx = has_zeros[i + 1]
-
-                rnn_scores, hidden_states = self.rnn(
-                    x[start_idx:end_idx],
-                    hidden_states * masks[start_idx].view(1, -1, 1),
-                )
-
-                outputs.append(rnn_scores)
-
-            # x is a (T, N, -1) tensor
-            x = torch.cat(outputs, dim=0)
-            x = x.view(t * n, -1)  # flatten
-            hidden_states = hidden_states.squeeze(0)
-
-        return x, hidden_states
-
     @property
     def is_blind(self):
-        return self._n_input_rgb + self._n_input_depth == 0
+        return self.visual_encoder.is_blind
 
-    def forward_perception_model(self, observations):
-        cnn_input = []
-        if self._n_input_rgb > 0:
-            rgb_observations = observations["rgb"]
-            # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH]
-            rgb_observations = rgb_observations.permute(0, 3, 1, 2)
-            rgb_observations = rgb_observations / 255.0  # normalize RGB
-            cnn_input.append(rgb_observations)
-
-        if self._n_input_depth > 0:
-            depth_observations = observations["depth"]
-            # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH]
-            depth_observations = depth_observations.permute(0, 3, 1, 2)
-            cnn_input.append(depth_observations)
-
-        cnn_input = torch.cat(cnn_input, dim=1)
+    @property
+    def num_recurrent_layers(self):
+        return self.state_encoder.num_recurrent_layers
 
-        return self.cnn(cnn_input)
+    def get_target_encoding(self, observations):
+        return observations[self.goal_sensor_uuid]
 
-    def forward(self, observations, rnn_hidden_states, masks):
-        x = observations[self.goal_sensor_uuid]
+    def forward(self, observations, rnn_hidden_states, prev_actions, masks):
+        target_encoding = self.get_target_encoding(observations)
+        x = [target_encoding]
 
         if not self.is_blind:
-            perception_embed = self.forward_perception_model(observations)
-            x = torch.cat([perception_embed, x], dim=1)
+            perception_embed = self.visual_encoder(observations)
+            x = [perception_embed] + x
 
-        x, rnn_hidden_states = self.forward_rnn(x, rnn_hidden_states, masks)
+        x = torch.cat(x, dim=1)
+        x, rnn_hidden_states = self.state_encoder(x, rnn_hidden_states, masks)
 
-        return self.critic_linear(x), x, rnn_hidden_states
+        return x, rnn_hidden_states
diff --git a/habitat_baselines/rl/ppo/ppo.py b/habitat_baselines/rl/ppo/ppo.py
index 78f113f89d7f167b0bcd20e556d8c9ec0e6e9b71..85bfd9e621d62375ee58155068283c9914d95bb4 100644
--- a/habitat_baselines/rl/ppo/ppo.py
+++ b/habitat_baselines/rl/ppo/ppo.py
@@ -24,6 +24,7 @@ class PPO(nn.Module):
         eps=None,
         max_grad_norm=None,
         use_clipped_value_loss=True,
+        use_normalized_advantage=True,
     ):
 
         super().__init__()
@@ -41,15 +42,21 @@ class PPO(nn.Module):
         self.use_clipped_value_loss = use_clipped_value_loss
 
         self.optimizer = optim.Adam(actor_critic.parameters(), lr=lr, eps=eps)
+        self.device = next(actor_critic.parameters()).device
+        self.use_normalized_advantage = use_normalized_advantage
 
     def forward(self, *x):
         raise NotImplementedError
 
-    def update(self, rollouts):
+    def get_advantages(self, rollouts):
         advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1]
-        advantages = (advantages - advantages.mean()) / (
-            advantages.std() + EPS_PPO
-        )
+        if not self.use_normalized_advantage:
+            return advantages
+
+        return (advantages - advantages.mean()) / (advantages.std() + EPS_PPO)
+
+    def update(self, rollouts):
+        advantages = self.get_advantages(rollouts)
 
         value_loss_epoch = 0
         action_loss_epoch = 0
@@ -65,6 +72,7 @@ class PPO(nn.Module):
                     obs_batch,
                     recurrent_hidden_states_batch,
                     actions_batch,
+                    prev_actions_batch,
                     value_preds_batch,
                     return_batch,
                     masks_batch,
@@ -81,6 +89,7 @@ class PPO(nn.Module):
                 ) = self.actor_critic.evaluate_actions(
                     obs_batch,
                     recurrent_hidden_states_batch,
+                    prev_actions_batch,
                     masks_batch,
                     actions_batch,
                 )
@@ -113,15 +122,19 @@ class PPO(nn.Module):
                     value_loss = 0.5 * (return_batch - values).pow(2).mean()
 
                 self.optimizer.zero_grad()
-                (
+                total_loss = (
                     value_loss * self.value_loss_coef
                     + action_loss
                     - dist_entropy * self.entropy_coef
-                ).backward()
-                nn.utils.clip_grad_norm_(
-                    self.actor_critic.parameters(), self.max_grad_norm
                 )
+
+                self.before_backward(total_loss)
+                total_loss.backward()
+                self.after_backward(total_loss)
+
+                self.before_step()
                 self.optimizer.step()
+                self.after_step()
 
                 value_loss_epoch += value_loss.item()
                 action_loss_epoch += action_loss.item()
@@ -134,3 +147,17 @@ class PPO(nn.Module):
         dist_entropy_epoch /= num_updates
 
         return value_loss_epoch, action_loss_epoch, dist_entropy_epoch
+
+    def before_backward(self, loss):
+        pass
+
+    def after_backward(self, loss):
+        pass
+
+    def before_step(self):
+        nn.utils.clip_grad_norm_(
+            self.actor_critic.parameters(), self.max_grad_norm
+        )
+
+    def after_step(self):
+        pass
diff --git a/habitat_baselines/rl/ppo/ppo_trainer.py b/habitat_baselines/rl/ppo/ppo_trainer.py
index 602b22e82fc7d9a0a52f12174de064245a5817ff..dbfab622ef1af6fddc4e88abcececcfcf9a08010 100644
--- a/habitat_baselines/rl/ppo/ppo_trainer.py
+++ b/habitat_baselines/rl/ppo/ppo_trainer.py
@@ -29,7 +29,7 @@ from habitat_baselines.common.utils import (
     poll_checkpoint_folder,
     update_linear_schedule,
 )
-from habitat_baselines.rl.ppo import PPO, Policy
+from habitat_baselines.rl.ppo import PPO, PointNavBaselinePolicy
 
 
 @baseline_registry.register_trainer(name="ppo")
@@ -60,7 +60,7 @@ class PPOTrainer(BaseRLTrainer):
         """
         logger.add_filehandler(ppo_cfg.log_file)
 
-        self.actor_critic = Policy(
+        self.actor_critic = PointNavBaselinePolicy(
             observation_space=self.envs.observation_spaces[0],
             action_space=self.envs.action_spaces[0],
             hidden_size=512,
@@ -113,6 +113,96 @@ class PPOTrainer(BaseRLTrainer):
         """
         return torch.load(checkpoint_path, map_location=self.device)
 
+    def _collect_rollout_step(
+        self, rollouts, current_episode_reward, episode_rewards, episode_counts
+    ):
+        pth_time = 0.0
+        env_time = 0.0
+
+        t_sample_action = time.time()
+        # sample actions
+        with torch.no_grad():
+            step_observation = {
+                k: v[rollouts.step] for k, v in rollouts.observations.items()
+            }
+
+            (
+                values,
+                actions,
+                actions_log_probs,
+                recurrent_hidden_states,
+            ) = self.actor_critic.act(
+                step_observation,
+                rollouts.recurrent_hidden_states[rollouts.step],
+                rollouts.prev_actions[rollouts.step],
+                rollouts.masks[rollouts.step],
+            )
+
+        pth_time += time.time() - t_sample_action
+
+        t_step_env = time.time()
+
+        outputs = self.envs.step([a[0].item() for a in actions])
+        observations, rewards, dones, infos = [list(x) for x in zip(*outputs)]
+
+        env_time += time.time() - t_step_env
+
+        t_update_stats = time.time()
+        batch = batch_obs(observations)
+        rewards = torch.tensor(rewards, dtype=torch.float)
+        rewards = rewards.unsqueeze(1)
+
+        masks = torch.tensor(
+            [[0.0] if done else [1.0] for done in dones], dtype=torch.float
+        )
+
+        current_episode_reward += rewards
+        episode_rewards += (1 - masks) * current_episode_reward
+        episode_counts += 1 - masks
+        current_episode_reward *= masks
+
+        rollouts.insert(
+            batch,
+            recurrent_hidden_states,
+            actions,
+            actions_log_probs,
+            values,
+            rewards,
+            masks,
+        )
+
+        pth_time += time.time() - t_update_stats
+
+        return pth_time, env_time, self.envs.num_envs
+
+    def _update_agent(self, ppo_cfg, rollouts):
+        t_update_model = time.time()
+        with torch.no_grad():
+            last_observation = {
+                k: v[-1] for k, v in rollouts.observations.items()
+            }
+            next_value = self.actor_critic.get_value(
+                last_observation,
+                rollouts.recurrent_hidden_states[-1],
+                rollouts.prev_actions[-1],
+                rollouts.masks[-1],
+            ).detach()
+
+        rollouts.compute_returns(
+            next_value, ppo_cfg.use_gae, ppo_cfg.gamma, ppo_cfg.tau
+        )
+
+        value_loss, action_loss, dist_entropy = self.agent.update(rollouts)
+
+        rollouts.after_update()
+
+        return (
+            time.time() - t_update_model,
+            value_loss,
+            action_loss,
+            dist_entropy,
+        )
+
     def train(self) -> None:
         r"""Main method for training PPO.
 
@@ -184,88 +274,24 @@ class PPOTrainer(BaseRLTrainer):
                     )
 
                 for step in range(ppo_cfg.num_steps):
-                    t_sample_action = time.time()
-                    # sample actions
-                    with torch.no_grad():
-                        step_observation = {
-                            k: v[step]
-                            for k, v in rollouts.observations.items()
-                        }
-
-                        (
-                            values,
-                            actions,
-                            actions_log_probs,
-                            recurrent_hidden_states,
-                        ) = self.actor_critic.act(
-                            step_observation,
-                            rollouts.recurrent_hidden_states[step],
-                            rollouts.masks[step],
-                        )
-                    pth_time += time.time() - t_sample_action
-
-                    t_step_env = time.time()
-
-                    outputs = self.envs.step([a[0].item() for a in actions])
-                    observations, rewards, dones, infos = [
-                        list(x) for x in zip(*outputs)
-                    ]
-
-                    env_time += time.time() - t_step_env
-
-                    t_update_stats = time.time()
-                    batch = batch_obs(observations)
-                    rewards = torch.tensor(rewards, dtype=torch.float)
-                    rewards = rewards.unsqueeze(1)
-
-                    masks = torch.tensor(
-                        [[0.0] if done else [1.0] for done in dones],
-                        dtype=torch.float,
-                    )
-
-                    current_episode_reward += rewards
-                    episode_rewards += (1 - masks) * current_episode_reward
-                    episode_counts += 1 - masks
-                    current_episode_reward *= masks
-
-                    rollouts.insert(
-                        batch,
-                        recurrent_hidden_states,
-                        actions,
-                        actions_log_probs,
-                        values,
-                        rewards,
-                        masks,
+                    delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step(
+                        rollouts,
+                        current_episode_reward,
+                        episode_rewards,
+                        episode_counts,
                     )
+                    pth_time += delta_pth_time
+                    env_time += delta_env_time
+                    count_steps += delta_steps
 
-                    count_steps += self.envs.num_envs
-                    pth_time += time.time() - t_update_stats
+                delta_pth_time, value_loss, action_loss, dist_entropy = self._update_agent(
+                    ppo_cfg, rollouts
+                )
+                pth_time += delta_pth_time
 
                 window_episode_reward.append(episode_rewards.clone())
                 window_episode_counts.append(episode_counts.clone())
 
-                t_update_model = time.time()
-                with torch.no_grad():
-                    last_observation = {
-                        k: v[-1] for k, v in rollouts.observations.items()
-                    }
-                    next_value = self.actor_critic.get_value(
-                        last_observation,
-                        rollouts.recurrent_hidden_states[-1],
-                        rollouts.masks[-1],
-                    ).detach()
-
-                rollouts.compute_returns(
-                    next_value, ppo_cfg.use_gae, ppo_cfg.gamma, ppo_cfg.tau
-                )
-
-                value_loss, action_loss, dist_entropy = self.agent.update(
-                    rollouts
-                )
-
-                rollouts.after_update()
-                pth_time += time.time() - t_update_model
-
                 losses = [value_loss, action_loss]
                 stats = zip(
                     ["count", "reward"],
@@ -434,7 +460,13 @@ class PPOTrainer(BaseRLTrainer):
         )
 
         test_recurrent_hidden_states = torch.zeros(
-            ppo_cfg.num_processes, ppo_cfg.hidden_size, device=self.device
+            actor_critic.net.num_recurrent_layers,
+            ppo_cfg.num_processes,
+            ppo_cfg.hidden_size,
+            device=self.device,
+        )
+        prev_actions = torch.zeros(
+            args.num_processes, 1, device=device, dtype=torch.long
         )
         not_done_masks = torch.zeros(
             ppo_cfg.num_processes, 1, device=self.device
@@ -457,10 +489,13 @@ class PPOTrainer(BaseRLTrainer):
                 _, actions, _, test_recurrent_hidden_states = self.actor_critic.act(
                     batch,
                     test_recurrent_hidden_states,
+                    prev_actions,
                     not_done_masks,
                     deterministic=False,
                 )
 
+                prev_actions.copy_(actions)
+
             outputs = self.envs.step([a[0].item() for a in actions])
 
             observations, rewards, dones, infos = [
@@ -533,6 +568,7 @@ class PPOTrainer(BaseRLTrainer):
                 ]
                 not_done_masks = not_done_masks[state_index]
                 current_episode_reward = current_episode_reward[state_index]
+                prev_actions = prev_actions[state_index]
 
                 for k, v in batch.items():
                     batch[k] = v[state_index]