Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 18:15:48 +0000 (20:15 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 18:15:48 +0000 (20:15 +0200)
grids.py
main.py

index 5e623cb..fb31c7d 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -134,17 +134,20 @@ def grow_islands(nb, height, width, nb_seeds, nb_iterations):
 
 
 class Grids(problem.Problem):
-    grid_gray = 64
-    thickness = 1
-    background_gray = 255
+    # grid_gray = 64
+    # thickness = 1
+    # background_gray = 255
+    # dots = False
 
     # grid_gray=240
     # thickness=1
     # background_gray=240
+    # dots = False
 
-    # grid_gray = 255
-    # thickness = 0
-    # background_gray = 240
+    grid_gray = 200
+    thickness = 0
+    background_gray = 240
+    dots = True
 
     named_colors = [
         ("white", [background_gray, background_gray, background_gray]),
@@ -161,76 +164,10 @@ class Grids(problem.Problem):
         ("gray", [128, 128, 128]),
     ]
 
-    def check_order(self, quizzes, quad_order):
-        S = self.height * self.width
-
-        return (
-            (quizzes[:, 0 * (S + 1)] == self.l2tok[quad_order[0]])
-            & (quizzes[:, 1 * (S + 1)] == self.l2tok[quad_order[1]])
-            & (quizzes[:, 2 * (S + 1)] == self.l2tok[quad_order[2]])
-            & (quizzes[:, 3 * (S + 1)] == self.l2tok[quad_order[3]])
-        ).all()
-
-    def get_order(self, quizzes):
-        S = self.height * self.width
-        quad_order = tuple(
-            self.tok2l[n.item()]
-            for n in quizzes.reshape(quizzes.size(0), 4, S + 1)[0, :, 0]
-        )
-        self.check_order(quizzes, quad_order)
-        return quad_order
-
-    def inject_noise(self, quizzes, noise, quad_order, quad_noise):
-        assert self.check_order(quizzes, quad_order=quad_order)
-        S = self.height * self.width
-
-        mask = torch.tensor(quad_noise, device=quizzes.device)
-        mask = mask[None, :, None].expand(1, 4, S + 1).clone()
-        mask[:, :, 0] = 0
-        mask = mask.reshape(1, -1).expand_as(quizzes)
-        mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long()
-        random = torch.randint(self.nb_colors, mask.size())
-        quizzes = mask * random + (1 - mask) * quizzes
-
-        return quizzes
-
     def pure_noise(self, nb, device):
         result = torch.randint(
-            self.nb_colors, (nb, 4 * (self.height * self.height + 1)), device=device
+            self.nb_colors, (nb, 4 * (self.height * self.height)), device=device
         )
-        result.view(nb, 4, -1)[:, 0, 0] = self.token_A
-        result.view(nb, 4, -1)[:, 1, 0] = self.token_f_A
-        result.view(nb, 4, -1)[:, 2, 0] = self.token_B
-        result.view(nb, 4, -1)[:, 3, 0] = self.token_f_B
-        return result
-
-    # What a mess
-    def reconfigure(self, quizzes, quad_order=("A", "f_A", "B", "f_B")):
-        if torch.is_tensor(quizzes):
-            return self.reconfigure([quizzes], quad_order=quad_order)[0]
-
-        S = self.height * self.width
-        result = [x.new(x.size()) for x in quizzes]
-
-        quad_order_from = self.get_order(quizzes[0][:1])
-        i = self.indices_select(quizzes[0], quad_order_from)
-
-        sf = dict((l, n) for n, l in enumerate(quad_order_from))
-
-        for q in range(4):
-            k = sf[quad_order[q]]
-            for x, y in zip(quizzes, result):
-                l = x.size(1) // 4
-                y[i, q * l : (q + 1) * l] = x[i, k * l : (k + 1) * l]
-
-        j = i == False
-
-        if j.any():
-            for z, y in zip(
-                self.reconfigure([x[j] for x in quizzes], quad_order=quad_order), result
-            ):
-                y[j] = z
-
         return result
 
     def trivial(self, quizzes):
@@ -241,22 +178,6 @@ class Grids(problem.Problem):
             dim=1
         ).values
 
