Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 17 Sep 2024 17:51:36 +0000 (19:51 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 17 Sep 2024 17:51:36 +0000 (19:51 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 51e0fa2..9525bdd 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -327,10 +327,6 @@ quiz_machine = quiz_machine.QuizMachine(
 )
 
 
-def mu_T_sampler(shape, device="cpu"):
-    return torch.randint(quiz_machine.problem.nb_colors, shape, device=device)
-
-
 diffuser = diffusion.Diffuser(
     mu_T_sampler, args.diffusion_nb_iterations, args.diffusion_proba_corruption
 )
@@ -397,22 +393,27 @@ def masked_cross_entropy(output, targets, masks):
 
 ######################################################################
 
+
+def add_hints(masks, fraction_with_hints):
+    if fraction_with_hints > 0:
+        h = torch.rand(masks.size(), device=masks.device) * masks
+        mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints
+        v = torch.rand(masks.size(0), device=masks.device)[:, None]
+        mask_hints = mask_hints * (v < fraction_with_hints).long()
+        return (1 - mask_hints) * masks
+    else:
+        return masks
+
+
 # IMT for input / masks / target
 
 
-def IMT_batch_prediction(input, proba_hints=0.0):
+def batch_prediction_imt(input, fraction_with_hints=0.0):
     nb = input.size(0)
     masks = input.new_zeros(input.size())
     u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
     masks.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
-
-    if proba_hints > 0:
-        h = torch.rand(input.size(), device=input.device) * masks
-        mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints
-        v = torch.rand(nb, device=input.device)[:, None]
-        mask_hints = mask_hints * (v < proba_hints).long()
-        masks = (1 - mask_hints) * masks
-
+    masks = add_hints(masks, fraction_with_hints)
     # noise = quiz_machine.problem.pure_noise(nb, input.device)
     targets = input
     input = (1 - masks) * targets  # + masks * noise
@@ -444,10 +445,32 @@ def predict(model, imt_set, local_device=main_device):
     return torch.cat(record)
 
 
+def predict_full(model, input, fraction_with_hints=0.0, local_device=main_device):
+    boy_that_s_ugly = input.view(input.size(0), 4, -1)[:, :, 0].clone()
+    input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
+    nb = input.size(0)
+    masks = input.new_zeros(input.size())
+    u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4)
+    masks.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
+    masks_with_hints = add_hints(masks, fraction_with_hints)
+    targets = input
+    input = (1 - masks_with_hints) * targets
+    imt_set = torch.cat(
+        [input[:, None], masks_with_hints[:, None], targets[:, None]], dim=1
+    )
+
+    result = predict(model, imt_set, local_device=local_device)
+    result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
+
+    result.view(result.size(0), 4, -1)[:, :, 0] = boy_that_s_ugly
+
+    return result
+
+
 ######################################################################
 
 
-def IMT_batch_generation(input):
+def batch_generation_imt(input):
     nb = input.size(0)
     probs_iterations = 0.1 ** torch.linspace(
         0, 1, args.diffusion_nb_iterations, device=input.device
@@ -516,16 +539,6 @@ def generate(model, nb, local_device=main_device):
 
 
 def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
-    if train:
-        label = "train"
-        model.train().to(local_device)
-        optimizer_to(model.optimizer, local_device)
-    else:
-        label = "test"
-        model.eval().to(local_device)
-
-    nb_samples, acc_loss = 0, 0.0
-
     quizzes = quiz_machine.quiz_set(
         args.nb_train_samples if train else args.nb_test_samples,
         c_quizzes,
@@ -535,11 +548,24 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
     q1, q2 = quizzes.to(local_device).chunk(2)
 
     imt_set = torch.cat(
-        [IMT_batch_prediction(q1, proba_hints=0.5), IMT_batch_generation(q2)]
+        [
+            batch_prediction_imt(q1, fraction_with_hints=0.5),
+            batch_generation_imt(q2),
+        ]
     )
 
     imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
 
+    if train:
+        label = "train"
+        model.train().to(local_device)
+        optimizer_to(model.optimizer, local_device)
+    else:
+        label = "test"
+        model.eval().to(local_device)
+
+    nb_samples, acc_loss = 0, 0.0
+
     for imt in tqdm.tqdm(
         imt_set.split(args.physical_batch_size),
         dynamic_ncols=True,
@@ -574,10 +600,23 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=True)
     one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=False)
 
+    #!!!!!!!!!!!!!!!!!!!!!!!!!
+    quizzes = quiz_machine.quiz_set(25, c_quizzes, args.c_quiz_multiplier).to(
+        local_device
+    )
+    quiz_machine.problem.save_quizzes_as_image(
+        args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes
+    )
+    result = predict_full(model, quizzes, local_device=local_device)
+    quiz_machine.problem.save_quizzes_as_image(
+        args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result
+    )
+    #!!!!!!!!!!!!!!!!!!!!!!!!!
+
     # predict
 
     quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier)
-    imt_set = IMT_batch_prediction(quizzes.to(local_device))
+    imt_set = batch_prediction_imt(quizzes.to(local_device))
     result = predict(model, imt_set, local_device=local_device).to("cpu")
     masks = imt_set[:, 1].to("cpu")
 
@@ -638,7 +677,7 @@ for i in range(args.nb_models):
 ######################################################################
 
 
-def quiz_validation(
+def quiz_validation_(
     models,
     c_quizzes,
     local_device,
index 781c1cf..dfedbf5 100755 (executable)
@@ -222,6 +222,8 @@ class QuizMachine:
         i = torch.randperm(quizzes.size(0), device=quizzes.device)
         quizzes = quizzes[i].contiguous()
 
+        quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].contiguous()
+
         return quizzes
 
     ######################################################################