Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 10 Sep 2024 07:05:29 +0000 (09:05 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 10 Sep 2024 07:05:29 +0000 (09:05 +0200)
main.py

diff --git a/main.py b/main.py
index 97d37ce..b7050df 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1230,7 +1230,9 @@ if args.resume:
         filename = f"ae_{model.id:03d}.pth"
 
         try:
-            d = torch.load(os.path.join(args.result_dir, filename))
+            d = torch.load(
+                os.path.join(args.result_dir, filename), map_location=main_device
+            )
             model.load_state_dict(d["state_dict"])
             model.optimizer.load_state_dict(d["optimizer_state_dict"])
             model.test_accuracy = d["test_accuracy"]
@@ -1378,9 +1380,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
         else:
             records.append(
-                generate_ae_c_quizzes(
-                    models, nb_c_quizzes_to_generate, records, gpus[0]
-                )
+                generate_ae_c_quizzes(models, nb_c_quizzes_to_generate, gpus[0])
             )
 
         time_c_quizzes = int(time.perf_counter() - start_time)