From 8f4766f67b3eabc74750dd9eb0c25361bc81c340 Mon Sep 17 00:00:00 2001 From: Erik Wijmans <etw@gatech.edu> Date: Mon, 17 Feb 2020 20:05:10 -0500 Subject: [PATCH] Convert observations to float32 on the GPU (#297) --- habitat_baselines/common/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/habitat_baselines/common/utils.py b/habitat_baselines/common/utils.py index d9b976765..6fe565b1e 100644 --- a/habitat_baselines/common/utils.py +++ b/habitat_baselines/common/utils.py @@ -96,8 +96,10 @@ def batch_obs( batch[sensor].append(_to_tensor(obs[sensor])) for sensor in batch: - batch[sensor] = torch.stack(batch[sensor], dim=0).to( - device=device, dtype=torch.float + batch[sensor] = ( + torch.stack(batch[sensor], dim=0) + .to(device=device) + .to(dtype=torch.float) ) return batch -- GitLab