Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 15 Sep 2024 11:07:23 +0000 (13:07 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 15 Sep 2024 11:07:23 +0000 (13:07 +0200)
main.py

diff --git a/main.py b/main.py
index 49799e4..1461ab1 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -59,7 +59,7 @@ parser.add_argument("--physical_batch_size", type=int, default=None)
 
 parser.add_argument("--inference_batch_size", type=int, default=25)
 
-parser.add_argument("--nb_train_samples", type=int, default=100000)
+parser.add_argument("--nb_train_samples", type=int, default=50000)
 
 parser.add_argument("--nb_test_samples", type=int, default=1000)
 
@@ -740,7 +740,7 @@ def quiz_validation(
 
     wrong = torch.cat(record_wrong, dim=1)
 
-    return to_keep, wrong
+    return to_keep, nb_correct, nb_wrong, wrong
 
 
 ######################################################################
@@ -782,7 +782,7 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
             c_quizzes = c_quizzes[to_keep]
 
             if c_quizzes.size(0) > 0:
-                to_keep, record_wrong = quiz_validation(
+                to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation(
                     models,
                     c_quizzes,
                     local_device,
@@ -840,7 +840,7 @@ def save_c_quizzes_with_scores(models, c_quizzes, filename, solvable_only=False)
 
     with torch.autograd.no_grad():
         if solvable_only:
-            to_keep, _ = quiz_validation(
+            to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation(
                 models,
                 c_quizzes,
                 main_device,
@@ -850,16 +850,10 @@ def save_c_quizzes_with_scores(models, c_quizzes, filename, solvable_only=False)
             )
             c_quizzes = c_quizzes[to_keep]
 
-        for model in models:
-            model = copy.deepcopy(model).to(main_device).eval()
-            l.append(model_ae_proba_solutions(model, c_quizzes))
-
-    probas = torch.cat([x[:, None] for x in l], dim=1)
-
     comments = []
 
-    for l in probas:
-        comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+    for c, w in zip(nb_correct, nb_wrong):
+        comments.append("nb_correct {c} nb_wrong {w}")
 
     quiz_machine.problem.save_quizzes_as_image(
         args.result_dir,