Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 6 Sep 2024 06:50:05 +0000 (08:50 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 6 Sep 2024 06:50:05 +0000 (08:50 +0200)
main.py

diff --git a/main.py b/main.py
index f609fd8..d1a1c8f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -774,6 +774,75 @@ def deterministic(mask_generate):
     return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long()
 
 
+######################################################################
+
+#
+# Given x_0 and t_0, t_1, ..., returns x_{t_0}, x_{t_1}, with
+#
+#    x_{t_k} ~ P(X_{t_k} | X_0=x_0)
+#
+
+
+def degrade_input_to_generate(x0, mask_generate, steps_nb_iterations):
+    noise = torch.randint(quiz_machine.problem.nb_colors, x0.size(), device=x0.device)
+
+    r = torch.rand(mask_generate.size(), device=mask_generate.device)
+
+    result = []
+
+    for n in steps_nb_iterations:
+        proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n
+        mask_erased = mask_generate * (r <= proba_erased[:, None]).long()
+        x = (1 - mask_erased) * x0 + mask_erased * noise
+        result.append(x)
+
+    return result
+
+
+######################################################################
+
+# Given x_t and a mas
+
+
+def targets_and_logits(model, input, mask_generate, prompt_noise=0.0):
+    d = deterministic(mask_generate)
+
+    probs_iterations = 0.1 ** torch.linspace(
+        0, 1, args.nb_diffusion_iterations, device=input.device
+    )
+
+    probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+    probs_iterations = probs_iterations.expand(input.size(0), -1)
+    dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
+
+    # N0 = dist.sample()
+    # N1 = N0 + 1
+    # N0 = (1 - d) * N0
+    # N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations
+
+    N0 = input.new_zeros(input.size(0))
+    N1 = dist.sample() + 1
+
+    targets, input = degrade_input_to_generate(input, mask_generate, (N0, N1))
+
+    if prompt_noise > 0:
+        mask_prompt_noise = (
+            torch.rand(input.size(), device=input.device) <= prompt_noise
+        ).long()
+        noise = torch.randint(
+            quiz_machine.problem.nb_colors, input.size(), device=input.device
+        )
+        noisy_input = (1 - mask_prompt_noise) * input + mask_prompt_noise * noise
+        input = (1 - mask_generate) * noisy_input + mask_generate * input
+
+    input_with_mask = NTC_channel_cat(input, mask_generate)
+    logits = model(input_with_mask)
+
+    return targets, logits
+
+
+######################################################################
+
 # This function returns a 2d tensor of same shape as low, full of
 # uniform random values in [0,1], such that, in every row, the values
 # corresponding to the True in low are all lesser than the values
@@ -840,7 +909,7 @@ def model_ae_proba_solutions(model, input, log_proba=False):
             mask_generate = quiz_machine.make_quiz_mask(
                 quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
-            targets, logits = targets_and_prediction(
+            targets, logits = targets_and_logits(
                 model, q, mask_generate, prompt_noise=args.prompt_noise
             )
             loss_per_token = F.cross_entropy(
@@ -866,7 +935,7 @@ def model_ae_argmax_nb_disagreements(model, input):
             mask_generate = quiz_machine.make_quiz_mask(
                 quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
-            targets, logits = targets_and_prediction(
+            targets, logits = targets_and_logits(
                 model, q, mask_generate, prompt_noise=args.prompt_noise
             )
 
@@ -893,7 +962,7 @@ def model_ae_argmax_predictions(model, input):
             mask_generate = quiz_machine.make_quiz_mask(
                 quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
-            targets, logits = targets_and_prediction(
+            targets, logits = targets_and_logits(
                 model, q, mask_generate, prompt_noise=args.prompt_noise
             )
 
@@ -907,64 +976,6 @@ def model_ae_argmax_predictions(model, input):
 ######################################################################
 
 
-def degrade_input_to_generate(input, mask_generate, steps_nb_iterations):
-    noise = torch.randint(
-        quiz_machine.problem.nb_colors, input.size(), device=input.device
-    )
-
-    r = torch.rand(mask_generate.size(), device=mask_generate.device)
-
-    result = []
-
-    for n in steps_nb_iterations:
-        proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n
-        mask_erased = mask_generate * (r <= proba_erased[:, None]).long()
-        x = (1 - mask_erased) * input + mask_erased * noise
-        result.append(x)
-
-    return result
-
-
-def targets_and_prediction(model, input, mask_generate, prompt_noise=0.0):
-    d = deterministic(mask_generate)
-
-    probs_iterations = 0.1 ** torch.linspace(
-        0, 1, args.nb_diffusion_iterations, device=input.device
-    )
-
-    probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
-    probs_iterations = probs_iterations.expand(input.size(0), -1)
-    dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
-
-    # N0 = dist.sample()
-    # N1 = N0 + 1
-    # N0 = (1 - d) * N0
-    # N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations
-
-    N0 = input.new_zeros(input.size(0))
-    N1 = dist.sample() + 1
-
-    targets, input = degrade_input_to_generate(input, mask_generate, (N0, N1))
-
-    if prompt_noise > 0:
-        mask_prompt_noise = (
-            torch.rand(input.size(), device=input.device) <= prompt_noise
-        ).long()
-        noise = torch.randint(
-            quiz_machine.problem.nb_colors, input.size(), device=input.device
-        )
-        noisy_input = (1 - mask_prompt_noise) * input + mask_prompt_noise * noise
-        input = (1 - mask_generate) * noisy_input + mask_generate * input
-
-    input_with_mask = NTC_channel_cat(input, mask_generate)
-    logits = model(input_with_mask)
-
-    return targets, logits
-
-
-######################################################################
-
-
 def run_ae_test(
     model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device, prefix=None
 ):
@@ -988,7 +999,7 @@ def run_ae_test(
             c_quizzes=c_quizzes,
             desc="test",
         ):
-            targets, logits = targets_and_prediction(model, input, mask_generate)
+            targets, logits = targets_and_logits(model, input, mask_generate)
             loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
             acc_test_loss += loss.item() * input.size(0)
             nb_test_samples += input.size(0)
@@ -1032,8 +1043,7 @@ def run_ae_test(
             f"{prefix}test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
         )
 
-        if prefix is None:
-            model.test_accuracy = nb_correct / nb_total
+        model.test_accuracy = nb_correct / nb_total
 
         # Save some images
 
@@ -1110,7 +1120,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
         if nb_train_samples % args.batch_size == 0:
             model.optimizer.zero_grad()
 
-        targets, logits = targets_and_prediction(
+        targets, logits = targets_and_logits(
             model, input, mask_generate, prompt_noise=args.prompt_noise
         )