-    def make_quiz_mask(
-        self, quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=(0, 0, 0, 1)
-    ):
-        assert self.check_order(quizzes, quad_order)
-
-        ar_mask = quizzes.new_zeros(quizzes.size())
-
-        S = self.height * self.width
-        a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:]
-        a[:, 0, :] = quad_mask[0]
-        a[:, 1, :] = quad_mask[1]
-        a[:, 2, :] = quad_mask[2]
-        a[:, 3, :] = quad_mask[3]
-
-        return ar_mask
-
     def text2quiz(self, t):
         chr2col = [
             (".", "white"),
@@ -322,32 +243,13 @@ class Grids(problem.Problem):
         self.colors = torch.tensor([c for _, c in self.named_colors])
 
         self.nb_colors = len(self.colors)
-        self.token_A = self.nb_colors
-        self.token_f_A = self.token_A + 1
-        self.token_B = self.token_f_A + 1
-        self.token_f_B = self.token_B + 1
 
         self.nb_rec_max = 5
         self.rfree = torch.tensor([])
 
-        self.l2tok = {
-            "A": self.token_A,
-            "f_A": self.token_f_A,
-            "B": self.token_B,
-            "f_B": self.token_f_B,
-        }
-
-        self.tok2l = {
-            self.token_A: "A",
-            self.token_f_A: "f_A",
-            self.token_B: "B",
-            self.token_f_B: "f_B",
-        }
-
         self.height = 10
         self.width = 10
-        self.seq_len = 4 * (1 + self.height * self.width)
-        self.nb_token_values = self.token_f_B + 1
+        self.seq_len = 4 * self.height * self.width
 
         self.cache_rec_coo = {}
 
@@ -385,7 +287,7 @@ class Grids(problem.Problem):
     ######################################################################
 
     def vocabulary_size(self):
-        return self.nb_token_values
+        return self.nb_colors
 
     def grid2img(self, x, scale=15, grids=True):
         m = torch.logical_and(x >= 0, x < self.nb_colors).long()
@@ -398,14 +300,41 @@ class Grids(problem.Problem):
             for t in range(self.thickness):
                 y[:, :, :, torch.arange(t, y.size(3), scale)] = self.grid_gray
                 y[:, :, torch.arange(t, y.size(2), scale), :] = self.grid_gray
+        if self.dots:
+            z = y.reshape(
+                y.size(0),
+                y.size(1),
+                y.size(2) // scale,
+                scale,
+                y.size(3) // scale,
+                scale,
+            )
+            z = z[
+                :,
+                :,
+                :,
+                scale // 2 - 2 : scale // 2 + 1,
+                :,
+                scale // 2 - 2 : scale // 2 + 1,
+            ]
+            z[...] = (z == self.background_gray) * self.grid_gray + (
+                z != self.background_gray
+            ) * z
 
         for n in range(m.size(0)):
             for i in range(m.size(1)):
                 for j in range(m.size(2)):
-                    if m[n, i, j] == 0:
-                        for k in range(3, scale - 2):
-                            y[n, :, i * scale + k, j * scale + k] = 0
-                            y[n, :, i * scale + k, j * scale + scale - k] = 0
+                    if x[n, i, j] >= self.nb_colors:
+                        # for k in range(3, scale - 2):
+                        c = self.colors[x[n, i, j] - self.nb_colors][:, None, None]
+                        # y[n, :, i * scale + k, j * scale + k] = c
+                        # y[n, :, i * scale + k, j * scale + scale - k] = c
+                        y[
+                            n,
+                            :,
+                            i * scale + 3 : i * scale + scale - 2,
+                            j * scale + 3 : j * scale + scale - 2,
+                        ] = c
 
         y = y[:, :, 1:, 1:]
 
@@ -443,37 +372,10 @@ class Grids(problem.Problem):
     ):
         quizzes = quizzes.to("cpu")
 
