Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 24 Aug 2024 16:21:31 +0000 (18:21 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 24 Aug 2024 16:21:31 +0000 (18:21 +0200)
main.py

diff --git a/main.py b/main.py
index 9fe01ab..7bd25cf 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -888,18 +888,20 @@ def NTC_channel_cat(*x):
     return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
 
 
-def ae_generate(model, input, mask_generate, n_epoch, nb_iterations):
+def ae_generate(model, input, mask_generate, n_epoch, nb_iterations_max=50):
     noise = torch.randint(
         quiz_machine.problem.nb_colors, input.size(), device=input.device
     )
     input = (1 - mask_generate) * input + mask_generate * noise
 
-    for it in range(nb_iterations):
-        rho = input.new_full((input.size(0),), nb_iterations - 1 - it)
-        input_with_mask = NTC_channel_cat(input, mask_generate, rho[:, None])
+    for it in range(nb_iterations_max):
+        input_with_mask = NTC_channel_cat(input, mask_generate)
         logits = model(input_with_mask)
         dist = torch.distributions.categorical.Categorical(logits=logits)
+        pred_input = input.clone()
         input = (1 - mask_generate) * input + mask_generate * dist.sample()
+        if (pred_input == input).min():
+            break
 
     return input
 
@@ -949,10 +951,6 @@ def test_ae(local_device=main_device):
 
     nb_iterations = 10
 
-    def phi(rho):
-        # return (rho / nb_iterations)**2
-        return rho / nb_iterations
-
     for n_epoch in range(args.nb_epochs):
         # ----------------------
         # Train
@@ -970,13 +968,11 @@ def test_ae(local_device=main_device):
             if nb_train_samples % args.batch_size == 0:
                 model.optimizer.zero_grad()
 
-            rho = torch.randint(nb_iterations, (input.size(0), 1), device=input.device)
-
-            targets, input = degrade_input(input, mask_generate, phi(rho), phi(rho + 1))
-
-            input_with_mask = NTC_channel_cat(input, mask_generate, rho)
-            output = model(input_with_mask)
-            loss = NTC_masked_cross_entropy(output, targets, mask_loss)
+            phi = torch.rand((input.size(0), 1), device=input.device).clamp(min=0.25)
+            targets, input = degrade_input(input, mask_generate, phi - 0.25, phi)
+            input_with_mask = NTC_channel_cat(input, mask_generate)
+            logits = model(input_with_mask)
+            loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
             acc_train_loss += loss.item() * input.size(0)
             nb_train_samples += input.size(0)
 
@@ -1004,17 +1000,13 @@ def test_ae(local_device=main_device):
                 local_device,
                 "test",
             ):
-                rho = torch.randint(
-                    nb_iterations, (input.size(0), 1), device=input.device
+                phi = torch.rand((input.size(0), 1), device=input.device).clamp(
+                    min=0.25
                 )
-
-                targets, input = degrade_input(
-                    input, mask_generate, phi(rho), phi(rho + 1)
-                )
-
-                input_with_mask = NTC_channel_cat(input, mask_generate, rho)
-                output = model(input_with_mask)
-                loss = NTC_masked_cross_entropy(output, targets, mask_loss)
+                targets, input = degrade_input(input, mask_generate, phi - 0.25, phi)
+                input_with_mask = NTC_channel_cat(input, mask_generate)
+                logits = model(input_with_mask)
+                loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
                 acc_test_loss += loss.item() * input.size(0)
                 nb_test_samples += input.size(0)