Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 1 Sep 2024 07:30:45 +0000 (09:30 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 1 Sep 2024 07:30:45 +0000 (09:30 +0200)
main.py

diff --git a/main.py b/main.py
index 4b39b28..58d6287 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -97,6 +97,8 @@ parser.add_argument("--nb_models", type=int, default=5)
 
 parser.add_argument("--nb_diffusion_iterations", type=int, default=25)
 
+parser.add_argument("--diffusion_noise_proba", type=float, default=0.05)
+
 parser.add_argument("--min_succeed_to_validate", type=int, default=2)
 
 parser.add_argument("--max_fail_to_validate", type=int, default=3)
@@ -1024,7 +1026,7 @@ def prioritized_rand(low):
     return y
 
 
-def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50):
+def ae_generate(model, input, mask_generate, nb_iterations_max=50):
     noise = torch.randint(
         quiz_machine.problem.nb_colors, input.size(), device=input.device
     )
@@ -1043,7 +1045,7 @@ def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50):
 
         r = prioritized_rand(final != input)
 
-        mask_erased = mask_generate * (r <= noise_proba).long()
+        mask_erased = mask_generate * (r <= args.diffusion_noise_proba).long()
 
         mask_to_change = d * mask_generate + (1 - d) * mask_erased
 
@@ -1090,7 +1092,7 @@ def model_ae_proba_solutions(model, input):
 nb_diffusion_iterations = 25
 
 
-def degrade_input(input, mask_generate, nb_iterations, noise_proba):
+def degrade_input(input, mask_generate, nb_iterations):
     noise = torch.randint(
         quiz_machine.problem.nb_colors, input.size(), device=input.device
     )
@@ -1100,7 +1102,7 @@ def degrade_input(input, mask_generate, nb_iterations, noise_proba):
     result = []
 
     for n in nb_iterations:
-        proba_erased = 1 - (1 - noise_proba) ** n
+        proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n
         mask_erased = mask_generate * (r <= proba_erased[:, None]).long()
         x = (1 - mask_erased) * input + mask_erased * noise
         result.append(x)
@@ -1121,9 +1123,7 @@ def targets_and_prediction(model, input, mask_generate):
     N0 = (1 - d) * N0
     N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations
 
-    targets, input = degrade_input(
-        input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
-    )
+    targets, input = degrade_input(input, mask_generate, (0 * N1, N1))
 
     input_with_mask = NTC_channel_cat(input, mask_generate)
     logits = model(input_with_mask)
@@ -1168,7 +1168,9 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
         ):
             targets = input.clone()
             result = ae_generate(
-                model, (1 - mask_generate) * input, mask_generate, noise_proba
+                model,
+                (1 - mask_generate) * input,
+                mask_generate,
             )
             correct = (result == targets).min(dim=1).values.long()
             predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[
@@ -1233,7 +1235,7 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
         # def change_theta(theta_A, theta_B):
         # theta
         # result = ae_generate(
-        # model, (1 - mask_generate) * input, mask_generate, noise_proba
+        # model, (1 - mask_generate) * input, mask_generate
         # )
 
 
@@ -1282,8 +1284,6 @@ def one_ae_epoch(
 
 ######################################################################
 
-noise_proba = 0.05
-
 models = []
 
 for i in range(args.nb_models):
@@ -1318,6 +1318,11 @@ def c_quiz_criterion_diff(probas):
     return (probas.max(dim=1).values - probas.min(dim=1).values) >= 0.5
 
 
+def c_quiz_criterion_diff2(probas):
+    v = probas.sort(dim=1).values
+    return (v[:, -2] - v[:, 0]) >= 0.5
+
+
 def c_quiz_criterion_two_certains(probas):
     return ((probas >= 0.99).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.5)
 
@@ -1331,9 +1336,10 @@ def c_quiz_criterion_some(probas):
 def generate_ae_c_quizzes(models, local_device=main_device):
     criteria = [
         c_quiz_criterion_one_good_one_bad,
-        # c_quiz_criterion_diff,
-        # c_quiz_criterion_two_certains,
-        # c_quiz_criterion_some,
+        c_quiz_criterion_diff,
+        # c_quiz_criterion_diff2,
+        c_quiz_criterion_two_certains,
+        c_quiz_criterion_some,
     ]
 
     for m in models:
@@ -1351,7 +1357,7 @@ def generate_ae_c_quizzes(models, local_device=main_device):
 
     duration_max = 4 * 3600
 
-    wanted_nb = 10000
+    wanted_nb = 128  # 0000
     nb_to_save = 128
 
     with torch.autograd.no_grad():
@@ -1367,7 +1373,7 @@ def generate_ae_c_quizzes(models, local_device=main_device):
             log_string(f"bag_len {bl}")
 
             model = models[torch.randint(len(models), (1,)).item()]
-            result = ae_generate(model, template, mask_generate, noise_proba)
+            result = ae_generate(model, template, mask_generate)
 
             to_keep = quiz_machine.problem.trivial(result) == False
             result = result[to_keep]