Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 4 Sep 2024 07:27:24 +0000 (09:27 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 4 Sep 2024 07:27:24 +0000 (09:27 +0200)
main.py

diff --git a/main.py b/main.py
index 61fc090..02c9fc6 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -957,27 +957,30 @@ def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_
 
         model.test_accuracy = nb_correct / nb_total
 
-        # for f, record in [("prediction", record_d), ("generation", record_nd)]:
-        # filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
+        # Save some images
 
-        # result, predicted_parts, correct_parts = bag_to_tensors(record)
+        for f, record in [("prediction", record_d), ("generation", record_nd)]:
+            filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
 
-        # l = [model_ae_proba_solutions(model, result) for model in other_models]
-        # probas = torch.cat([x[:, None] for x in l], dim=1)
-        # comments = []
+            result, predicted_parts, correct_parts = bag_to_tensors(record)
 
-        # for l in probas:
-        # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+            # l = [model_ae_proba_solutions(model, result) for model in other_models]
+            # probas = torch.cat([x[:, None] for x in l], dim=1)
+            # comments = []
 
-        # quiz_machine.problem.save_quizzes_as_image(
-        # args.result_dir,
-        # filename,
-        # quizzes=result,
-        # predicted_parts=predicted_parts,
-        # correct_parts=correct_parts,
-        # comments=comments,
-        # )
-        # log_string(f"wrote {filename}")
+            # for l in probas:
+            # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+
+            quiz_machine.problem.save_quizzes_as_image(
+                args.result_dir,
+                filename,
+                quizzes=result[:128],
+                predicted_parts=predicted_parts[:128],
+                correct_parts=correct_parts[:128],
+                # comments=comments,
+            )
+
+            log_string(f"wrote {filename}")
 
         # Prediction with functional perturbations
 
@@ -1046,7 +1049,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
         f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
     )
 
-    run_ae_test(model, quiz_machine, n_epoch, local_device=local_device)
+    run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device)
 
 
 ######################################################################