Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 9 Sep 2024 06:31:08 +0000 (08:31 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 9 Sep 2024 06:31:08 +0000 (08:31 +0200)
main.py

diff --git a/main.py b/main.py
index 301c4f8..d914113 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1060,11 +1060,11 @@ def save_badness_statistics(
 ######################################################################
 
 
-def quiz_validation_1(models, c_quizzes, local_device):
-    nb_have_to_be_correct = args.nb_models // 2
+def quiz_validation_paris(models, c_quizzes, local_device):
+    nb_have_to_be_correct = 3
     nb_have_to_be_wrong = 1
 
-    nb_runs = 1
+    nb_runs = 3
     nb_mistakes_to_be_wrong = 5
 
     record_wrong = []
@@ -1093,47 +1093,7 @@ def quiz_validation_1(models, c_quizzes, local_device):
     return to_keep, wrong
 
 
-def quiz_validation_2(models, c_quizzes, local_device):
-    nb_have_to_be_correct = 3
-    nb_have_to_be_wrong = 1
-    nb_runs = 3
-
-    record_wrong = []
-    nb_correct, nb_wrong = 0, 0
-
-    for i, model in enumerate(models):
-        assert i == model.id  # a bit of paranoia
-        model = copy.deepcopy(model).to(local_device).eval()
-        log_probas_max, log_probas_min = None, None
-        for _ in range(nb_runs):
-            log_probas = model_ae_proba_solutions(
-                model, c_quizzes, log_probas=True, reduce=False
-            )
-            log_probas_max = (
-                log_probas
-                if log_probas_max is None
-                else log_probas.maximum(log_probas_max)
-            )
-            log_probas_min = (
-                log_probas
-                if log_probas_min is None
-                else log_probas.minimum(log_probas_min)
-            )
-        probas = log_probas.sum(dim=1).exp()
-        correct = (log_probas_min.exp() <= 0.75).long().sum(dim=1) == 0
-        wrong = (log_probas_min.exp() <= 0.1).long().sum(dim=1) >= 3
-        record_wrong.append(wrong[:, None])
-        nb_correct += correct.long()
-        nb_wrong += wrong.long()
-
-    to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
-
-    wrong = torch.cat(record_wrong, dim=1)
-
-    return to_keep, wrong
-
-
-def quiz_validation(models, c_quizzes, local_device):
+def quiz_validation_berne(models, c_quizzes, local_device):
     nb_have_to_be_correct = 3
     nb_have_to_be_wrong = 1
     nb_runs = 3
@@ -1201,7 +1161,9 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
             c_quizzes = c_quizzes[to_keep]
 
             if c_quizzes.size(0) > 0:
-                to_keep, record_wrong = quiz_validation(models, c_quizzes, local_device)
+                to_keep, record_wrong = quiz_validation_berne(
+                    models, c_quizzes, local_device
+                )
                 q = c_quizzes[to_keep]
 
                 if q.size(0) > 0: