diff --git a/habitat_baselines/common/utils.py b/habitat_baselines/common/utils.py index d9b976765786f934f36494a9db9ff58ee8ca3cc2..6fe565b1eccb17005b8d9d96eda3ef2b14f698bb 100644 --- a/habitat_baselines/common/utils.py +++ b/habitat_baselines/common/utils.py @@ -96,8 +96,10 @@ def batch_obs( batch[sensor].append(_to_tensor(obs[sensor])) for sensor in batch: - batch[sensor] = torch.stack(batch[sensor], dim=0).to( - device=device, dtype=torch.float + batch[sensor] = ( + torch.stack(batch[sensor], dim=0) + .to(device=device) + .to(dtype=torch.float) ) return batch