Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 11 Sep 2024 13:56:39 +0000 (15:56 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 11 Sep 2024 13:56:39 +0000 (15:56 +0200)
main.py

diff --git a/main.py b/main.py
index b83cabd..e01e57a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -725,14 +725,13 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise
 def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None):
     noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
 
-    if mask_hints is None:
-        x_t = (1 - mask_generate) * x_0 + mask_generate * noise
-    else:
-        mask = mask_generate * (1 - mask_hints)
-        x_t = (1 - mask) * x_0 + mask * noise
-
     one_iteration_prediction = deterministic(mask_generate)[:, None]
 
+    if mask_hints is not None:
+        mask_generate = mask_generate * (1 - mask_hints)
+
+    x_t = (1 - mask_generate) * x_0 + mask_generate * noise
+
     changed = True
 
     for it in range(nb_iterations_max):
@@ -1058,14 +1057,14 @@ def quiz_validation(
         for q in c_quizzes.split(args.inference_batch_size):
             record.append(
                 quiz_validation(
-                    models,
-                    q,
-                    local_device,
-                    nb_have_to_be_correct,
-                    nb_have_to_be_wrong,
-                    nb_mistakes_to_be_wrong,
-                    nb_hints=0,
-                    nb_runs=1,
+                    models=models,
+                    c_quizzes=q,
+                    local_device=local_device,
+                    nb_have_to_be_correct=nb_have_to_be_correct,
+                    nb_have_to_be_wrong=nb_have_to_be_wrong,
+                    nb_mistakes_to_be_wrong=nb_mistakes_to_be_wrong,
+                    nb_hints=nb_hints,
+                    nb_runs=nb_runs,
                 )
             )
 
@@ -1087,7 +1086,7 @@ def quiz_validation(
                 quad_mask=quad,
             )
 
-            sub_correct, sub_wrong = True, True
+            sub_correct, sub_wrong = False, True
             for _ in range(nb_runs):
                 if nb_hints == 0:
                     mask_hints = None
@@ -1118,6 +1117,9 @@ def quiz_validation(
         nb_correct += correct.long()
         nb_wrong += wrong.long()
 
+    # log_string(f"{nb_hints=} {nb_correct=}")
+    # log_string(f"{nb_hints=} {nb_wrong=}")
+
     to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
 
     wrong = torch.cat(record_wrong, dim=1)
@@ -1225,16 +1227,12 @@ def save_c_quizzes_with_scores(models, c_quizzes, nb, filename, solvable_only=Fa
                 models,
                 c_quizzes,
                 main_device,
-                nb_have_to_be_correct=1,
+                nb_have_to_be_correct=2,
                 nb_have_to_be_wrong=0,
                 nb_hints=0,
             )
             c_quizzes = c_quizzes[to_keep]
 
-        c_quizzes = c_quizzes[
-            torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[:nb]
-        ]
-
         for model in models:
             model = copy.deepcopy(model).to(main_device).eval()
             l.append(model_ae_proba_solutions(model, c_quizzes))