Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 26 Aug 2024 12:56:55 +0000 (14:56 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 26 Aug 2024 12:56:55 +0000 (14:56 +0200)
main.py

diff --git a/main.py b/main.py
index 3374a5b..b9ec213 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -888,18 +888,32 @@ def NTC_channel_cat(*x):
     return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
 
 
-def ae_generate(model, input, mask_generate, n_epoch, nb_iterations_max=50):
+def deterministic(mask_generate):
+    return mask_generate.sum(dim=1) < mask_generate.size(1) // 2
+
+
+def ae_generate(
+    model, input, mask_generate, n_epoch, noise_proba, nb_iterations_max=50
+):
     noise = torch.randint(
         quiz_machine.problem.nb_colors, input.size(), device=input.device
     )
     input = (1 - mask_generate) * input + mask_generate * noise
 
+    proba_erased = noise_proba
+
+    d = deterministic(mask_generate)
     changed = True
     for it in range(nb_iterations_max):
         input_with_mask = NTC_channel_cat(input, mask_generate)
         logits = model(input_with_mask)
         dist = torch.distributions.categorical.Categorical(logits=logits)
-        update = (1 - mask_generate) * input + mask_generate * dist.sample()
+
+        r = torch.rand(mask_generate.size(), device=mask_generate.device)
+        mask_erased = mask_generate * (r <= proba_erased).long()
+        mask_to_change = d * mask_generate + (1 - d) * mask_erased
+
+        update = (1 - mask_to_change) * input + mask_to_change * dist.sample()
         if update.equal(input):
             break
         else:
@@ -909,7 +923,7 @@ def ae_generate(model, input, mask_generate, n_epoch, nb_iterations_max=50):
     return input
 
 
-def degrade_input(input, mask_generate, nb_iterations, noise_proba=0.35):
+def degrade_input(input, mask_generate, nb_iterations, noise_proba):
     noise = torch.randint(
         quiz_machine.problem.nb_colors, input.size(), device=input.device
     )
@@ -953,7 +967,10 @@ def test_ae(local_device=main_device):
     model.to(local_device).train()
     optimizer_to(model.optimizer, local_device)
 
-    nb_iterations = 10
+    nb_iterations = 25
+    probs_iterations = torch.arange(nb_iterations, device=main_device)
+    probs_iterations = 0.1 ** (probs_iterations / nb_iterations)
+    probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
 
     for n_epoch in range(args.nb_epochs):
         # ----------------------
@@ -962,6 +979,8 @@ def test_ae(local_device=main_device):
         model.train()
         nb_train_samples, acc_train_loss = 0, 0.0
 
+        noise_proba = 0.05
+
         for input, mask_generate, mask_loss in ae_batches(
             quiz_machine,
             args.nb_train_samples,
@@ -972,19 +991,30 @@ def test_ae(local_device=main_device):
             if nb_train_samples % args.batch_size == 0:
                 model.optimizer.zero_grad()
 
-            deterministic = (
-                mask_generate.sum(dim=1) < mask_generate.size(1) // 4
-            ).long()
-
-            N0 = torch.randint(nb_iterations, (input.size(0),), device=input.device)
+            d = deterministic = mask_generate
+            p = probs_iterations.expand(input.size(0), -1)
+            dist = torch.distributions.categorical.Categorical(probs=p)
+            N0 = dist.sample()
             N1 = N0 + 1
+            N0 = (1 - d) * N0
+            N1 = (1 - d) * N1 + d * nb_iterations
 
-            N0 = (1 - deterministic) * N0
-            N1 = deterministic * nb_iterations + (1 - deterministic) * N1
+            targets, input = degrade_input(
+                input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
+            )
 
-            # print(f"{N0.size()=} {N1.size()=} {deterministic.size()=}")
+            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+            # for n in ["input", "targets"]:
+            # filename = f"{n}.png"
+            # quiz_machine.problem.save_quizzes_as_image(
+            # args.result_dir,
+            # filename,
+            # quizzes=locals()[n],
+            # )
+            # log_string(f"wrote {filename}")
+            # time.sleep(1000)
+            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-            targets, input = degrade_input(input, mask_generate, (N0, N1))
             input_with_mask = NTC_channel_cat(input, mask_generate)
             logits = model(input_with_mask)
             loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
@@ -1015,17 +1045,16 @@ def test_ae(local_device=main_device):
                 local_device,
                 "test",
             ):
-                deterministic = (
-                    mask_generate.sum(dim=1) < mask_generate.size(1) // 4
-                ).long()
-
-                N0 = torch.randint(nb_iterations, (input.size(0),), device=input.device)
+                d = deterministic = mask_generate
+                p = probs_iterations.expand(input.size(0), -1)
+                dist = torch.distributions.categorical.Categorical(probs=p)
+                N0 = dist.sample()
                 N1 = N0 + 1
-
-                N0 = (1 - deterministic) * N0
-                N1 = deterministic * nb_iterations + (1 - deterministic) * N1
-
-                targets, input = degrade_input(input, mask_generate, (N0, N1))
+                N0 = (1 - d) * N0
+                N1 = (1 - d) * N1 + d * nb_iterations
+                targets, input = degrade_input(
+                    input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
+                )
                 input_with_mask = NTC_channel_cat(input, mask_generate)
                 logits = model(input_with_mask)
                 loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
@@ -1045,7 +1074,14 @@ def test_ae(local_device=main_device):
                 )
 
                 targets = input.clone()
-                input = ae_generate(model, input, mask_generate, n_epoch, nb_iterations)
+                input = ae_generate(
+                    model,
+                    input,
+                    mask_generate,
+                    n_epoch,
+                    nb_iterations,
+                    noise_proba=noise_proba,
+                )
                 correct = (input == targets).min(dim=1).values.long()
                 predicted_parts = torch.tensor(quad_generate, device=input.device)
                 predicted_parts = predicted_parts[None, :].expand(input.size(0), -1)