From 02960377669d1994c4b69af1e164827664a890d8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 9 Sep 2024 08:31:08 +0200 Subject: [PATCH] Update. --- main.py | 52 +++++++--------------------------------------------- 1 file changed, 7 insertions(+), 45 deletions(-) diff --git a/main.py b/main.py index 301c4f8..d914113 100755 --- a/main.py +++ b/main.py @@ -1060,11 +1060,11 @@ def save_badness_statistics( ###################################################################### -def quiz_validation_1(models, c_quizzes, local_device): - nb_have_to_be_correct = args.nb_models // 2 +def quiz_validation_paris(models, c_quizzes, local_device): + nb_have_to_be_correct = 3 nb_have_to_be_wrong = 1 - nb_runs = 1 + nb_runs = 3 nb_mistakes_to_be_wrong = 5 record_wrong = [] @@ -1093,47 +1093,7 @@ def quiz_validation_1(models, c_quizzes, local_device): return to_keep, wrong -def quiz_validation_2(models, c_quizzes, local_device): - nb_have_to_be_correct = 3 - nb_have_to_be_wrong = 1 - nb_runs = 3 - - record_wrong = [] - nb_correct, nb_wrong = 0, 0 - - for i, model in enumerate(models): - assert i == model.id # a bit of paranoia - model = copy.deepcopy(model).to(local_device).eval() - log_probas_max, log_probas_min = None, None - for _ in range(nb_runs): - log_probas = model_ae_proba_solutions( - model, c_quizzes, log_probas=True, reduce=False - ) - log_probas_max = ( - log_probas - if log_probas_max is None - else log_probas.maximum(log_probas_max) - ) - log_probas_min = ( - log_probas - if log_probas_min is None - else log_probas.minimum(log_probas_min) - ) - probas = log_probas.sum(dim=1).exp() - correct = (log_probas_min.exp() <= 0.75).long().sum(dim=1) == 0 - wrong = (log_probas_min.exp() <= 0.1).long().sum(dim=1) >= 3 - record_wrong.append(wrong[:, None]) - nb_correct += correct.long() - nb_wrong += wrong.long() - - to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong) - - wrong = torch.cat(record_wrong, dim=1) - - return to_keep, wrong - - -def quiz_validation(models, c_quizzes, local_device): +def quiz_validation_berne(models, c_quizzes, local_device): nb_have_to_be_correct = 3 nb_have_to_be_wrong = 1 nb_runs = 3 @@ -1201,7 +1161,9 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): c_quizzes = c_quizzes[to_keep] if c_quizzes.size(0) > 0: - to_keep, record_wrong = quiz_validation(models, c_quizzes, local_device) + to_keep, record_wrong = quiz_validation_berne( + models, c_quizzes, local_device + ) q = c_quizzes[to_keep] if q.size(0) > 0: -- 2.39.5