Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 26 Aug 2024 15:12:16 +0000 (17:12 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 26 Aug 2024 15:12:16 +0000 (17:12 +0200)
main.py

diff --git a/main.py b/main.py
index 27817e8..3c8b4b9 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -889,7 +889,7 @@ def NTC_channel_cat(*x):
 
 
 def deterministic(mask_generate):
-    return mask_generate.sum(dim=1) < mask_generate.size(1) // 2
+    return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long()
 
 
 def ae_generate(
@@ -902,8 +902,9 @@ def ae_generate(
 
     proba_erased = noise_proba
 
-    d = deterministic(mask_generate)
+    d = deterministic(mask_generate)[:, None]
     changed = True
+
     for it in range(nb_iterations_max):
         input_with_mask = NTC_channel_cat(input, mask_generate)
         logits = model(input_with_mask)
@@ -914,6 +915,7 @@ def ae_generate(
         mask_to_change = d * mask_generate + (1 - d) * mask_erased
 
         update = (1 - mask_to_change) * input + mask_to_change * dist.sample()
+
         if update.equal(input):
             log_string(f"converged at iteration {it}")
             break
@@ -992,7 +994,7 @@ def test_ae(local_device=main_device):
             if nb_train_samples % args.batch_size == 0:
                 model.optimizer.zero_grad()
 
-            d = deterministic = mask_generate
+            d = deterministic(mask_generate)
             p = probs_iterations.expand(input.size(0), -1)
             dist = torch.distributions.categorical.Categorical(probs=p)
             N0 = dist.sample()
@@ -1046,7 +1048,7 @@ def test_ae(local_device=main_device):
                 local_device,
                 "test",
             ):
-                d = deterministic = mask_generate
+                d = deterministic(mask_generate)
                 p = probs_iterations.expand(input.size(0), -1)
                 dist = torch.distributions.categorical.Categorical(probs=p)
                 N0 = dist.sample()
@@ -1080,9 +1082,9 @@ def test_ae(local_device=main_device):
                     input,
                     mask_generate,
                     n_epoch,
-                    nb_iterations,
                     noise_proba=noise_proba,
                 )
+
                 correct = (input == targets).min(dim=1).values.long()
                 predicted_parts = torch.tensor(quad_generate, device=input.device)
                 predicted_parts = predicted_parts[None, :].expand(input.size(0), -1)