Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 25 Sep 2024 20:21:56 +0000 (22:21 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 25 Sep 2024 20:21:56 +0000 (22:21 +0200)
main.py

diff --git a/main.py b/main.py
index 230453f..7af281c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -99,7 +99,7 @@ parser.add_argument("--proba_plasticity", type=float, default=0.0)
 
 parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
 
-parser.add_argument("--proba_prompt_noise", type=float, default=0.05)
+parser.add_argument("--proba_input_noise", type=float, default=0.05)
 
 parser.add_argument("--proba_hints", type=float, default=0.25)
 
@@ -319,10 +319,10 @@ def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
 ######################################################################
 
 
-def add_hints_imt(imt_set):
-    """Set every component of the mask to zero with probability
-    args.proba_hints, and for each component set to zero, copy the
-    corresponding value from the target into the input
+def add_hints_imt(imt_set, proba_hints):
+    """Set every component of the mask to zero with probability proba,
+    and for each component set to zero, copy the corresponding value
+    from the target into the input
 
     """
     input, masks, targets = imt_set.unbind(dim=1)
@@ -330,7 +330,7 @@ def add_hints_imt(imt_set):
     # t = h.sort(dim=1).values[:, args.nb_hints, None]
     # mask_hints = (h < t).long()
     mask_hints = (
-        torch.rand(input.size(), device=input.device) < args.proba_hints
+        torch.rand(input.size(), device=input.device) < proba_hints
     ).long() * masks
 
     masks = (1 - mask_hints) * masks
@@ -338,13 +338,15 @@ def add_hints_imt(imt_set):
     return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
 
-def add_noise_imt(imt_set):
-    """Replace every component of the input by a random value with
-    probability args.proba_prompt_noise."""
+def add_input_noise_imt(imt_set, proba_input_noise):
+    """Replace every component of the non-masked input by a random
+    value with probability proba.
+
+    """
     input, masks, targets = imt_set.unbind(dim=1)
     noise = problem.pure_noise(input.size(0), input.device)
     change = (1 - masks) * (
-        torch.rand(input.size(), device=input.device) < args.proba_prompt_noise
+        torch.rand(input.size(), device=input.device) < proba_input_noise
     ).long()
     input = (1 - change) * input + change * noise
     return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
@@ -393,7 +395,7 @@ def ae_predict(model, imt_set, local_device=main_device):
 
 
 def predict_the_four_grids(
-    model, input, with_noise=False, with_hints=False, local_device=main_device
+    model, input, proba_input_noise, proba_hints, local_device=main_device
 ):
     input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
     nb = input.size(0)
@@ -404,11 +406,11 @@ def predict_the_four_grids(
     input = (1 - masks) * targets
     imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
-    if with_hints:
-        imt_set = add_hints_imt(imt_set)
+    if proba_hints > 0:
+        imt_set = add_hints_imt(imt_set, proba_hints)
 
-    if with_noise:
-        imt_set = add_noise_imt(imt_set)
+    if proba_input_noise > 0:
+        imt_set = add_input_noise_imt(imt_set, proba_input_noise)
 
     result = ae_predict(model, imt_set, local_device=local_device)
     result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
@@ -512,9 +514,9 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
     # complexity, and hints in half to allow dealing with hints when
     # validating c quizzes
     b_p = samples_for_prediction_imt(q_p)
-    b_p = add_noise_imt(b_p)
+    b_p = add_input_noise_imt(b_p, args.proba_input_noise)
     half = torch.rand(b_p.size(0)) < 0.5
-    b_p[half] = add_hints_imt(b_p[half])
+    b_p[half] = add_hints_imt(b_p[half], args.proba_hints)
 
     # The other half are denoising examples for the generation
     b_g = samples_for_generation_imt(q_g)
@@ -661,8 +663,8 @@ def evaluate_quizzes(quizzes, models, with_hints, local_device):
         predicted = predict_the_four_grids(
             model=model,
             input=quizzes,
-            with_noise=False,
-            with_hints=with_hints,
+            proba_input_noise=0.0,
+            proba_hints=args.proba_hints,
             local_device=local_device,
         )
         nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, predicted)
@@ -748,7 +750,6 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
 
     duration = time.perf_counter() - start_time
 
-    log_string(f"generate_c_quizz_speed {int(3600 * nb_validated / duration)}/h")
     log_string(
         f"validation_rate {nb_validated} / {nb_generated} ({nb_validated*100/nb_generated:.02e}%)"
     )
@@ -1039,7 +1040,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             models, new_c_quizzes[:256], f"culture_c_quiz_{n_epoch:04d}.png"
         )
 
-        log_string(f"generated_c_quizzes {new_c_quizzes.size()}")
+        log_string(f"generated_c_quizzes {new_c_quizzes.size()}")
 
         train_c_quizzes = (
             new_c_quizzes