Skip to content
Snippets Groups Projects
Unverified Commit 8f4766f6 authored by Erik Wijmans's avatar Erik Wijmans Committed by GitHub
Browse files

Convert observations to float32 on the GPU (#297)

parent f5e29c69
Branches
Tags
No related merge requests found
......@@ -96,8 +96,10 @@ def batch_obs(
batch[sensor].append(_to_tensor(obs[sensor]))
for sensor in batch:
batch[sensor] = torch.stack(batch[sensor], dim=0).to(
device=device, dtype=torch.float
batch[sensor] = (
torch.stack(batch[sensor], dim=0)
.to(device=device)
.to(dtype=torch.float)
)
return batch
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment