From e929b179c58768c1e36d2791a21520bda23e18a8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 11 Sep 2024 15:56:39 +0200 Subject: [PATCH] Update. --- main.py | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/main.py b/main.py index b83cabd..e01e57a 100755 --- a/main.py +++ b/main.py @@ -725,14 +725,13 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None): noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) - if mask_hints is None: - x_t = (1 - mask_generate) * x_0 + mask_generate * noise - else: - mask = mask_generate * (1 - mask_hints) - x_t = (1 - mask) * x_0 + mask * noise - one_iteration_prediction = deterministic(mask_generate)[:, None] + if mask_hints is not None: + mask_generate = mask_generate * (1 - mask_hints) + + x_t = (1 - mask_generate) * x_0 + mask_generate * noise + changed = True for it in range(nb_iterations_max): @@ -1058,14 +1057,14 @@ def quiz_validation( for q in c_quizzes.split(args.inference_batch_size): record.append( quiz_validation( - models, - q, - local_device, - nb_have_to_be_correct, - nb_have_to_be_wrong, - nb_mistakes_to_be_wrong, - nb_hints=0, - nb_runs=1, + models=models, + c_quizzes=q, + local_device=local_device, + nb_have_to_be_correct=nb_have_to_be_correct, + nb_have_to_be_wrong=nb_have_to_be_wrong, + nb_mistakes_to_be_wrong=nb_mistakes_to_be_wrong, + nb_hints=nb_hints, + nb_runs=nb_runs, ) ) @@ -1087,7 +1086,7 @@ def quiz_validation( quad_mask=quad, ) - sub_correct, sub_wrong = True, True + sub_correct, sub_wrong = False, True for _ in range(nb_runs): if nb_hints == 0: mask_hints = None @@ -1118,6 +1117,9 @@ def quiz_validation( nb_correct += correct.long() nb_wrong += wrong.long() + # log_string(f"{nb_hints=} {nb_correct=}") + # log_string(f"{nb_hints=} {nb_wrong=}") + to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong) wrong = torch.cat(record_wrong, dim=1) @@ -1225,16 +1227,12 @@ def save_c_quizzes_with_scores(models, c_quizzes, nb, filename, solvable_only=Fa models, c_quizzes, main_device, - nb_have_to_be_correct=1, + nb_have_to_be_correct=2, nb_have_to_be_wrong=0, nb_hints=0, ) c_quizzes = c_quizzes[to_keep] - c_quizzes = c_quizzes[ - torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[:nb] - ] - for model in models: model = copy.deepcopy(model).to(main_device).eval() l.append(model_ae_proba_solutions(model, c_quizzes)) -- 2.39.5