diff --git a/habitat_baselines/common/utils.py b/habitat_baselines/common/utils.py
index d9b976765786f934f36494a9db9ff58ee8ca3cc2..6fe565b1eccb17005b8d9d96eda3ef2b14f698bb 100644
--- a/habitat_baselines/common/utils.py
+++ b/habitat_baselines/common/utils.py
@@ -96,8 +96,10 @@ def batch_obs(
             batch[sensor].append(_to_tensor(obs[sensor]))
 
     for sensor in batch:
-        batch[sensor] = torch.stack(batch[sensor], dim=0).to(
-            device=device, dtype=torch.float
+        batch[sensor] = (
+            torch.stack(batch[sensor], dim=0)
+            .to(device=device)
+            .to(dtype=torch.float)
         )
 
     return batch