Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 26 Sep 2024 20:46:12 +0000 (22:46 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 26 Sep 2024 20:46:12 +0000 (22:46 +0200)
main.py
world.py [new file with mode: 0755]

diff --git a/main.py b/main.py
index 7af281c..d699bc6 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -509,16 +509,18 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
 
     q_p, q_g = quizzes.to(local_device).chunk(2)
 
-    # Half of the samples train the prediction. We inject noise in all
-    # to avoid drift of the culture toward "finding waldo" type of
-    # complexity, and hints in half to allow dealing with hints when
-    # validating c quizzes
+    # Half of the samples are used to train the prediction.
     b_p = samples_for_prediction_imt(q_p)
+    # We inject noise in all to avoid drift of the culture toward
+    # "finding waldo" type of complexity
     b_p = add_input_noise_imt(b_p, args.proba_input_noise)
-    half = torch.rand(b_p.size(0)) < 0.5
+    # And we add hints in half so that the models can deal with too
+    # complex quizzes
+    half = torch.rand(b_p.size(0), device=b_p.device) < 0.5
     b_p[half] = add_hints_imt(b_p[half], args.proba_hints)
 
-    # The other half are denoising examples for the generation
+    # The other half are denoising examples to train the generative
+    # process.
     b_g = samples_for_generation_imt(q_g)
 
     imt_set = torch.cat([b_p, b_g])
diff --git a/world.py b/world.py
new file mode 100755 (executable)
index 0000000..3ab6944
--- /dev/null
+++ b/world.py
@@ -0,0 +1,790 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, sys, tqdm, os, warnings, cairo, re
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+
+def text_img(height, width, text):
+    pixel_map = torch.full((height, width, 4), 255, dtype=torch.uint8)
+
+    surface = cairo.ImageSurface.create_for_data(
+        pixel_map.numpy(), cairo.FORMAT_ARGB32, pixel_map.size(1), pixel_map.size(0)
+    )
+
+    ctx = cairo.Context(surface)
+    ctx.set_source_rgb(0, 0, 0)
+    ctx.set_font_size(16)
+    ctx.select_font_face("courier", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
+    y = None
+    for line in text.split("\n"):
+        xbearing, ybearing, width, height, dx, dy = ctx.text_extents(line)
+        if y is None:
+            y = height * 1.5
+            x = height * 0.5
+
+        ctx.move_to(x, y)
+        ctx.show_text(line)
+        y += height * 1.5
+
+    ctx.stroke()
+
+    return pixel_map.permute(2, 0, 1)[None, :3].contiguous()
+
+
+######################################################################
+
+import problem
+
+
+class Grids(problem.Problem):
+    grid_gray = 240
+    thickness = 1
+    background_gray = 240
+    dots = False
+
+    named_colors = [
+        ("white", [background_gray, background_gray, background_gray]),
+        # ("white", [224, 224, 224]),
+        ("red", [255, 0, 0]),
+        ("green", [0, 160, 0]),
+        ("blue", [0, 0, 255]),
+        ("yellow", [255, 224, 0]),
+        ("cyan", [0, 255, 255]),
+        ("violet", [224, 128, 255]),
+        ("lightgreen", [160, 255, 160]),
+        ("brown", [165, 42, 42]),
+        ("lightblue", [192, 192, 255]),
+        ("gray", [128, 128, 128]),
+    ]
+
+    def pure_noise(self, nb, device):
+        result = torch.randint(
+            self.nb_colors, (nb, 4 * (self.height * self.height)), device=device
+        )
+        return result
+
+    def trivial(self, quizzes):
+        S = self.height * self.width
+        assert self.check_order(quizzes, quad_order=("A", "f_A", "B", "f_B"))
+        a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
+        return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min(
+            dim=1
+        ).values
+
+    def text2quiz(self, t):
+        chr2col = [
+            (".", "white"),
+            ("r", "red"),
+            ("g", "green"),
+            ("b", "blue"),
+            ("y", "yellow"),
+            ("c", "cyan"),
+            ("v", "violet"),
+            ("l", "lightgreen"),
+            ("o", "brown"),
+            ("l", "lightblue"),
+            ("a", "gray"),
+        ]
+
+        col2tok = dict([(c[0], n) for n, c in enumerate(self.named_colors)])
+        chr2tok = dict([(c, col2tok[col]) for c, col in chr2col])
+
+        t = re.sub(r"#.*\n", "", t).strip()
+        l = t.replace("\n\n", ";").split(";")
+
+        result = []
+
+        for t in l:
+            t = "".join(t.replace("\n", " ").strip().split(" "))
+            t = torch.tensor([chr2tok[c] for c in t])
+            t = t.reshape(10, 4, 10).permute(1, 0, 2).flatten(1)
+            t = torch.cat(
+                [
+                    torch.tensor(
+                        [
+                            [self.token_A],
+                            [self.token_f_A],
+                            [self.token_B],
+                            [self.token_f_B],
+                        ]
+                    ),
+                    t,
+                ],
+                dim=1,
+            )
+            result.append(t.flatten()[None, :])
+
+        return torch.cat(result, dim=0)
+
+    def __init__(
+        self,
+        max_nb_cached_chunks=None,
+        chunk_size=None,
+        nb_threads=-1,
+        tasks=None,
+    ):
+        self.colors = torch.tensor([c for _, c in self.named_colors])
+
+        self.nb_colors = len(self.colors)
+
+        self.nb_rec_max = 5
+        self.rfree = torch.tensor([])
+
+        self.height = 12
+        self.width = 20
+        self.seq_len = 4 * self.height * self.width
+
+        self.cache_rec_coo = {}
+
+        all_tasks = [
+            self.task_replace_color,
+            self.task_translate,
+            self.task_grow,
+            self.task_frame,
+        ]
+
+        if tasks is None:
+            self.all_tasks = all_tasks
+        else:
+            self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
+
+        super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
+
+    ######################################################################
+
+    def vocabulary_size(self):
+        return self.nb_colors
+
+    def grid2img(self, x, scale=15, grids=True):
+        m = torch.logical_and(x >= 0, x < self.nb_colors).long()
+        y = self.colors[x * m].permute(0, 3, 1, 2)
+        s = y.shape
+        y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+        y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
+
+        if grids:
+            for t in range(self.thickness):
+                y[:, :, :, torch.arange(t, y.size(3), scale)] = self.grid_gray
+                y[:, :, torch.arange(t, y.size(2), scale), :] = self.grid_gray
+        if self.dots:
+            z = y.reshape(
+                y.size(0),
+                y.size(1),
+                y.size(2) // scale,
+                scale,
+                y.size(3) // scale,
+                scale,
+            )
+            z = z[
+                :,
+                :,
+                :,
+                scale // 2 - 1 : scale // 2 + 2,
+                :,
+                scale // 2 - 1 : scale // 2 + 2,
+            ]
+            zz = (z == self.background_gray).min(dim=1, keepdim=True).values
+            z[...] = zz * self.grid_gray + (zz == False) * z
+
+        for n in range(m.size(0)):
+            for i in range(m.size(1)):
+                for j in range(m.size(2)):
+                    if x[n, i, j] >= self.nb_colors:
+                        # for k in range(3, scale - 2):
+                        c = self.colors[x[n, i, j] - self.nb_colors][:, None, None]
+                        # y[n, :, i * scale + k, j * scale + k] = c
+                        # y[n, :, i * scale + k, j * scale + scale - k] = c
+                        y[
+                            n,
+                            :,
+                            i * scale + 3 : i * scale + scale - 2,
+                            j * scale + 3 : j * scale + scale - 2,
+                        ] = c
+
+        y = y[:, :, 1:, 1:]
+
+        return y
+
+    def add_frame(self, img, colors, thickness):
+        if thickness > 0:
+            result = img.new(
+                img.size(0),
+                img.size(1),
+                img.size(2) + 2 * thickness,
+                img.size(3) + 2 * thickness,
+            )
+
+            result[...] = colors[:, :, None, None]
+            result[:, :, thickness:-thickness, thickness:-thickness] = img
+        else:
+            result = img
+
+        return result
+
+    def save_quizzes_as_image(
+        self,
+        result_dir,
+        filename,
+        quizzes,
+        predicted_parts=None,
+        correct_parts=None,
+        comments=None,
+        comment_height=48,
+        nrow=4,
+        grids=True,
+        margin=12,
+        delta=False,
+        delta_highlight=False,
+    ):
+        quizzes = quizzes.to("cpu")
+
+        S = self.height * self.width
+
+        A, f_A, B, f_B = (
+            quizzes.reshape(quizzes.size(0), 4, S)
+            .reshape(quizzes.size(0), 4, self.height, self.width)
+            .permute(1, 0, 2, 3)
+        )
+
+        frame, white, gray, green, red = torch.tensor(
+            [
+                [self.grid_gray, self.grid_gray, self.grid_gray],
+                [255, 255, 255],
+                [200, 200, 200],
+                [0, 255, 0],
+                [255, 0, 0],
+            ],
+            device=quizzes.device,
+        )
+
+        thickness = self.thickness
+
+        if delta:
+            u = (A != f_A).long()
+            img_delta_A = self.add_frame(
+                self.grid2img(u, grids=grids), frame[None, :], thickness=thickness
+            )
+            img_delta_A = img_delta_A.min(dim=1, keepdim=True).values.expand_as(
+                img_delta_A
+            )
+            u = (B != f_B).long()
+            img_delta_B = self.add_frame(
+                self.grid2img(u, grids=grids), frame[None, :], thickness=thickness
+            )
+            img_delta_B = img_delta_B.min(dim=1, keepdim=True).values.expand_as(
+                img_delta_B
+            )
+
+        img_A = self.add_frame(
+            self.grid2img(A, grids=grids), frame[None, :], thickness=thickness
+        )
+        img_f_A = self.add_frame(
+            self.grid2img(f_A, grids=grids), frame[None, :], thickness=thickness
+        )
+        img_B = self.add_frame(
+            self.grid2img(B, grids=grids), frame[None, :], thickness=thickness
+        )
+        img_f_B = self.add_frame(
+            self.grid2img(f_B, grids=grids), frame[None, :], thickness=thickness
+        )
+
+        if delta_highlight:
+            q = (img_B == img_f_B).min(dim=1, keepdim=True).values.long()
+            img_f_B = q * (img_f_B // 4 + 192) + (1 - q) * img_f_B
+
+        # predicted_parts Nx4
+        # correct_parts Nx4
+
+        if predicted_parts is None:
+            colors = white[None, None, :].expand(-1, 4, -1)
+        else:
+            predicted_parts = predicted_parts.to("cpu")
+            if correct_parts is None:
+                colors = (
+                    predicted_parts[:, :, None] * gray[None, None, :]
+                    + (1 - predicted_parts[:, :, None]) * white[None, None, :]
+                )
+            else:
+                correct_parts = correct_parts.to("cpu")
+                colors = (
+                    predicted_parts[:, :, None]
+                    * (
+                        (correct_parts[:, :, None] == 1).long() * green[None, None, :]
+                        + (correct_parts[:, :, None] == 0).long() * gray[None, None, :]
+                        + (correct_parts[:, :, None] == -1).long() * red[None, None, :]
+                    )
+                    + (1 - predicted_parts[:, :, None]) * white[None, None, :]
+                )
+
+        separation = 6
+
+        img_A = self.add_frame(img_A, colors[:, 0], thickness=separation)
+        img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=separation)
+        img_B = self.add_frame(img_B, colors[:, 2], thickness=separation)
+        img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=separation)
+
+        img_A = self.add_frame(img_A, white[None, :], thickness=2)
+        img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2)
+        img_B = self.add_frame(img_B, white[None, :], thickness=2)
+        img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2)
+
+        if delta:
+            img_delta_A = self.add_frame(
+                img_delta_A, colors[:, 0], thickness=separation
+            )
+            img_delta_A = self.add_frame(img_delta_A, white[None, :], thickness=2)
+            img_delta_B = self.add_frame(
+                img_delta_B, colors[:, 0], thickness=separation
+            )
+            img_delta_B = self.add_frame(img_delta_B, white[None, :], thickness=2)
+            img = torch.cat(
+                [img_A, img_f_A, img_delta_A, img_B, img_f_B, img_delta_B], dim=3
+            )
+        else:
+            img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3)
+
+        if comments is not None:
+            comment_img = [text_img(comment_height, img.size(3), t) for t in comments]
+            comment_img = torch.cat(comment_img, dim=0)
+            img = torch.cat([img, comment_img], dim=2)
+
+        image_name = os.path.join(result_dir, filename)
+
+        torchvision.utils.save_image(
+            img.float() / 255.0,
+            image_name,
+            nrow=nrow,
+            padding=margin * 4,
+            pad_value=1.0,
+        )
+
+    ######################################################################
+
+    # @torch.compile
+    def rec_coo(
+        self,
+        nb_rec,
+        min_height=3,
+        min_width=3,
+        surface_max=None,
+    ):
+        if surface_max is None:
+            surface_max = self.height * self.width // 4
+
+        signature = (nb_rec, min_height, min_width, surface_max)
+
+        try:
+            return self.cache_rec_coo[signature].pop()
+        except IndexError:
+            pass
+        except KeyError:
+            pass
+
+        N = 10000
+        while True:
+            while True:
+                i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
+                j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
+                i[:, 1] += 1
+                j[:, 1] += 1
+                big_enough = (
+                    (i[:, 1] >= i[:, 0] + min_height)
+                    & (j[:, 1] >= j[:, 0] + min_height)
+                    & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
+                )
+
+                i, j = i[big_enough], j[big_enough]
+
+                n = i.size(0) - i.size(0) % nb_rec
+
+                if n > 0:
+                    break
+
+            i = i[:n].reshape(n // nb_rec, nb_rec, -1)
+            j = j[:n].reshape(n // nb_rec, nb_rec, -1)
+
+            if i.size(0) > 1:
+                break
+
+        self.cache_rec_coo[signature] = [
+            [
+                (
+                    i[n, k, 0].item(),
+                    j[n, k, 0].item(),
+                    i[n, k, 1].item(),
+                    j[n, k, 1].item(),
+                )
+                for k in range(nb_rec)
+            ]
+            for n in range(i.size(0))
+        ]
+
+        return self.cache_rec_coo[signature].pop()
+
+    ######################################################################
+
+    def task_replace_color(self, A, f_A, B, f_B):
+        nb_rec = 3
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            r = self.rec_coo(nb_rec)
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
+
+    def task_translate(self, A, f_A, B, f_B):
+        while True:
+            di, dj = torch.randint(3, (2,)) - 1
+            if di.abs() + dj.abs() > 0:
+                break
+
+        nb_rec = 3
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                r = self.rec_coo(nb_rec)
+                i1, j1, i2, j2 = r[nb_rec - 1]
+                if (
+                    i1 + di >= 0
+                    and i2 + di < X.size(0)
+                    and j1 + dj >= 0
+                    and j2 + dj < X.size(1)
+                ):
+                    break
+
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                if n == nb_rec - 1:
+                    f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
+                else:
+                    f_X[i1:i2, j1:j2] = c[n]
+
+    def task_grow(self, A, f_A, B, f_B):
+        di, dj = torch.randint(2, (2,)) * 2 - 1
+        nb_rec = 3
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+        direction = torch.randint(2, (1,)).item()
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                r = self.rec_coo(nb_rec)
+                i1, j1, i2, j2 = r[nb_rec - 1]
+                if i1 + 3 < i2 and j1 + 3 < j2:
+                    break
+
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                if n == nb_rec - 1:
+                    if direction == 0:
+                        X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
+                        f_X[i1:i2, j1:j2] = c[n]
+                    else:
+                        X[i1:i2, j1:j2] = c[n]
+                        f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
+                else:
+                    X[i1:i2, j1:j2] = c[n]
+                    f_X[i1:i2, j1:j2] = c[n]
+
+    # @torch.compile
+    def task_frame(self, A, f_A, B, f_B):
+        nb_rec = 3
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            r = self.rec_coo(nb_rec)
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                if n == nb_rec - 1:
+                    f_X[i1:i2, j1] = c[n]
+                    f_X[i1:i2, j2 - 1] = c[n]
+                    f_X[i1, j1:j2] = c[n]
+                    f_X[i2 - 1, j1:j2] = c[n]
+                else:
+                    f_X[i1:i2, j1:j2] = c[n]
+
+    ######################################################################
+
+    def create_empty_quizzes(self, nb, quad_order=("A", "f_A", "B", "f_B")):
+        S = self.height * self.width
+        quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64)
+        quizzes[:, 0 * (S + 1)] = self.l2tok[quad_order[0]]
+        quizzes[:, 1 * (S + 1)] = self.l2tok[quad_order[1]]
+        quizzes[:, 2 * (S + 1)] = self.l2tok[quad_order[2]]
+        quizzes[:, 3 * (S + 1)] = self.l2tok[quad_order[3]]
+
+        return quizzes
+
+    def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
+        S = self.height * self.width
+
+        if tasks is None:
+            tasks = self.all_tasks
+
+        quizzes = torch.empty(nb, 4 * self.height * self.width, dtype=torch.int64)
+
+        if progress_bar:
+            quizzes = tqdm.tqdm(
+                quizzes,
+                dynamic_ncols=True,
+                desc="world quizzes generation",
+                total=quizzes.size(0),
+            )
+
+        for quiz in quizzes:
+            q = quiz.reshape(4, self.height, self.width)
+            q[...] = 0
+            A, f_A, B, f_B = q
+            task = tasks[torch.randint(len(tasks), (1,)).item()]
+            task(A, f_A, B, f_B)
+
+        return quizzes
+
+    def save_some_examples(self, result_dir, prefix=""):
+        nb, nrow = 256, 8
+        for t in self.all_tasks:
+            print(t.__name__)
+            quizzes = self.generate_w_quizzes_(nb, tasks=[t])
+            self.save_quizzes_as_image(
+                result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow, delta=True
+            )
+
+    def detect_rectangles(self, q1, q2):
+        c = torch.arange(self.nb_colors)
+        I = torch.arange(self.height)[None, :, None]
+        J = torch.arange(self.width)[None, :, None]
+
+        def corners(q):
+            q = q.reshape(-1, self.height, self.width)
+            a = (q[:, :, :, None] == c[None, None, None, :]).long()
+            mi = a.max(dim=2).values
+            i = mi * I
+            i1 = (i + (1 - mi) * q.size(1)).min(dim=1).values
+            i2 = (i + (1 - mi) * (-1)).max(dim=1).values + 1
+            mj = a.max(dim=1).values
+            j = mj * J
+            j1 = (j + (1 - mj) * q.size(2)).min(dim=1).values
+            j2 = (j + (1 - mj) * (-1)).max(dim=1).values + 1
+            m = (
+                ((I > i1[:, None, :]) & (I < i2[:, None, :] - 1))[:, :, None, :]
+                & ((J > j1[:, None, :]) & (J < j2[:, None, :] - 1))[:, None, :, :]
+            ).long()
+            f = ((a * m).long().sum(dim=(1, 2)) > 0).long()
+            return i1, i2, j1, j2, f
+
+        q1_i1, q1_i2, q1_j1, q1_j2, q1_f = corners(q1)
+        q2_i1, q2_i2, q2_j1, q2_j2, q2_f = corners(q2)
+        u1, u2 = 0, 0
+
+        for _ in range(10):
+            r1 = q.new_zeros(q1.size(0), self.height, self.width)
+            r2 = q.new_zeros(q1.size(0), self.height, self.width)
+
+            m1 = (
+                ((I >= q1_i1[:, None, :]) & (I < q1_i2[:, None, :]))[:, :, None, :]
+                & ((J >= q1_j1[:, None, :]) & (J < q1_j2[:, None, :]))[:, None, :, :]
+            ).long()
+
+            f1 = (
+                (
+                    ((I == q1_i1[:, None, :]) | (I == q1_i2[:, None, :] - 1))[
+                        :, :, None, :
+                    ]
+                    & ((J >= q1_j1[:, None, :]) & (J < q1_j2[:, None, :]))[
+                        :, None, :, :
+                    ]
+                )
+                | (
+                    ((I >= q1_i1[:, None, :]) & (I < q1_i2[:, None, :] - 1))[
+                        :, :, None, :
+                    ]
+                    & ((J == q1_j1[:, None, :]) | (J == q1_j2[:, None, :] - 1))[
+                        :, None, :, :
+                    ]
+                )
+            ).long()
+
+            r2 = q.new_zeros(q2.size(0), self.height, self.width)
+
+            m2 = (
+                ((I >= q2_i1[:, None, :]) & (I < q2_i2[:, None, :]))[:, :, None, :]
+                & ((J >= q2_j1[:, None, :]) & (J < q2_j2[:, None, :]))[:, None, :, :]
+            ).long()
+
+            f2 = (
+                (
+                    ((I == q2_i1[:, None, :]) | (I == q2_i2[:, None, :] - 1))[
+                        :, :, None, :
+                    ]
+                    & ((J >= q2_j1[:, None, :]) & (J < q2_j2[:, None, :]))[
+                        :, None, :, :
+                    ]
+                )
+                | (
+                    ((I >= q2_i1[:, None, :]) & (I < q2_i2[:, None, :] - 1))[
+                        :, :, None, :
+                    ]
+                    & ((J == q2_j1[:, None, :]) | (J == q2_j2[:, None, :] - 1))[
+                        :, None, :, :
+                    ]
+                )
+            ).long()
+
+            for c in torch.randperm(self.nb_colors - 1) + 1:
+                r1[...] = q1_f[:, None, None, c] * (
+                    m1[:, :, :, c] * c + (1 - m1[:, :, :, c]) * r1
+                ) + (1 - q1_f[:, None, None, c]) * (
+                    f1[:, :, :, c] * c + (1 - f1[:, :, :, c]) * r1
+                )
+
+                r2[...] = q2_f[:, None, None, c] * (
+                    m2[:, :, :, c] * c + (1 - m2[:, :, :, c]) * r2
+                ) + (1 - q2_f[:, None, None, c]) * (
+                    f2[:, :, :, c] * c + (1 - f2[:, :, :, c]) * r2
+                )
+
+            match = (
+                (q1 == r1.flatten(1)).min(dim=1).values
+                & (q2 == r2.flatten(1)).min(dim=1).values
+            ).long()[:, None, None]
+            u1 = (1 - match) * u1 + match * r1
+            u2 = (1 - match) * u2 + match * r2
+
+        return u1.flatten(1), u2.flatten(1)
+
+        # o = F.one_hot(q * (1 - m)).sum(dim=1)
+        # print(o)
+        # print(o.sort(dim=1, descending=True))
+        # c = N x nb_col x 4
+
+
+######################################################################
+
+if __name__ == "__main__":
+    import time
+
+    grids = Grids()
+
+    nb, nrow = 64, 4
+    nb_rows = 12
+
+    # c_quizzes = torch.load("/home/fleuret/state.pth")["train_c_quizzes"]
+    # c_quizzes = c_quizzes[torch.randperm(c_quizzes.size(0))[: nrow * nb_rows]]
+
+    # grids.save_quizzes_as_image(
+    # "/tmp",
+    # "c_quizzes.png",
+    # c_quizzes,
+    # delta=True,
+    # nrow=nrow,
+    # margin=10,
+    # grids=False
+    # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
+    # )
+
+    w_quizzes = grids.generate_w_quizzes_(
+        16,
+        tasks=[
+            grids.task_replace_color,
+            grids.task_translate,
+            grids.task_grow,
+            grids.task_frame,
+        ],
+    )
+
+    q = w_quizzes.reshape(-1, 4, w_quizzes.size(1) // 4)
+    r = q.new_zeros(q.size())
+    r[:, 0], r[:, 1] = grids.detect_rectangles(q[:, 0], q[:, 1])
+    r[:, 2], r[:, 3] = grids.detect_rectangles(q[:, 2], q[:, 3])
+
+    grids.save_quizzes_as_image(
+        "/tmp",
+        "q.png",
+        q.flatten(1),
+        # delta=True,
+        nrow=nrow,
+        margin=10,
+        # grids=False
+        # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
+    )
+
+    grids.save_quizzes_as_image(
+        "/tmp",
+        "r.png",
+        r.flatten(1),
+        # delta=True,
+        nrow=nrow,
+        margin=10,
+        # grids=False
+        # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
+    )
+
+    exit(0)
+
+    q = grids.text2quiz(
+        """
+
+# the original
+
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+....aaaaa. ....aaaaa. .vvvvv.... .rrrrr....
+.......... .......... .vvvvvvvvv .rrrrroooo
+.......... .......... .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. .......aaa .......aaa
+vvvvaaaaa. rrrraaaaa. .......aaa .......aaa
+....aaaaa. ....aaaaa. .vvvvv.aaa .rrrrr.aaa
+.......... .......... .vvvvvvvvv .rrrrroooo
+.......... .......... .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+#
+# so what
+#
+
+vvvv...... rrrr...... .......... ..........
+vvvv...... rrrr...... .......... ..........
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+vvvv...... rrrr...... .......... ..........
+vvvv...... rrrr...... .......... ..........
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+"""
+    )
+
+    grids.save_quizzes_as_image("/tmp", "test.png", q, nrow=1, grids=False)