Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 17 Sep 2024 18:18:24 +0000 (20:18 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 17 Sep 2024 18:18:24 +0000 (20:18 +0200)
grids.py
main.py
quiz_machine.py

index 490750b..9424496 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -422,6 +422,19 @@ class Grids(problem.Problem):
     ):
         quizzes = quizzes.to("cpu")
 
+        if quizzes.size(1) == 4 * self.height * self.width:
+            quizzes = torch.cat(
+                [
+                    quizzes.new_zeros(quizzes.size(0), 4, 1),
+                    quizzes.reshape(quizzes.size(0), 4, -1),
+                ],
+                dim=2,
+            )
+            quizzes[:, :, 0] = torch.tensor(
+                [self.token_A, self.token_f_A, self.token_B, self.token_f_B]
+            )[None, :]
+            quizzes = quizzes.reshape(quizzes.size(0), -1)
+
         to_reconfigure = [quizzes]
         if predicted_parts is not None:
             to_reconfigure.append(predicted_parts)
diff --git a/main.py b/main.py
index 9525bdd..6cbb2c4 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -326,11 +326,6 @@ quiz_machine = quiz_machine.QuizMachine(
     device=main_device,
 )
 
-
-diffuser = diffusion.Diffuser(
-    mu_T_sampler, args.diffusion_nb_iterations, args.diffusion_proba_corruption
-)
-
 ######################################################################
 
 log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
@@ -412,11 +407,10 @@ def batch_prediction_imt(input, fraction_with_hints=0.0):
     nb = input.size(0)
     masks = input.new_zeros(input.size())
     u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
-    masks.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
+    masks.view(nb, 4, -1)[...] = u[:, :, None]
     masks = add_hints(masks, fraction_with_hints)
-    # noise = quiz_machine.problem.pure_noise(nb, input.device)
     targets = input
-    input = (1 - masks) * targets  # + masks * noise
+    input = (1 - masks) * targets
 
     return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
@@ -446,12 +440,11 @@ def predict(model, imt_set, local_device=main_device):
 
 
 def predict_full(model, input, fraction_with_hints=0.0, local_device=main_device):
-    boy_that_s_ugly = input.view(input.size(0), 4, -1)[:, :, 0].clone()
     input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
     nb = input.size(0)
     masks = input.new_zeros(input.size())
     u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4)
-    masks.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
+    masks.view(nb, 4, -1)[...] = u[:, :, None]
     masks_with_hints = add_hints(masks, fraction_with_hints)
     targets = input
     input = (1 - masks_with_hints) * targets
@@ -462,8 +455,6 @@ def predict_full(model, input, fraction_with_hints=0.0, local_device=main_device
     result = predict(model, imt_set, local_device=local_device)
     result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
 
-    result.view(result.size(0), 4, -1)[:, :, 0] = boy_that_s_ugly
-
     return result
 
 
@@ -483,12 +474,10 @@ def batch_generation_imt(input):
     proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t
     mask_erased = (r <= proba_erased[:, None]).long()
 
-    noise = quiz_machine.problem.pure_noise(nb, input.device)
-
+    noise = quiz_machine.pure_noise(nb, input.device)
     targets = input
     input = (1 - mask_erased) * input + mask_erased * noise
     masks = input.new_full(input.size(), 1)
-    masks.reshape(masks.size(0), 4, -1)[:, :, 0] = 0
 
     return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
@@ -503,9 +492,8 @@ def prioritized_rand(low):
 
 
 def generate(model, nb, local_device=main_device):
-    all_input = quiz_machine.problem.pure_noise(nb, local_device)
+    all_input = quiz_machine.pure_noise(nb, local_device)
     all_masks = all_input.new_full(all_input.size(), 1)
-    all_masks.reshape(all_masks.size(0), 4, -1)[:, :, 0] = 0
 
     for input, masks in tqdm.tqdm(
         zip(
index dfedbf5..594b5ca 100755 (executable)
@@ -195,6 +195,11 @@ class QuizMachine:
 
     ######################################################################
 
+    def pure_noise(self, nb, device):
+        r = self.problem.pure_noise(nb, device)
+        r = r.view(r.size(0), 4, -1)[:, :, 1:].reshape(r.size(0), -1)
+        return r
+
     def quiz_set(self, nb_samples, c_quizzes, c_quiz_multiplier=1):
         if c_quizzes is None:
             quizzes = self.problem.generate_w_quizzes(nb_samples)
@@ -222,7 +227,9 @@ class QuizMachine:
         i = torch.randperm(quizzes.size(0), device=quizzes.device)
         quizzes = quizzes[i].contiguous()
 
-        quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].contiguous()
+        quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape(
+            quizzes.size(0), -1
+        )
 
         return quizzes