From b4ad1cace5019dd2a51353534bac9747aa5071d7 Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Thu, 17 Jun 2021 23:49:34 +0800
Subject: [PATCH] load weights to cpu instead of gpu

---
 generate_samples.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/generate_samples.py b/generate_samples.py
index c6876a5..b357164 100755
--- a/generate_samples.py
+++ b/generate_samples.py
@@ -56,7 +56,7 @@ def setup_model(args):
             iteration, release, success = get_checkpoint_iteration(args)
             path = os.path.join(args.load, str(iteration), "mp_rank_00_model_states.pt")
             print('current device:', torch.cuda.current_device())
-            checkpoint = torch.load(path, map_location=lambda storage, loc: storage.cuda(torch.cuda.current_device()))
+            checkpoint = torch.load(path, map_location=torch.device('cpu'))
             model.load_state_dict(checkpoint["module"])
             print(f"Load model file {path}")
         else:
-- 
GitLab