From 617f63649d2bb1d6e975e98c74f242cd8a66ac91 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 24 Aug 2024 18:21:31 +0200 Subject: [PATCH] Update. --- main.py | 42 +++++++++++++++++------------------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/main.py b/main.py index 9fe01ab..7bd25cf 100755 --- a/main.py +++ b/main.py @@ -888,18 +888,20 @@ def NTC_channel_cat(*x): return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2) -def ae_generate(model, input, mask_generate, n_epoch, nb_iterations): +def ae_generate(model, input, mask_generate, n_epoch, nb_iterations_max=50): noise = torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) input = (1 - mask_generate) * input + mask_generate * noise - for it in range(nb_iterations): - rho = input.new_full((input.size(0),), nb_iterations - 1 - it) - input_with_mask = NTC_channel_cat(input, mask_generate, rho[:, None]) + for it in range(nb_iterations_max): + input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) dist = torch.distributions.categorical.Categorical(logits=logits) + pred_input = input.clone() input = (1 - mask_generate) * input + mask_generate * dist.sample() + if (pred_input == input).min(): + break return input @@ -949,10 +951,6 @@ def test_ae(local_device=main_device): nb_iterations = 10 - def phi(rho): - # return (rho / nb_iterations)**2 - return rho / nb_iterations - for n_epoch in range(args.nb_epochs): # ---------------------- # Train @@ -970,13 +968,11 @@ def test_ae(local_device=main_device): if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() - rho = torch.randint(nb_iterations, (input.size(0), 1), device=input.device) - - targets, input = degrade_input(input, mask_generate, phi(rho), phi(rho + 1)) - - input_with_mask = NTC_channel_cat(input, mask_generate, rho) - output = model(input_with_mask) - loss = NTC_masked_cross_entropy(output, targets, mask_loss) + phi = torch.rand((input.size(0), 1), device=input.device).clamp(min=0.25) + targets, input = degrade_input(input, mask_generate, phi - 0.25, phi) + input_with_mask = NTC_channel_cat(input, mask_generate) + logits = model(input_with_mask) + loss = NTC_masked_cross_entropy(logits, targets, mask_loss) acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) @@ -1004,17 +1000,13 @@ def test_ae(local_device=main_device): local_device, "test", ): - rho = torch.randint( - nb_iterations, (input.size(0), 1), device=input.device + phi = torch.rand((input.size(0), 1), device=input.device).clamp( + min=0.25 ) - - targets, input = degrade_input( - input, mask_generate, phi(rho), phi(rho + 1) - ) - - input_with_mask = NTC_channel_cat(input, mask_generate, rho) - output = model(input_with_mask) - loss = NTC_masked_cross_entropy(output, targets, mask_loss) + targets, input = degrade_input(input, mask_generate, phi - 0.25, phi) + input_with_mask = NTC_channel_cat(input, mask_generate) + logits = model(input_with_mask) + loss = NTC_masked_cross_entropy(logits, targets, mask_loss) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) -- 2.39.5