Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 10 Sep 2024 11:20:45 +0000 (13:20 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 10 Sep 2024 11:20:45 +0000 (13:20 +0200)
main.py

diff --git a/main.py b/main.py
index fed8abc..c1ef5bc 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1080,11 +1080,7 @@ def quiz_validation(models, c_quizzes, local_device):
             mask_generate = quiz_machine.make_quiz_mask(
                 quizzes=c_quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
-            result = ae_generate(
-                model,
-                (1 - mask_generate) * c_quizzes,
-                mask_generate,
-            )
+            result = ae_generate(model, (1 - mask_generate) * c_quizzes, mask_generate)
             nb_mistakes = (result != c_quizzes).long().sum(dim=1)
             correct = correct & (nb_mistakes == 0)
             wrong = wrong | (nb_mistakes >= nb_mistakes_to_be_wrong)
@@ -1201,9 +1197,11 @@ def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
 
 def save_c_quizzes_with_scores(models, c_quizzes, filename):
     l = []
-    for model in models:
-        model.eval().to(main_device)
-        l.append(model_ae_proba_solutions(model, c_quizzes))
+
+    with torch.autograd.no_grad():
+        for model in models:
+            model = copy.deepcopy(model).to(main_device).eval()
+            l.append(model_ae_proba_solutions(model, c_quizzes))
 
     probas = torch.cat([x[:, None] for x in l], dim=1)
 
@@ -1215,7 +1213,7 @@ def save_c_quizzes_with_scores(models, c_quizzes, filename):
     quiz_machine.problem.save_quizzes_as_image(
         args.result_dir,
         filename,
-        quizzes=subset_c_quizzes,
+        quizzes=c_quizzes,
         comments=comments,
         delta=True,
         nrow=8,
@@ -1233,9 +1231,7 @@ if args.resume:
         filename = f"ae_{model.id:03d}.pth"
 
         try:
-            d = torch.load(
-                os.path.join(args.result_dir, filename), map_location=main_device
-            )
+            d = torch.load(os.path.join(args.result_dir, filename), map_location="cpu")
             model.load_state_dict(d["state_dict"])
             model.optimizer.load_state_dict(d["optimizer_state_dict"])
             model.test_accuracy = d["test_accuracy"]