-        if quizzes.size(1) == 4 * self.height * self.width:
-            quizzes = torch.cat(
-                [
-                    quizzes.new_zeros(quizzes.size(0), 4, 1),
-                    quizzes.reshape(quizzes.size(0), 4, -1),
-                ],
-                dim=2,
-            )
-            quizzes[:, :, 0] = torch.tensor(
-                [self.token_A, self.token_f_A, self.token_B, self.token_f_B]
-            )[None, :]
-            quizzes = quizzes.reshape(quizzes.size(0), -1)
-
-        to_reconfigure = [quizzes]
-        if predicted_parts is not None:
-            to_reconfigure.append(predicted_parts)
-        if correct_parts is not None:
-            to_reconfigure.append(correct_parts)
-
-        to_reconfigure = self.reconfigure(to_reconfigure, ("A", "f_A", "B", "f_B"))
-
-        quizzes = to_reconfigure.pop(0)
-        if predicted_parts is not None:
-            predicted_parts = to_reconfigure.pop(0)
-        if correct_parts is not None:
-            correct_parts = to_reconfigure.pop(0)
-
         S = self.height * self.width
 
         A, f_A, B, f_B = (
-            quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
+            quizzes.reshape(quizzes.size(0), 4, S)
             .reshape(quizzes.size(0), 4, self.height, self.width)
             .permute(1, 0, 2, 3)
         )
@@ -1881,7 +1783,7 @@ class Grids(problem.Problem):
         if tasks is None:
             tasks = self.all_tasks
 
-        quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
+        quizzes = torch.empty(nb, 4 * self.height * self.width, dtype=torch.int64)
 
         if progress_bar:
             quizzes = tqdm.tqdm(
@@ -1892,7 +1794,7 @@ class Grids(problem.Problem):
             )
 
         for quiz in quizzes:
-            q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width)
+            q = quiz.reshape(4, self.height, self.width)
             q[...] = 0
             A, f_A, B, f_B = q
             task = tasks[torch.randint(len(tasks), (1,)).item()]
@@ -1932,6 +1834,9 @@ if __name__ == "__main__":
     ]:
         print(t.__name__)
         w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
+
+        w_quizzes[:5] = torch.randint(grids.vocabulary_size(), w_quizzes[:5].size())
+
         grids.save_quizzes_as_image(
             "/tmp",
             t.__name__ + ".png",
diff --git a/main.py b/main.py
index 85f2cb6..5493b7d 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -250,18 +250,9 @@ assert args.nb_test_samples % args.batch_size == 0
 ######################################################################
 
 
-def pure_noise(nb, device):
-    r = problem.pure_noise(nb, device)
-    r = r.view(r.size(0), 4, -1)[:, :, 1:].reshape(r.size(0), -1)
-    return r
-
-
 def quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
     if c_quizzes is None:
         quizzes = problem.generate_w_quizzes(nb_samples)
-        quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape(
-            quizzes.size(0), -1
-        )
         nb_w_quizzes = quizzes.size(0)
         nb_c_quizzes = 0
     else:
@@ -340,7 +331,7 @@ def add_noise_imt(imt_set):
     """Replace every component of the input by a random value with
     probability args.proba_prompt_noise."""
     input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
-    noise = pure_noise(input.size(0), input.device)
+    noise = problem.pure_noise(input.size(0), input.device)
     change = (1 - masks) * (
         torch.rand(input.size(), device=input.device) < args.proba_prompt_noise
     ).long()
@@ -428,7 +419,7 @@ def samples_for_generation_imt(input):
     proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t
     mask_erased = (r <= proba_erased[:, None]).long()
 
-    noise = pure_noise(nb, input.device)
+    noise = problem.pure_noise(nb, input.device)
     targets = input
     input = (1 - mask_erased) * input + mask_erased * noise
     masks = input.new_full(input.size(), 1)
@@ -452,7 +443,7 @@ def ae_generate(model, nb, local_device=main_device):
     # mini-batches second so that we keep only the samples that have
     # not stabilized
 
-    all_input = pure_noise(nb, local_device)
+    all_input = problem.pure_noise(nb, local_device)
     all_masks = all_input.new_full(all_input.size(), 1)
     all_changed = torch.full((all_input.size(0),), True, device=all_input.device)