From 8f4766f67b3eabc74750dd9eb0c25361bc81c340 Mon Sep 17 00:00:00 2001
From: Erik Wijmans <etw@gatech.edu>
Date: Mon, 17 Feb 2020 20:05:10 -0500
Subject: [PATCH] Convert observations to float32 on the GPU (#297)

---
 habitat_baselines/common/utils.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/habitat_baselines/common/utils.py b/habitat_baselines/common/utils.py
index d9b976765..6fe565b1e 100644
--- a/habitat_baselines/common/utils.py
+++ b/habitat_baselines/common/utils.py
@@ -96,8 +96,10 @@ def batch_obs(
             batch[sensor].append(_to_tensor(obs[sensor]))
 
     for sensor in batch:
-        batch[sensor] = torch.stack(batch[sensor], dim=0).to(
-            device=device, dtype=torch.float
+        batch[sensor] = (
+            torch.stack(batch[sensor], dim=0)
+            .to(device=device)
+            .to(dtype=torch.float)
         )
 
     return batch
-- 
GitLab