From ac2c5648402fa72e1a54ea610bc687cddaf1e095 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 26 Aug 2024 17:12:16 +0200 Subject: [PATCH] Update. --- main.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 27817e8..3c8b4b9 100755 --- a/main.py +++ b/main.py @@ -889,7 +889,7 @@ def NTC_channel_cat(*x): def deterministic(mask_generate): - return mask_generate.sum(dim=1) < mask_generate.size(1) // 2 + return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long() def ae_generate( @@ -902,8 +902,9 @@ def ae_generate( proba_erased = noise_proba - d = deterministic(mask_generate) + d = deterministic(mask_generate)[:, None] changed = True + for it in range(nb_iterations_max): input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) @@ -914,6 +915,7 @@ def ae_generate( mask_to_change = d * mask_generate + (1 - d) * mask_erased update = (1 - mask_to_change) * input + mask_to_change * dist.sample() + if update.equal(input): log_string(f"converged at iteration {it}") break @@ -992,7 +994,7 @@ def test_ae(local_device=main_device): if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() - d = deterministic = mask_generate + d = deterministic(mask_generate) p = probs_iterations.expand(input.size(0), -1) dist = torch.distributions.categorical.Categorical(probs=p) N0 = dist.sample() @@ -1046,7 +1048,7 @@ def test_ae(local_device=main_device): local_device, "test", ): - d = deterministic = mask_generate + d = deterministic(mask_generate) p = probs_iterations.expand(input.size(0), -1) dist = torch.distributions.categorical.Categorical(probs=p) N0 = dist.sample() @@ -1080,9 +1082,9 @@ def test_ae(local_device=main_device): input, mask_generate, n_epoch, - nb_iterations, noise_proba=noise_proba, ) + correct = (input == targets).min(dim=1).values.long() predicted_parts = torch.tensor(quad_generate, device=input.device) predicted_parts = predicted_parts[None, :].expand(input.size(0), -1) -- 2.39.5