From 54d7a0361253824542be95b7b47c88ce3056d8ce Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 11 Sep 2024 09:22:43 +0200 Subject: [PATCH] Update. --- main.py | 46 ++++++++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index ed83a5c..9e1726a 100755 --- a/main.py +++ b/main.py @@ -117,6 +117,8 @@ parser.add_argument("--prompt_noise", type=float, default=0.05) parser.add_argument("--nb_hints", type=int, default=5) +parser.add_argument("--nb_runs", type=int, default=5) + parser.add_argument("--dirty_debug", action="store_true", default=False) parser.add_argument("--test", type=str, default=None) @@ -1050,7 +1052,6 @@ def quiz_validation( c_quizzes, local_device, nb_have_to_be_correct=3, - nb_have_to_be_not_correct=0, nb_have_to_be_wrong=1, nb_mistakes_to_be_wrong=5, nb_hints=0, @@ -1069,6 +1070,8 @@ def quiz_validation( quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad, ) + + sub_correct, sub_wrong = True, True for _ in range(nb_runs): if nb_hints == 0: mask_hints = None @@ -1089,8 +1092,11 @@ def quiz_validation( ) nb_mistakes = (result != c_quizzes).long().sum(dim=1) - correct = correct & (nb_mistakes == 0) - wrong = wrong | (nb_mistakes >= nb_mistakes_to_be_wrong) + sub_correct = sub_correct | (nb_mistakes == 0) + sub_wrong = sub_wrong & (nb_mistakes >= nb_mistakes_to_be_wrong) + + correct = correct & sub_correct + wrong = wrong | sub_wrong record_wrong.append(wrong[:, None]) nb_correct += correct.long() @@ -1143,7 +1149,11 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): if c_quizzes.size(0) > 0: to_keep, record_wrong = quiz_validation( - models, c_quizzes, local_device, nb_hints=args.nb_hints + models, + c_quizzes, + local_device, + nb_hints=args.nb_hints, + nb_runs=args.nb_runs, ) q = c_quizzes[to_keep] @@ -1193,22 +1203,22 @@ def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device): def save_c_quizzes_with_scores(models, c_quizzes, nb, filename, solvable_only=False): l = [] - if solvable_only: - to_keep, _ = quiz_validation( - models, - c_quizzes, - main_device, - nb_have_to_be_correct=1, - nb_have_to_be_wrong=0, - nb_hints=0, - ) - c_quizzes = c_quizzes[to_keep] + with torch.autograd.no_grad(): + if solvable_only: + to_keep, _ = quiz_validation( + models, + c_quizzes, + main_device, + nb_have_to_be_correct=1, + nb_have_to_be_wrong=0, + nb_hints=0, + ) + c_quizzes = c_quizzes[to_keep] - c_quizzes = c_quizzes[ - torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[:nb] - ] + c_quizzes = c_quizzes[ + torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[:nb] + ] - with torch.autograd.no_grad(): for model in models: model = copy.deepcopy(model).to(main_device).eval() l.append(model_ae_proba_solutions(model, c_quizzes)) -- 2.39.5