Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 25 Aug 2024 13:55:54 +0000 (15:55 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 25 Aug 2024 13:55:54 +0000 (15:55 +0200)
main.py

diff --git a/main.py b/main.py
index 7bd25cf..ed36efb 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -894,19 +894,22 @@ def ae_generate(model, input, mask_generate, n_epoch, nb_iterations_max=50):
     )
     input = (1 - mask_generate) * input + mask_generate * noise
 
+    changed = True
     for it in range(nb_iterations_max):
         input_with_mask = NTC_channel_cat(input, mask_generate)
         logits = model(input_with_mask)
         dist = torch.distributions.categorical.Categorical(logits=logits)
-        pred_input = input.clone()
-        input = (1 - mask_generate) * input + mask_generate * dist.sample()
-        if (pred_input == input).min():
+        update = (1 - mask_generate) * input + mask_generate * dist.sample()
+        if update.equal(input):
             break
+        else:
+            changed = changed & (update != input).max(dim=1).values
+            input[changed] = update[changed]
 
     return input
 
 
-def degrade_input(input, mask_generate, *phis):
+def degrade_input(input, mask_generate, noise_levels):
     noise = torch.randint(
         quiz_machine.problem.nb_colors, input.size(), device=input.device
     )
@@ -915,7 +918,7 @@ def degrade_input(input, mask_generate, *phis):
 
     result = []
 
-    for phi in phis:
+    for phi in noise_levels:
         mask_diffusion_noise = mask_generate * (r <= phi).long()
         x = (1 - mask_diffusion_noise) * input + mask_diffusion_noise * noise
         result.append(x)
@@ -968,8 +971,15 @@ def test_ae(local_device=main_device):
             if nb_train_samples % args.batch_size == 0:
                 model.optimizer.zero_grad()
 
-            phi = torch.rand((input.size(0), 1), device=input.device).clamp(min=0.25)
-            targets, input = degrade_input(input, mask_generate, phi - 0.25, phi)
+            deterministic = (
+                mask_generate.sum(dim=1, keepdim=True) < mask_generate.size(1) // 2
+            ).long()
+
+            k = torch.randint(3, (input.size(0), 1), device=input.device)
+            phi0 = deterministic * 0 + (1 - deterministic) * (k / 3)
+            phi1 = deterministic * 1 + (1 - deterministic) * ((k + 1) / 3)
+
+            targets, input = degrade_input(input, mask_generate, (phi0, phi1))
             input_with_mask = NTC_channel_cat(input, mask_generate)
             logits = model(input_with_mask)
             loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
@@ -1000,10 +1010,17 @@ def test_ae(local_device=main_device):
                 local_device,
                 "test",
             ):
-                phi = torch.rand((input.size(0), 1), device=input.device).clamp(
-                    min=0.25
-                )
-                targets, input = degrade_input(input, mask_generate, phi - 0.25, phi)
+                deterministic = (
+                    mask_generate.sum(dim=1, keepdim=True) < mask_generate.size(1) // 2
+                ).long()
+
+                k = torch.randint(3, (input.size(0), 1), device=input.device)
+                phi0 = deterministic * 0 + (1 - deterministic) * (k / 3)
+                phi1 = deterministic * 1 + (1 - deterministic) * ((k + 1) / 3)
+
+                phi = torch.rand((input.size(0), 1), device=input.device)
+                phi = deterministic + (1 - deterministic) * phi
+                targets, input = degrade_input(input, mask_generate, (phi0, phi1))
                 input_with_mask = NTC_channel_cat(input, mask_generate)
                 logits = model(input_with_mask)
                 loss = NTC_masked_cross_entropy(logits, targets, mask_loss)