From 0db83482b02a0700dd27db2f30b849afcb188008 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 26 Aug 2024 14:56:55 +0200 Subject: [PATCH] Update. --- main.py | 84 ++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 60 insertions(+), 24 deletions(-) diff --git a/main.py b/main.py index 3374a5b..b9ec213 100755 --- a/main.py +++ b/main.py @@ -888,18 +888,32 @@ def NTC_channel_cat(*x): return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2) -def ae_generate(model, input, mask_generate, n_epoch, nb_iterations_max=50): +def deterministic(mask_generate): + return mask_generate.sum(dim=1) < mask_generate.size(1) // 2 + + +def ae_generate( + model, input, mask_generate, n_epoch, noise_proba, nb_iterations_max=50 +): noise = torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) input = (1 - mask_generate) * input + mask_generate * noise + proba_erased = noise_proba + + d = deterministic(mask_generate) changed = True for it in range(nb_iterations_max): input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) dist = torch.distributions.categorical.Categorical(logits=logits) - update = (1 - mask_generate) * input + mask_generate * dist.sample() + + r = torch.rand(mask_generate.size(), device=mask_generate.device) + mask_erased = mask_generate * (r <= proba_erased).long() + mask_to_change = d * mask_generate + (1 - d) * mask_erased + + update = (1 - mask_to_change) * input + mask_to_change * dist.sample() if update.equal(input): break else: @@ -909,7 +923,7 @@ def ae_generate(model, input, mask_generate, n_epoch, nb_iterations_max=50): return input -def degrade_input(input, mask_generate, nb_iterations, noise_proba=0.35): +def degrade_input(input, mask_generate, nb_iterations, noise_proba): noise = torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) @@ -953,7 +967,10 @@ def test_ae(local_device=main_device): model.to(local_device).train() optimizer_to(model.optimizer, local_device) - nb_iterations = 10 + nb_iterations = 25 + probs_iterations = torch.arange(nb_iterations, device=main_device) + probs_iterations = 0.1 ** (probs_iterations / nb_iterations) + probs_iterations = probs_iterations[None, :] / probs_iterations.sum() for n_epoch in range(args.nb_epochs): # ---------------------- @@ -962,6 +979,8 @@ def test_ae(local_device=main_device): model.train() nb_train_samples, acc_train_loss = 0, 0.0 + noise_proba = 0.05 + for input, mask_generate, mask_loss in ae_batches( quiz_machine, args.nb_train_samples, @@ -972,19 +991,30 @@ def test_ae(local_device=main_device): if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() - deterministic = ( - mask_generate.sum(dim=1) < mask_generate.size(1) // 4 - ).long() - - N0 = torch.randint(nb_iterations, (input.size(0),), device=input.device) + d = deterministic = mask_generate + p = probs_iterations.expand(input.size(0), -1) + dist = torch.distributions.categorical.Categorical(probs=p) + N0 = dist.sample() N1 = N0 + 1 + N0 = (1 - d) * N0 + N1 = (1 - d) * N1 + d * nb_iterations - N0 = (1 - deterministic) * N0 - N1 = deterministic * nb_iterations + (1 - deterministic) * N1 + targets, input = degrade_input( + input, mask_generate, (0 * N1, N1), noise_proba=noise_proba + ) - # print(f"{N0.size()=} {N1.size()=} {deterministic.size()=}") + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # for n in ["input", "targets"]: + # filename = f"{n}.png" + # quiz_machine.problem.save_quizzes_as_image( + # args.result_dir, + # filename, + # quizzes=locals()[n], + # ) + # log_string(f"wrote {filename}") + # time.sleep(1000) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - targets, input = degrade_input(input, mask_generate, (N0, N1)) input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) loss = NTC_masked_cross_entropy(logits, targets, mask_loss) @@ -1015,17 +1045,16 @@ def test_ae(local_device=main_device): local_device, "test", ): - deterministic = ( - mask_generate.sum(dim=1) < mask_generate.size(1) // 4 - ).long() - - N0 = torch.randint(nb_iterations, (input.size(0),), device=input.device) + d = deterministic = mask_generate + p = probs_iterations.expand(input.size(0), -1) + dist = torch.distributions.categorical.Categorical(probs=p) + N0 = dist.sample() N1 = N0 + 1 - - N0 = (1 - deterministic) * N0 - N1 = deterministic * nb_iterations + (1 - deterministic) * N1 - - targets, input = degrade_input(input, mask_generate, (N0, N1)) + N0 = (1 - d) * N0 + N1 = (1 - d) * N1 + d * nb_iterations + targets, input = degrade_input( + input, mask_generate, (0 * N1, N1), noise_proba=noise_proba + ) input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) loss = NTC_masked_cross_entropy(logits, targets, mask_loss) @@ -1045,7 +1074,14 @@ def test_ae(local_device=main_device): ) targets = input.clone() - input = ae_generate(model, input, mask_generate, n_epoch, nb_iterations) + input = ae_generate( + model, + input, + mask_generate, + n_epoch, + nb_iterations, + noise_proba=noise_proba, + ) correct = (input == targets).min(dim=1).values.long() predicted_parts = torch.tensor(quad_generate, device=input.device) predicted_parts = predicted_parts[None, :].expand(input.size(0), -1) -- 2.39.5