From add68cf0dfba3c50ae90c6e60265028f0d0e5eb8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 1 Sep 2024 09:30:45 +0200 Subject: [PATCH] Update. --- main.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index 4b39b28..58d6287 100755 --- a/main.py +++ b/main.py @@ -97,6 +97,8 @@ parser.add_argument("--nb_models", type=int, default=5) parser.add_argument("--nb_diffusion_iterations", type=int, default=25) +parser.add_argument("--diffusion_noise_proba", type=float, default=0.05) + parser.add_argument("--min_succeed_to_validate", type=int, default=2) parser.add_argument("--max_fail_to_validate", type=int, default=3) @@ -1024,7 +1026,7 @@ def prioritized_rand(low): return y -def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50): +def ae_generate(model, input, mask_generate, nb_iterations_max=50): noise = torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) @@ -1043,7 +1045,7 @@ def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50): r = prioritized_rand(final != input) - mask_erased = mask_generate * (r <= noise_proba).long() + mask_erased = mask_generate * (r <= args.diffusion_noise_proba).long() mask_to_change = d * mask_generate + (1 - d) * mask_erased @@ -1090,7 +1092,7 @@ def model_ae_proba_solutions(model, input): nb_diffusion_iterations = 25 -def degrade_input(input, mask_generate, nb_iterations, noise_proba): +def degrade_input(input, mask_generate, nb_iterations): noise = torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) @@ -1100,7 +1102,7 @@ def degrade_input(input, mask_generate, nb_iterations, noise_proba): result = [] for n in nb_iterations: - proba_erased = 1 - (1 - noise_proba) ** n + proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n mask_erased = mask_generate * (r <= proba_erased[:, None]).long() x = (1 - mask_erased) * input + mask_erased * noise result.append(x) @@ -1121,9 +1123,7 @@ def targets_and_prediction(model, input, mask_generate): N0 = (1 - d) * N0 N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations - targets, input = degrade_input( - input, mask_generate, (0 * N1, N1), noise_proba=noise_proba - ) + targets, input = degrade_input(input, mask_generate, (0 * N1, N1)) input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) @@ -1168,7 +1168,9 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): ): targets = input.clone() result = ae_generate( - model, (1 - mask_generate) * input, mask_generate, noise_proba + model, + (1 - mask_generate) * input, + mask_generate, ) correct = (result == targets).min(dim=1).values.long() predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[ @@ -1233,7 +1235,7 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device): # def change_theta(theta_A, theta_B): # theta # result = ae_generate( - # model, (1 - mask_generate) * input, mask_generate, noise_proba + # model, (1 - mask_generate) * input, mask_generate # ) @@ -1282,8 +1284,6 @@ def one_ae_epoch( ###################################################################### -noise_proba = 0.05 - models = [] for i in range(args.nb_models): @@ -1318,6 +1318,11 @@ def c_quiz_criterion_diff(probas): return (probas.max(dim=1).values - probas.min(dim=1).values) >= 0.5 +def c_quiz_criterion_diff2(probas): + v = probas.sort(dim=1).values + return (v[:, -2] - v[:, 0]) >= 0.5 + + def c_quiz_criterion_two_certains(probas): return ((probas >= 0.99).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.5) @@ -1331,9 +1336,10 @@ def c_quiz_criterion_some(probas): def generate_ae_c_quizzes(models, local_device=main_device): criteria = [ c_quiz_criterion_one_good_one_bad, - # c_quiz_criterion_diff, - # c_quiz_criterion_two_certains, - # c_quiz_criterion_some, + c_quiz_criterion_diff, + # c_quiz_criterion_diff2, + c_quiz_criterion_two_certains, + c_quiz_criterion_some, ] for m in models: @@ -1351,7 +1357,7 @@ def generate_ae_c_quizzes(models, local_device=main_device): duration_max = 4 * 3600 - wanted_nb = 10000 + wanted_nb = 128 # 0000 nb_to_save = 128 with torch.autograd.no_grad(): @@ -1367,7 +1373,7 @@ def generate_ae_c_quizzes(models, local_device=main_device): log_string(f"bag_len {bl}") model = models[torch.randint(len(models), (1,)).item()] - result = ae_generate(model, template, mask_generate, noise_proba) + result = ae_generate(model, template, mask_generate) to_keep = quiz_machine.problem.trivial(result) == False result = result[to_keep] -- 2.39.5