Merge branch 'dev' master
authorFrançois Fleuret <francois@fleuret.org>
Sun, 11 Aug 2024 12:33:19 +0000 (14:33 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 11 Aug 2024 12:33:19 +0000 (14:33 +0200)
54 files changed:
grids.py
main.py
mygpt.py
problem.py
quiz_machine.py
report/culture.tex [new file with mode: 0644]
report/pics/4_birds_1.png [new file with mode: 0644]
report/pics/5_birds_1.png [new file with mode: 0644]
report/pics/6_birds_1.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_01.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_02.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_16.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_18.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_27.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_30.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_31.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_37.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_42.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_60.png [new file with mode: 0644]
report/pics/culture_c_quiz_0084_N4_validated/quiz_00.png [new file with mode: 0644]
report/pics/culture_c_quiz_0084_N4_validated/quiz_10.png [new file with mode: 0644]
report/pics/culture_c_quiz_0084_N4_validated/quiz_12.png [new file with mode: 0644]
report/pics/culture_c_quiz_0084_N4_validated/quiz_21.png [new file with mode: 0644]
report/pics/culture_c_quiz_0084_N4_validated/quiz_41.png [new file with mode: 0644]
report/pics/culture_c_quiz_0084_N4_validated/quiz_49.png [new file with mode: 0644]
report/pics/culture_c_quiz_0086_N4_validated/quiz_04.png [new file with mode: 0644]
report/pics/culture_c_quiz_0086_N4_validated/quiz_23.png [new file with mode: 0644]
report/pics/culture_c_quiz_0086_N4_validated/quiz_28.png [new file with mode: 0644]
report/pics/culture_c_quiz_0086_N4_validated/quiz_45.png [new file with mode: 0644]
report/pics/culture_c_quiz_0087_N4_validated/quiz_62.png [new file with mode: 0644]
report/pics/culture_c_quiz_0089_N4_validated/quiz_28.png [new file with mode: 0644]
report/pics/culture_c_quiz_0102_N4_validated/quiz_04.png [new file with mode: 0644]
report/pics/culture_c_quiz_0102_N4_validated/quiz_11.png [new file with mode: 0644]
report/pics/culture_c_quiz_0108_N4_validated/quiz_31.png [new file with mode: 0644]
report/pics/culture_c_quiz_0110_N4_validated/quiz_63.png [new file with mode: 0644]
report/pics/culture_c_quiz_0111_N4_validated/quiz_23.png [new file with mode: 0644]
report/pics/culture_c_quiz_0115_N4_validated/quiz_37.png [new file with mode: 0644]
report/pics/culture_c_quiz_0120_N4_validated/quiz_05.png [new file with mode: 0644]
report/pics/examples_train.png [new file with mode: 0644]
report/pics/occlusions_1.png [new file with mode: 0644]
report/pics/other_shapes_1.png [new file with mode: 0644]
report/pics/other_shapes_2.png [new file with mode: 0644]
report/pics/other_shapes_3.png [new file with mode: 0644]
report/pics/task_bounce.png [new file with mode: 0644]
report/pics/task_color_grow.png [new file with mode: 0644]
report/pics/task_count.png [new file with mode: 0644]
report/pics/task_detect.png [new file with mode: 0644]
report/pics/task_frame.png [new file with mode: 0644]
report/pics/task_grow.png [new file with mode: 0644]
report/pics/task_replace_color.png [new file with mode: 0644]
report/pics/task_scale.png [new file with mode: 0644]
report/pics/task_trajectory.png [new file with mode: 0644]
report/pics/task_translate.png [new file with mode: 0644]
tasks.py [new file with mode: 0755]

index eea8c6c..0564f3b 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, sys, tqdm, os, warnings
+import math, sys, tqdm, os, warnings, cairo
 
 import torch, torchvision
 
@@ -14,6 +14,36 @@ from torch.nn import functional as F
 
 ######################################################################
 
+
+def text_img(height, width, text):
+    pixel_map = torch.full((height, width, 4), 255, dtype=torch.uint8)
+
+    surface = cairo.ImageSurface.create_for_data(
+        pixel_map.numpy(), cairo.FORMAT_ARGB32, pixel_map.size(1), pixel_map.size(0)
+    )
+
+    ctx = cairo.Context(surface)
+    ctx.set_source_rgb(0, 0, 0)
+    ctx.set_font_size(16)
+    ctx.select_font_face("courier", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
+    y = None
+    for line in text.split("\n"):
+        xbearing, ybearing, width, height, dx, dy = ctx.text_extents(line)
+        if y is None:
+            y = height * 1.5
+            x = height * 0.5
+
+        ctx.move_to(x, y)
+        ctx.show_text(line)
+        y += height * 1.5
+
+    ctx.stroke()
+
+    return pixel_map.permute(2, 0, 1)[None, :3].contiguous()
+
+
+######################################################################
+
 import problem
 
 
@@ -118,6 +148,102 @@ class Grids(problem.Problem):
         ("gray", [128, 128, 128]),
     ]
 
+    def check_structure(self, quizzes, struct):
+        S = self.height * self.width
+
+        return (
+            (quizzes[:, 0 * (S + 1)] == self.l2tok[struct[0]])
+            & (quizzes[:, 1 * (S + 1)] == self.l2tok[struct[1]])
+            & (quizzes[:, 2 * (S + 1)] == self.l2tok[struct[2]])
+            & (quizzes[:, 3 * (S + 1)] == self.l2tok[struct[3]])
+        ).all()
+
+    def get_structure(self, quizzes):
+        S = self.height * self.width
+        struct = tuple(
+            self.tok2l[n.item()]
+            for n in quizzes.reshape(quizzes.size(0), 4, S + 1)[0, :, 0]
+        )
+        self.check_structure(quizzes, struct)
+        return struct
+
+    def inject_noise(self, quizzes, noise, struct, mask):
+        assert self.check_structure(quizzes, struct=struct)
+        S = self.height * self.width
+
+        mask = torch.tensor(mask, device=quizzes.device)
+        mask = mask[None, :, None].expand(1, 4, S + 1).clone()
+        mask[:, :, 0] = 0
+        mask = mask.reshape(1, -1).expand_as(quizzes)
+        mask = mask * (torch.rand(mask.size(), device=mask.device) <= noise).long()
+        random = torch.randint(self.nb_colors, mask.size())
+        quizzes = mask * random + (1 - mask) * quizzes
+
+        return quizzes
+
+    # What a mess
+    def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
+        if torch.is_tensor(quizzes):
+            return self.reconfigure([quizzes], struct=struct)[0]
+
+        S = self.height * self.width
+        result = [x.new(x.size()) for x in quizzes]
+
+        struct_from = self.get_structure(quizzes[0][:1])
+        i = self.indices_select(quizzes[0], struct_from)
+
+        sf = dict((l, n) for n, l in enumerate(struct_from))
+
+        for q in range(4):
+            k = sf[struct[q]]
+            for x, y in zip(quizzes, result):
+                l = x.size(1) // 4
+                y[i, q * l : (q + 1) * l] = x[i, k * l : (k + 1) * l]
+
+        j = i == False
+
+        if j.any():
+            for z, y in zip(
+                self.reconfigure([x[j] for x in quizzes], struct=struct), result
+            ):
+                y[j] = z
+
+        return result
+
+    def trivial(self, quizzes):
+        S = self.height * self.width
+        assert self.check_structure(quizzes, struct=("A", "f_A", "B", "f_B"))
+        a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
+        return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min(
+            dim=1
+        ).values
+
+    def make_quiz_mask(
+        self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)
+    ):
+        assert self.check_structure(quizzes, struct)
+
+        ar_mask = quizzes.new_zeros(quizzes.size())
+
+        S = self.height * self.width
+        a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:]
+        a[:, 0, :] = mask[0]
+        a[:, 1, :] = mask[1]
+        a[:, 2, :] = mask[2]
+        a[:, 3, :] = mask[3]
+
+        return ar_mask
+
+    def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")):
+        S = self.height * self.width
+        q = quizzes.reshape(quizzes.size(0), 4, S + 1)
+        return (
+            (q[:, 0, 0] == self.l2tok[struct[0]])
+            & (q[:, 1, 0] == self.l2tok[struct[1]])
+            & (q[:, 2, 0] == self.l2tok[struct[2]])
+            & (q[:, 3, 0] == self.l2tok[struct[3]])
+        )
+
     def __init__(
         self,
         max_nb_cached_chunks=None,
@@ -126,8 +252,35 @@ class Grids(problem.Problem):
         tasks=None,
     ):
         self.colors = torch.tensor([c for _, c in self.named_colors])
+
+        self.nb_colors = len(self.colors)
+        self.token_A = self.nb_colors
+        self.token_f_A = self.token_A + 1
+        self.token_B = self.token_f_A + 1
+        self.token_f_B = self.token_B + 1
+
+        self.nb_rec_max = 5
+        self.rfree = torch.tensor([])
+
+        self.l2tok = {
+            "A": self.token_A,
+            "f_A": self.token_f_A,
+            "B": self.token_B,
+            "f_B": self.token_f_B,
+        }
+
+        self.tok2l = {
+            self.token_A: "A",
+            self.token_f_A: "f_A",
+            self.token_B: "B",
+            self.token_f_B: "f_B",
+        }
+
         self.height = 10
         self.width = 10
+        self.seq_len = 4 * (1 + self.height * self.width)
+        self.nb_token_values = self.token_f_B + 1
+
         self.cache_rec_coo = {}
 
         all_tasks = [
@@ -137,13 +290,18 @@ class Grids(problem.Problem):
             self.task_half_fill,
             self.task_frame,
             self.task_detect,
-            self.task_count,
-            self.task_trajectory,
-            self.task_bounce,
             self.task_scale,
             self.task_symbols,
+            self.task_corners,
+            self.task_contact,
+            self.task_path,
+            self.task_fill,
+            ############################################ hard ones
             self.task_isometry,
-            #            self.task_islands,
+            self.task_trajectory,
+            self.task_bounce,
+            # self.task_count, # NOT REVERSIBLE
+            # self.task_islands, # TOO MESSY
         ]
 
         if tasks is None:
@@ -155,147 +313,130 @@ class Grids(problem.Problem):
 
     ######################################################################
 
-    def frame2img(self, x, scale=15):
-        x = x.reshape(x.size(0), self.height, -1)
-        m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
-        x = self.colors[x * m].permute(0, 3, 1, 2)
-        s = x.shape
-        x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
-        x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
+    def grid2img(self, x, scale=15):
+        m = torch.logical_and(x >= 0, x < self.nb_colors).long()
+        y = self.colors[x * m].permute(0, 3, 1, 2)
+        s = y.shape
+        y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+        y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
 
-        x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
-        x[:, :, torch.arange(0, x.size(2), scale), :] = 0
-        x = x[:, :, 1:, 1:]
+        y[:, :, :, torch.arange(0, y.size(3), scale)] = 64
+        y[:, :, torch.arange(0, y.size(2), scale), :] = 64
 
         for n in range(m.size(0)):
             for i in range(m.size(1)):
                 for j in range(m.size(2)):
                     if m[n, i, j] == 0:
-                        for k in range(2, scale - 2):
-                            for l in [0, 1]:
-                                x[n, :, i * scale + k, j * scale + k - l] = 0
-                                x[
-                                    n, :, i * scale + scale - 1 - k, j * scale + k - l
-                                ] = 0
+                        for k in range(3, scale - 2):
+                            y[n, :, i * scale + k, j * scale + k] = 0
+                            y[n, :, i * scale + k, j * scale + scale - k] = 0
+
+        y = y[:, :, 1:, 1:]
+
+        return y
+
+    def add_frame(self, img, colors, thickness):
+        result = img.new(
+            img.size(0),
+            img.size(1),
+            img.size(2) + 2 * thickness,
+            img.size(3) + 2 * thickness,
+        )
+
+        result[...] = colors[:, :, None, None]
+        result[:, :, thickness:-thickness, thickness:-thickness] = img
 
-        return x
+        return result
 
-    def save_image(
+    def save_quizzes_as_image(
         self,
         result_dir,
         filename,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
+        quizzes,
+        predicted_parts=None,
+        correct_parts=None,
+        comments=None,
+        comment_height=48,
         nrow=4,
         margin=8,
     ):
-        S = self.height * self.width
-        As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
-        f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
-            -1, self.height, self.width
-        )
-        Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
-        prompts = torch.cat([As, f_As, Bs], dim=2)
-        answers = answers.reshape(answers.size(0), self.height, self.width)
+        quizzes = quizzes.to("cpu")
 
-        if predicted_prompts is None:
-            predicted_prompts = 255
+        to_reconfigure = [quizzes]
+        if predicted_parts is not None:
+            to_reconfigure.append(predicted_parts)
+        if correct_parts is not None:
+            to_reconfigure.append(correct_parts)
 
-        if predicted_answers is None:
-            predicted_answers = 255
+        to_reconfigure = self.reconfigure(to_reconfigure, ("A", "f_A", "B", "f_B"))
 
-        def add_frame(x, c, margin, bottom=False):
-            if bottom:
-                h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
-            else:
-                h, w, di, dj = (
-                    x.size(2) + 2 * margin,
-                    x.size(3) + 2 * margin,
-                    margin,
-                    margin,
-                )
-
-            y = x.new_full((x.size(0), x.size(1), h, w), 0)
+        quizzes = to_reconfigure.pop(0)
+        if predicted_parts is not None:
+            predicted_parts = to_reconfigure.pop(0)
+        if correct_parts is not None:
+            correct_parts = to_reconfigure.pop(0)
 
-            if type(c) is int:
-                y[...] = c
-            else:
-                c = c.long()[:, None]
-                c = (
-                    (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
-                    * torch.tensor([64, 64, 64])
-                    + (c == 1).long() * torch.tensor([0, 255, 0])
-                    + (c == 0).long() * torch.tensor([255, 255, 255])
-                    + (c == -1).long() * torch.tensor([255, 0, 0])
-                )
-                y[...] = c[:, :, None, None]
-
-            y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
-
-            return y
+        S = self.height * self.width
 
-        img_prompts = torch.cat(
-            [
-                add_frame(
-                    add_frame(self.frame2img(x), c=0, margin=1),
-                    c=predicted_prompts,
-                    margin=margin,
-                )
-                for x in prompts.to("cpu").split(split_size=self.width, dim=2)
-            ],
-            dim=3,
+        A, f_A, B, f_B = (
+            quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
+            .reshape(quizzes.size(0), 4, self.height, self.width)
+            .permute(1, 0, 2, 3)
         )
 
-        h = img_prompts.size(2)
-        img_answers = add_frame(
-            add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
-            c=predicted_answers,
-            margin=margin,
+        frame, white, gray, green, red = torch.tensor(
+            [[64, 64, 64], [255, 255, 255], [200, 200, 200], [0, 255, 0], [255, 0, 0]],
+            device=quizzes.device,
         )
 
-        separator_size = 2 * margin
+        img_A = self.add_frame(self.grid2img(A), frame[None, :], thickness=1)
+        img_f_A = self.add_frame(self.grid2img(f_A), frame[None, :], thickness=1)
+        img_B = self.add_frame(self.grid2img(B), frame[None, :], thickness=1)
+        img_f_B = self.add_frame(self.grid2img(f_B), frame[None, :], thickness=1)
 
-        separator = img_prompts.new_full(
-            (
-                img_prompts.size(0),
-                img_prompts.size(1),
-                img_prompts.size(2),
-                separator_size,
-            ),
-            255,
-        )
+        # predicted_parts Nx4
+        # correct_parts Nx4
 
-        marker = img_prompts.new_full(
-            (
-                img_prompts.size(0),
-                img_prompts.size(1),
-                img_prompts.size(2),
-                separator_size,
-            ),
-            255,
-        )
+        if predicted_parts is None:
+            colors = white[None, None, :].expand(-1, 4, -1)
+        else:
+            predicted_parts = predicted_parts.to("cpu")
+            if correct_parts is None:
+                colors = (
+                    predicted_parts[:, :, None] * gray[None, None, :]
+                    + (1 - predicted_parts[:, :, None]) * white[None, None, :]
+                )
+            else:
+                correct_parts = correct_parts.to("cpu")
+                colors = (
+                    predicted_parts[:, :, None]
+                    * (
+                        (correct_parts[:, :, None] == 1).long() * green[None, None, :]
+                        + (correct_parts[:, :, None] == 0).long() * gray[None, None, :]
+                        + (correct_parts[:, :, None] == -1).long() * red[None, None, :]
+                    )
+                    + (1 - predicted_parts[:, :, None]) * white[None, None, :]
+                )
 
-        # marker[:, :, 0] = 0
-        # marker[:, :, h - 1] = 0
+        img_A = self.add_frame(img_A, colors[:, 0], thickness=8)
+        img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=8)
+        img_B = self.add_frame(img_B, colors[:, 2], thickness=8)
+        img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=8)
 
-        for k in range(1, 2 * separator_size - 8):
-            i = k - (separator_size - 4)
-            j = separator_size - 5 - abs(i)
-            marker[:, :, h // 2 - 1 + i, 2 + j] = 0
-            marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+        img_A = self.add_frame(img_A, white[None, :], thickness=2)
+        img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2)
+        img_B = self.add_frame(img_B, white[None, :], thickness=2)
+        img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2)
 
-        img = torch.cat(
-            [
-                img_prompts,
-                marker,
-                img_answers,
-            ],
-            dim=3,
-        )
+        img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3)
+
+        if comments is not None:
+            comment_img = [text_img(comment_height, img.size(3), t) for t in comments]
+            comment_img = torch.cat(comment_img, dim=0)
+            img = torch.cat([img, comment_img], dim=2)
 
         image_name = os.path.join(result_dir, filename)
+
         torchvision.utils.save_image(
             img.float() / 255.0,
             image_name,
@@ -306,9 +447,6 @@ class Grids(problem.Problem):
 
     ######################################################################
 
-    def nb_token_values(self):
-        return len(self.colors)
-
     # @torch.compile
     def rec_coo(
         self,
@@ -335,7 +473,8 @@ class Grids(problem.Problem):
             while True:
                 i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
                 j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
-
+                i[:, 1] += 1
+                j[:, 1] += 1
                 big_enough = (
                     (i[:, 1] >= i[:, 0] + min_height)
                     & (j[:, 1] >= j[:, 0] + min_height)
@@ -370,13 +509,13 @@ class Grids(problem.Problem):
                         j[:, 1, 0],
                         j[:, 1, 1],
                     )
-                    no_overlap = torch.logical_not(
+                    no_overlap = (
                         (A_i1 >= B_i2)
-                        & (A_i2 <= B_i1)
-                        & (A_j1 >= B_j1)
-                        & (A_j2 <= B_j1)
+                        | (A_i2 <= B_i1)
+                        | (A_j1 >= B_j2)
+                        | (A_j2 <= B_j1)
                     )
-                    i, j = i[no_overlap], j[no_overlap]
+                    i, j = (i[no_overlap], j[no_overlap])
                 elif nb_rec == 3:
                     A_i1, A_i2, A_j1, A_j2 = (
                         i[:, 0, 0],
@@ -440,10 +579,109 @@ class Grids(problem.Problem):
 
     ######################################################################
 
+    def contact_matrices(self, rn, ri, rj, rz):
+        n = torch.arange(self.nb_rec_max)
+        return (
+            (
+                (
+                    (
+                        (ri[:, :, None, 0] == ri[:, None, :, 1] + 1)
+                        | (ri[:, :, None, 1] + 1 == ri[:, None, :, 0])
+                    )
+                    & (rj[:, :, None, 0] <= rj[:, None, :, 1])
+                    & (rj[:, :, None, 1] >= rj[:, None, :, 0])
+                )
+                | (
+                    (
+                        (rj[:, :, None, 0] == rj[:, None, :, 1] + 1)
+                        | (rj[:, :, None, 1] + 1 == rj[:, None, :, 0])
+                    )
+                    & (ri[:, :, None, 0] <= ri[:, None, :, 1])
+                    & (ri[:, :, None, 1] >= ri[:, None, :, 0])
+                )
+            )
+            # & (rz[:, :, None] == rz[:, None, :])
+            & (n[None, :, None] < rn[:, None, None])
+            & (n[None, None, :] < n[None, :, None])
+        )
+
+    def sample_rworld_states(self, N=1000):
+        while True:
+            ri = (
+                torch.randint(self.height - 2, (N, self.nb_rec_max, 2))
+                .sort(dim=2)
+                .values
+            )
+            ri[:, :, 1] += 2
+            rj = (
+                torch.randint(self.width - 2, (N, self.nb_rec_max, 2))
+                .sort(dim=2)
+                .values
+            )
+            rj[:, :, 1] += 2
+            rn = torch.randint(self.nb_rec_max - 1, (N,)) + 2
+            rz = torch.randint(2, (N, self.nb_rec_max))
+            rc = torch.randint(self.nb_colors - 1, (N, self.nb_rec_max)) + 1
+            n = torch.arange(self.nb_rec_max)
+            nb_collisions = (
+                (
+                    (ri[:, :, None, 0] <= ri[:, None, :, 1])
+                    & (ri[:, :, None, 1] >= ri[:, None, :, 0])
+                    & (rj[:, :, None, 0] <= rj[:, None, :, 1])
+                    & (rj[:, :, None, 1] >= rj[:, None, :, 0])
+                    & (rz[:, :, None] == rz[:, None, :])
+                    & (n[None, :, None] < rn[:, None, None])
+                    & (n[None, None, :] < n[None, :, None])
+                )
+                .long()
+                .flatten(1)
+                .sum(dim=1)
+            )
+
+            no_collision = nb_collisions == 0
+
+            if no_collision.any():
+                print(no_collision.long().sum() / N)
+                self.rn = rn[no_collision]
+                self.ri = ri[no_collision]
+                self.rj = rj[no_collision]
+                self.rz = rz[no_collision]
+                self.rc = rc[no_collision]
+
+                nb_contact = (
+                    self.contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1)
+                )
+
+                self.rcontact = nb_contact > 0
+                self.rfree = torch.full((self.rn.size(0),), True)
+
+                break
+
+    def get_recworld_state(self):
+        if not self.rfree.any():
+            self.sample_rworld_states()
+        k = torch.arange(self.rn.size(0))[self.rfree]
+        k = k[torch.randint(k.size(0), (1,))].item()
+        self.rfree[k] = False
+        return self.rn[k], self.ri[k], self.rj[k], self.rz[k], self.rc[k]
+
+    def draw_state(self, X, rn, ri, rj, rz, rc):
+        for n in sorted(list(range(rn)), key=lambda n: rz[n].item()):
+            X[ri[n, 0] : ri[n, 1] + 1, rj[n, 0] : rj[n, 1] + 1] = rc[n]
+
+    def task_recworld_immobile(self, A, f_A, B, f_B):
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            rn, ri, rj, rz, rc = self.get_recworld_state()
+            self.draw_state(X, rn, ri, rj, rz, rc)
+            ri += 1
+            self.draw_state(f_X, rn, ri, rj, rz, rc)
+
+    ######################################################################
+
     # @torch.compile
     def task_replace_color(self, A, f_A, B, f_B):
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
@@ -459,7 +697,7 @@ class Grids(problem.Problem):
                 break
 
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
                 r = self.rec_coo(nb_rec, prevent_overlap=True)
@@ -484,7 +722,7 @@ class Grids(problem.Problem):
     def task_grow(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
         direction = torch.randint(2, (1,)).item()
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
@@ -510,7 +748,7 @@ class Grids(problem.Problem):
     def task_half_fill(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
+        c = torch.randperm(self.nb_colors - 1)[: 2 * nb_rec] + 1
         direction = torch.randint(4, (1,)).item()
         for X, f_X in [(A, f_A), (B, f_B)]:
             r = self.rec_coo(nb_rec, prevent_overlap=True)
@@ -551,7 +789,7 @@ class Grids(problem.Problem):
     # @torch.compile
     def task_frame(self, A, f_A, B, f_B):
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
@@ -568,14 +806,17 @@ class Grids(problem.Problem):
     # @torch.compile
     def task_detect(self, A, f_A, B, f_B):
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
+                f_X[i1:i2, j1:j2] = c[n]
                 if n < nb_rec - 1:
-                    f_X[i1, j1] = c[-1]
+                    for k in range(2):
+                        f_X[i1 + k, j1] = c[-1]
+                        f_X[i1, j1 + k] = c[-1]
 
     # @torch.compile
     def contact(self, X, i, j, q):
@@ -613,13 +854,13 @@ class Grids(problem.Problem):
 
         return no, nq, nq_diag
 
-    def task_count(self, A, f_A, B, f_B):
+    def REMOVED_task_count(self, A, f_A, B, f_B):
         while True:
             error = False
 
-            N = torch.randint(5, (1,)).item() + 1
-            c = torch.zeros(N + 1)
-            c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1
+            N = 3
+            c = torch.zeros(N + 2, dtype=torch.int64)
+            c[1:] = torch.randperm(self.nb_colors - 1)[: N + 1] + 1
 
             for X, f_X in [(A, f_A), (B, f_B)]:
                 if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
@@ -629,29 +870,31 @@ class Grids(problem.Problem):
                             self.height,
                             self.width,
                             nb_seeds=self.height * self.width // 8,
-                            nb_iterations=self.height * self.width // 10,
+                            nb_iterations=self.height * self.width // 5,
                         )
                     )
 
                 X[...] = self.cache_count.pop()
 
-                k = (X.max() + 1 + (c.size(0) - 1)).item()
-                V = torch.arange(k) // (c.size(0) - 1)
-                V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
-                    c.size(0) - 1
-                ) + 1
+                # k = (X.max() + 1 + (c.size(0) - 1)).item()
+                # V = torch.arange(k) // (c.size(0) - 1)
+                # V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
+                # c.size(0) - 1
+                # ) + 1
+
+                V = torch.randint(N, (X.max() + 1,)) + 1
                 V[0] = 0
+                NB = F.one_hot(c[V]).sum(dim=0)
                 X[...] = c[V[X]]
-
-                if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1:
-                    f_X[...] = 0
-                    for e in range(1, N + 1):
-                        for j in range((X == c[e]).sum() + 1):
-                            if j < self.width:
-                                f_X[e - 1, j] = c[e]
-                            else:
-                                error = True
-                                break
+                f_X[...] = X
+
+                if F.one_hot(X.flatten()).max(dim=0).values.sum().item() >= 3:
+                    m = NB[c[:-1]].max()
+                    if (NB[c[:-1]] == m).long().sum() == 1:
+                        for e in range(1, N + 1):
+                            if NB[c[e]] == m:
+                                a = (f_X == c[e]).long()
+                                f_X[...] = (1 - a) * f_X + a * c[-1]
                 else:
                     error = True
                     break
@@ -659,9 +902,11 @@ class Grids(problem.Problem):
             if not error:
                 break
 
+        assert F.one_hot(A.flatten()).max(dim=0).values.sum() >= 3
+
     # @torch.compile
     def task_trajectory(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:2] + 1
+        c = torch.randperm(self.nb_colors - 1)[:2] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
                 di, dj = torch.randint(7, (2,)) - 3
@@ -692,7 +937,7 @@ class Grids(problem.Problem):
 
     # @torch.compile
     def task_bounce(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:3] + 1
+        c = torch.randperm(self.nb_colors - 1)[:3] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             # @torch.compile
             def free(i, j):
@@ -748,6 +993,7 @@ class Grids(problem.Problem):
                     f_X[i, j] = c[2]
                     if l <= 1:
                         X[i, j] = c[2]
+                        f_X[i, j] = c[1]
 
                     if l >= self.width:
                         break
@@ -760,7 +1006,7 @@ class Grids(problem.Problem):
 
     # @torch.compile
     def task_scale(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:2] + 1
+        c = torch.randperm(self.nb_colors - 1)[:2] + 1
 
         i, j = (
             torch.randint(self.height // 2, (1,)).item(),
@@ -783,13 +1029,16 @@ class Grids(problem.Problem):
                 X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
                 f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
 
-            X[i, j] = c[1]
-            f_X[0:2, 0:2] = c[1]
+            for k in range(2):
+                X[i + k, j] = c[1]
+                X[i, j + k] = c[1]
+                f_X[i + k, j] = c[1]
+                f_X[i, j + k] = c[1]
 
     # @torch.compile
     def task_symbols(self, A, f_A, B, f_B):
         nb_rec = 4
-        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
         delta = 3
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
@@ -801,26 +1050,37 @@ class Grids(problem.Problem):
                 if d.min() > delta:
                     break
 
-            for k in range(1, nb_rec):
-                X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
-
             ai, aj = i.float().mean(), j.float().mean()
 
             q = torch.randint(3, (1,)).item() + 1
 
-            X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
-            X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
-            X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
-            X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
-
             assert i[q] != ai and j[q] != aj
 
-            X[
+            for Z in [X, f_X]:
+                for k in range(0, nb_rec):
+                    Z[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
+                # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
+                # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
+                # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
+                # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
+
+            # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+
+            f_X[i[0] + delta // 2, j[0] + delta // 2] = c[q]
+            # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+
+            ii, jj = (
                 i[0] + delta // 2 + (i[q] - ai).sign().long(),
                 j[0] + delta // 2 + (j[q] - aj).sign().long(),
-            ] = c[nb_rec]
+            )
 
-            f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+            X[ii, jj] = c[nb_rec]
+            X[i[0] + delta // 2, jj] = c[nb_rec]
+            X[ii, j[0] + delta // 2] = c[nb_rec]
+
+            f_X[ii, jj] = c[nb_rec]
+            f_X[i[0] + delta // 2, jj] = c[nb_rec]
+            f_X[ii, j[0] + delta // 2] = c[nb_rec]
 
     # @torch.compile
     def task_isometry(self, A, f_A, B, f_B):
@@ -840,7 +1100,7 @@ class Grids(problem.Problem):
                 X[...] = 0
                 f_X[...] = 0
 
-                c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+                c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
 
                 for r in range(nb_rec):
                     while True:
@@ -906,8 +1166,8 @@ class Grids(problem.Problem):
                 return dist * (1 - walls)
 
     # @torch.compile
-    def task_distance(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:3] + 1
+    def REMOVED_task_distance(self, A, f_A, B, f_B):
+        c = torch.randperm(self.nb_colors - 1)[:3] + 1
         dist0 = torch.empty(self.height + 2, self.width + 2)
         dist1 = torch.empty(self.height + 2, self.width + 2)
         for X, f_X in [(A, f_A), (B, f_B)]:
@@ -969,10 +1229,10 @@ class Grids(problem.Problem):
     # if
 
     # @torch.compile
-    def task_puzzle(self, A, f_A, B, f_B):
+    def TOO_HARD_task_puzzle(self, A, f_A, B, f_B):
         S = 4
         i0, j0 = (self.height - S) // 2, (self.width - S) // 2
-        c = torch.randperm(len(self.colors) - 1)[:4] + 1
+        c = torch.randperm(self.nb_colors - 1)[:4] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
                 f_X[...] = 0
@@ -1037,8 +1297,8 @@ class Grids(problem.Problem):
                         if f_X[i + i0, j + j0] == c[d]:
                             X[ii + i, jj + j] = c[d]
 
-    def task_islands(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:2] + 1
+    def TOO_MESSY_task_islands(self, A, f_A, B, f_B):
+        c = torch.randperm(self.nb_colors - 1)[:2] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0:
                 self.cache_islands = list(
@@ -1062,72 +1322,433 @@ class Grids(problem.Problem):
                     break
 
             X[...] = (A > 0) * c[0]
-            X[i, j] = c[1]
             f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0]
+            f_X[i, j] = X[i, j]
+            X[i, j] = c[1]
+
+    # @torch.compile
+    def TOO_HARD_task_stack(self, A, f_A, B, f_B):
+        N = 5
+        c = torch.randperm(self.nb_colors - 1)[:N] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            i1, j1, i2, j2 = (
+                self.height // 2 - 1,
+                self.width // 2 - 1,
+                self.height // 2 + 1,
+                self.width // 2 + 1,
+            )
+            op = torch.tensor((0, 1, 2, 3) * 4)
+            op = op[torch.randperm(op.size(0))[:9]]
+            for q in range(op.size(0)):
+                u = 3 * (q // 3)
+                v = 3 * (q % 3)
+                d = c[torch.randint(N, (1,)).item()]
+                # X[u+1,v+1]=d
+                if op[q] == 0:  # right
+                    X[u : u + 3, v + 2] = d
+                elif op[q] == 1:  # let
+                    X[u : u + 3, v] = d
+                elif op[q] == 2:  # bottom
+                    X[u + 2, v : v + 3] = d
+                elif op[q] == 3:  # top
+                    X[u, v : v + 3] = d
+
+                if q == 0:
+                    f_X[i1:i2, j1:j2] = d
+                elif op[q] == 0:  # right
+                    f_X[i1:i2, j2] = d
+                    j2 += 1
+                elif op[q] == 1:  # let
+                    j1 -= 1
+                    f_X[i1:i2, j1] = d
+                elif op[q] == 2:  # bottom
+                    f_X[i2, j1:j2] = d
+                    i2 += 1
+                elif op[q] == 3:  # top
+                    i1 -= 1
+                    f_X[i1, j1:j2] = d
+
+    def randint(self, *m):
+        m = torch.tensor(m)
+        return (torch.rand(m.size()) * m).long()
+
+    def TOO_HARD_task_matrices(self, A, f_A, B, f_B):
+        N = 6
+        c = torch.randperm(self.nb_colors - 1)[:N] + 1
+
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            M1 = torch.randint(2, (5, 5))
+            M2 = torch.randint(2, (5, 5))
+            P = M1 @ M2
+            for i in range(5):
+                for j in range(5):
+                    X[i, j] = c[M1[i, j]]
+                    X[i, j + 5] = c[M2[i, j]]
+                    f_X[i, j] = c[M1[i, j]]
+                    f_X[i, j + 5] = c[M2[i, j]]
+                    f_X[i + 5, j + 5] = c[P[i, j]]
+
+    def TOO_HARD_task_compute(self, A, f_A, B, f_B):
+        N = 6
+        c = torch.randperm(self.nb_colors - 1)[:N] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            v = torch.randint((self.width - 1) // 2, (N,)) + 1
+            chain = torch.randperm(N)
+            eq = []
+            for i in range(chain.size(0) - 1):
+                i1, i2 = chain[i], chain[i + 1]
+                v1, v2 = v[i1], v[i2]
+                k = torch.arange(self.width // 2) + 1
+                d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1
+                d = d[torch.randint(d.size(0), (1,)).item()]
+                w1, w2 = d
+                eq.append((c[i1], w1, c[i2], w2))
+
+            ii = torch.randperm(self.height - 2)[: len(eq)]
+
+            for k, x in enumerate(eq):
+                i = ii[k]
+                c1, w1, c2, w2 = x
+                s = torch.randint(self.width - (w1 + w2) + 1, (1,)).item()
+                X[i, s : s + w1] = c1
+                X[i, s + w1 : s + w1 + w2] = c2
+                f_X[i, s : s + w1] = c1
+                f_X[i, s + w1 : s + w1 + w2] = c2
+
+            i1, i2 = torch.randperm(N)[:2]
+            v1, v2 = v[i1], v[i2]
+            k = torch.arange(self.width // 2) + 1
+            d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1
+            d = d[torch.randint(d.size(0), (1,)).item()]
+            w1, w2 = d
+            c1, c2 = c[i1], c[i2]
+            s = 0  # torch.randint(self.width - (w1 + w2) + 1, (1,)).item()
+            i = self.height - 1
+            X[i, s : s + w1] = c1
+            X[i, s + w1 : s + w1 + 1] = c2
+            f_X[i, s : s + w1] = c1
+            f_X[i, s + w1 : s + w1 + w2] = c2
+
+    # @torch.compile
+    # [ai1,ai2] [bi1,bi2]
+    def task_contact(self, A, f_A, B, f_B):
+        def rec_dist(a, b):
+            ai1, aj1, ai2, aj2 = a
+            bi1, bj1, bi2, bj2 = b
+            v = max(ai1 - bi2, bi1 - ai2)
+            h = max(aj1 - bj2, bj1 - aj2)
+            return min(max(v, 0) + max(h + 1, 0), max(v + 1, 0) + max(h, 0))
+
+        nb_rec = 3
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                d = [rec_dist(r[0], r[k]) for k in range(nb_rec)]
+                if min(d[1:]) == 0:
+                    break
+
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                f_X[i1:i2, j1:j2] = c[n]
+                if d[n] == 0:
+                    f_X[i1, j1:j2] = c[0]
+                    f_X[i2 - 1, j1:j2] = c[0]
+                    f_X[i1:i2, j1] = c[0]
+                    f_X[i1:i2, j2 - 1] = c[0]
+
+    # @torch.compile
+    # [ai1,ai2] [bi1,bi2]
+    def task_corners(self, A, f_A, B, f_B):
+        polarity = torch.randint(2, (1,)).item()
+        nb_rec = 3
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            r = self.rec_coo(nb_rec, prevent_overlap=True)
+
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                for k in range(2):
+                    if polarity == 0:
+                        X[i1 + k, j1] = c[n]
+                        X[i2 - 1 - k, j2 - 1] = c[n]
+                        X[i1, j1 + k] = c[n]
+                        X[i2 - 1, j2 - 1 - k] = c[n]
+                    else:
+                        X[i1 + k, j2 - 1] = c[n]
+                        X[i2 - 1 - k, j1] = c[n]
+                        X[i1, j2 - 1 - k] = c[n]
+                        X[i2 - 1, j1 + k] = c[n]
+                    f_X[i1:i2, j1:j2] = c[n]
+
+    def compdist(self, X, i, j):
+        dd = X.new_full((self.height + 2, self.width + 2), self.height * self.width)
+        d = dd[1:-1, 1:-1]
+        m = (X > 0).long()
+        d[i, j] = 0
+        e = d.clone()
+        while True:
+            e[...] = d
+            d[...] = (
+                d.min(dd[:-2, 1:-1] + 1)
+                .min(dd[2:, 1:-1] + 1)
+                .min(dd[1:-1, :-2] + 1)
+                .min(dd[1:-1, 2:] + 1)
+            )
+            d[...] = (1 - m) * d + m * self.height * self.width
+            if e.equal(d):
+                break
+
+        return d
+
+    # @torch.compile
+    def task_path(self, A, f_A, B, f_B):
+        nb_rec = 2
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 2] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                X[...] = 0
+                f_X[...] = 0
+
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                for n in range(nb_rec):
+                    i1, j1, i2, j2 = r[n]
+                    X[i1:i2, j1:j2] = c[n]
+                    f_X[i1:i2, j1:j2] = c[n]
+
+                i1, i2 = torch.randint(self.height, (2,))
+                j1, j2 = torch.randint(self.width, (2,))
+                if (
+                    abs(i1 - i2) + abs(j1 - j2) > 2
+                    and X[i1, j1] == 0
+                    and X[i2, j2] == 0
+                ):
+                    d2 = self.compdist(X, i2, j2)
+                    d = self.compdist(X, i1, j1)
+
+                    if d2[i1, j1] < 2 * self.width:
+                        break
+
+            m = ((d + d2) == d[i2, j2]).long()
+            f_X[...] = m * c[-1] + (1 - m) * f_X
+
+            X[i1, j1] = c[-2]
+            X[i2, j2] = c[-2]
+            f_X[i1, j1] = c[-2]
+            f_X[i2, j2] = c[-2]
+
+    # @torch.compile
+    def task_fill(self, A, f_A, B, f_B):
+        nb_rec = 3
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            accept_full = torch.rand(1) < 0.5
+
+            while True:
+                X[...] = 0
+                f_X[...] = 0
+
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                for n in range(nb_rec):
+                    i1, j1, i2, j2 = r[n]
+                    X[i1:i2, j1:j2] = c[n]
+                    f_X[i1:i2, j1:j2] = c[n]
+
+                while True:
+                    i, j = (
+                        torch.randint(self.height, (1,)).item(),
+                        torch.randint(self.width, (1,)).item(),
+                    )
+                    if X[i, j] == 0:
+                        break
+
+                d = self.compdist(X, i, j)
+                m = (d < self.height * self.width).long()
+                X[i, j] = c[-1]
+                f_X[...] = m * c[-1] + (1 - m) * f_X
+                f_X[i, j] = 0
+
+                if accept_full or (d * (X == 0)).max() == self.height * self.width:
+                    break
+
+    def TOO_HARD_task_addition(self, A, f_A, B, f_B):
+        c = torch.randperm(self.nb_colors - 1)[:4] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
+            N2 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
+            S = N1 + N2
+            for j in range(self.width):
+                r1 = (N1 // (2**j)) % 2
+                X[0, -j - 1] = c[r1]
+                f_X[0, -j - 1] = c[r1]
+                r2 = (N2 // (2**j)) % 2
+                X[1, -j - 1] = c[r2]
+                f_X[1, -j - 1] = c[r2]
+                rs = (S // (2**j)) % 2
+                f_X[2, -j - 1] = c[2 + rs]
+
+    def task_science_implicit(self, A, f_A, B, f_B):
+        nb_rec = 5
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                i1, i2 = torch.randint(self.height, (2,)).sort().values
+                if i1 >= 1 and i2 < self.height and i1 + 3 < i2:
+                    break
+
+            while True:
+                j1, j2 = torch.randint(self.width, (2,)).sort().values
+                if j1 >= 1 and j2 < self.width and j1 + 3 < j2:
+                    break
+
+            f_X[i1:i2, j1:j2] = c[0]
+
+            # ---------------------
+
+            while True:
+                ii1, ii2 = torch.randint(self.height, (2,)).sort().values
+                if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
+                    break
+            jj = torch.randint(j1, (1,))
+            X[ii1:ii2, jj:j1] = c[1]
+            f_X[ii1:ii2, jj:j1] = c[1]
+
+            while True:
+                ii1, ii2 = torch.randint(self.height, (2,)).sort().values
+                if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
+                    break
+            jj = torch.randint(self.width - j2, (1,)) + j2 + 1
+            X[ii1:ii2, j2:jj] = c[2]
+            f_X[ii1:ii2, j2:jj] = c[2]
+
+            # ---------------------
+
+            while True:
+                jj1, jj2 = torch.randint(self.width, (2,)).sort().values
+                if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
+                    break
+            ii = torch.randint(i1, (1,))
+            X[ii:i1, jj1:jj2] = c[3]
+            f_X[ii:i1, jj1:jj2] = c[3]
+
+            while True:
+                jj1, jj2 = torch.randint(self.width, (2,)).sort().values
+                if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
+                    break
+            ii = torch.randint(self.height - i2, (1,)) + i2 + 1
+            X[i2:ii, jj1:jj2] = c[4]
+            f_X[i2:ii, jj1:jj2] = c[4]
+
+    def task_science_dot(self, A, f_A, B, f_B):
+        nb_rec = 3
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                X[...] = 0
+                f_X[...] = 0
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                i, j = (
+                    torch.randint(self.height, (1,)).item(),
+                    torch.randint(self.width, (1,)).item(),
+                )
+                q = 0
+                for n in range(nb_rec):
+                    i1, j1, i2, j2 = r[n]
+                    X[i1:i2, j1:j2] = c[n]
+                    f_X[i1:i2, j1:j2] = c[n]
+                    if i >= i1 and i < i2:
+                        q += 1
+                        f_X[i, j1:j2] = c[-1]
+                    if j >= j1 and j < j2:
+                        q += 1
+                        f_X[i1:i2, j] = c[-1]
+                X[i, j] = c[-1]
+                f_X[i, j] = c[-1]
+                if q >= 2:
+                    break
+
+    def collide(self, s, r, rs):
+        i, j = r
+        for i2, j2 in rs:
+            if abs(i - i2) < s and abs(j - j2) < s:
+                return True
+        return False
+
+    def task_science_tag(self, A, f_A, B, f_B):
+        c = torch.randperm(self.nb_colors - 1)[:4] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            rs = []
+            while len(rs) < 4:
+                i, j = (
+                    torch.randint(self.height - 3, (1,)).item(),
+                    torch.randint(self.width - 3, (1,)).item(),
+                )
+                if not self.collide(s=3, r=(i, j), rs=rs):
+                    rs.append((i, j))
+
+            for k in range(len(rs)):
+                i, j = rs[k]
+                q = min(k, 2)
+                X[i, j : j + 3] = c[q]
+                X[i + 2, j : j + 3] = c[q]
+                X[i : i + 3, j] = c[q]
+                X[i : i + 3, j + 2] = c[q]
+
+                f_X[i, j : j + 3] = c[q]
+                f_X[i + 2, j : j + 3] = c[q]
+                f_X[i : i + 3, j] = c[q]
+                f_X[i : i + 3, j + 2] = c[q]
+                if q == 2:
+                    f_X[i + 1, j + 1] = c[-1]
+
+    # end_tasks
 
     ######################################################################
 
-    def trivial_prompts_and_answers(self, prompts, answers):
+    def create_empty_quizzes(self, nb, struct=("A", "f_A", "B", "f_B")):
         S = self.height * self.width
-        Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
-        f_Bs = answers
-        return (Bs == f_Bs).long().min(dim=-1).values > 0
+        quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64)
+        quizzes[:, 0 * (S + 1)] = self.l2tok[struct[0]]
+        quizzes[:, 1 * (S + 1)] = self.l2tok[struct[1]]
+        quizzes[:, 2 * (S + 1)] = self.l2tok[struct[2]]
+        quizzes[:, 3 * (S + 1)] = self.l2tok[struct[3]]
 
-    def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
-        if tasks is None:
-            tasks = self.all_tasks
+        return quizzes
 
+    def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
         S = self.height * self.width
-        prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
-        answers = torch.zeros(nb, S, dtype=torch.int64)
 
-        bunch = zip(prompts, answers)
+        if tasks is None:
+            tasks = self.all_tasks
+
+        quizzes = self.create_empty_quizzes(nb, ("A", "f_A", "B", "f_B"))
 
         if progress_bar:
-            bunch = tqdm.tqdm(
-                bunch,
+            quizzes = tqdm.tqdm(
+                quizzes,
                 dynamic_ncols=True,
-                desc="world generation",
-                total=prompts.size(0),
+                desc="world quizzes generation",
+                total=quizzes.size(0),
             )
 
-        for prompt, answer in bunch:
-            A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
-            f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
-            B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
-            f_B = answer.view(self.height, self.width)
+        for quiz in quizzes:
+            q = quiz.reshape(4, S + 1)[:, 1:].reshape(4, self.height, self.width)
+            q[...] = 0
+            A, f_A, B, f_B = q
             task = tasks[torch.randint(len(tasks), (1,)).item()]
             task(A, f_A, B, f_B)
 
-        return prompts.flatten(1), answers.flatten(1)
-
-    def save_quiz_illustrations(
-        self,
-        result_dir,
-        filename_prefix,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
-        nrow=4,
-    ):
-        self.save_image(
-            result_dir,
-            filename_prefix + ".png",
-            prompts,
-            answers,
-            predicted_prompts,
-            predicted_answers,
-            nrow,
-        )
+        return quizzes
 
-    def save_some_examples(self, result_dir):
-        nb, nrow = 72, 4
+    def save_some_examples(self, result_dir, prefix=""):
+        nb, nrow = 128, 4
         for t in self.all_tasks:
             print(t.__name__)
-            prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
-            self.save_quiz_illustrations(
-                result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+            quizzes = self.generate_w_quizzes_(nb, tasks=[t])
+            self.save_quizzes_as_image(
+                result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow
             )
 
 
@@ -1137,41 +1758,86 @@ if __name__ == "__main__":
     import time
 
     # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
+
     grids = Grids()
 
+    # nb = 5
+    # quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
+    # print(quizzes)
+    # print(grids.get_structure(quizzes))
+    # quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B"))
+    # print("DEBUG2", quizzes)
+    # print(grids.get_structure(quizzes))
+    # print(quizzes)
+
+    # i = torch.rand(quizzes.size(0)) < 0.5
+
+    # quizzes[i] = grids.reconfigure(quizzes[i], struct=("f_B", "f_A", "B", "A"))
+
+    # j = grids.indices_select(quizzes, struct=("f_B", "f_A", "B", "A"))
+
+    # print(
+    # i.equal(j),
+    # grids.get_structure(quizzes[j]),
+    # grids.get_structure(quizzes[j == False]),
+    # )
+
+    #   exit(0)
+
     # nb = 1000
     # grids = problem.MultiThreadProblem(
     # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
     # )
     #    time.sleep(10)
     # start_time = time.perf_counter()
-    # prompts, answers = grids.generate_prompts_and_answers(nb)
+    # prompts, answers = grids.generate_w_quizzes(nb)
     # delay = time.perf_counter() - start_time
     # print(f"{prompts.size(0)/delay:02f} seq/s")
     # exit(0)
 
     # if True:
-    nb, nrow = 72, 4
+    nb, nrow = 128, 4
     # nb, nrow = 8, 2
 
     # for t in grids.all_tasks:
-    for t in [grids.task_distance]:
+
+    for t in [grids.task_recworld_immobile]:
         print(t.__name__)
-        prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
-        grids.save_quiz_illustrations(
-            "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+        w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
+        grids.save_quizzes_as_image(
+            "/tmp",
+            t.__name__ + ".png",
+            w_quizzes,
+            comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
         )
 
-    exit(0)
+    exit(0)
 
     nb = 1000
 
-    # for t in grids.all_tasks:
-    for t in [grids.task_distance]:
+    for t in [
+        # grids.task_bounce,
+        # grids.task_contact,
+        # grids.task_corners,
+        # grids.task_detect,
+        # grids.task_fill,
+        # grids.task_frame,
+        # grids.task_grow,
+        # grids.task_half_fill,
+        # grids.task_isometry,
+        # grids.task_path,
+        # grids.task_replace_color,
+        # grids.task_scale,
+        grids.task_symbols,
+        # grids.task_trajectory,
+        # grids.task_translate,
+    ]:
+        # for t in [grids.task_path]:
         start_time = time.perf_counter()
-        prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
+        w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
         delay = time.perf_counter() - start_time
-        print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
+        print(f"{t.__name__} {w_quizzes.size(0)/delay:02f} seq/s")
+        grids.save_quizzes_as_image("/tmp", t.__name__ + ".png", w_quizzes[:128])
 
     exit(0)
 
@@ -1179,9 +1845,9 @@ if __name__ == "__main__":
     predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
     predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
 
-    grids.save_quiz_illustrations(
+    grids.save_quizzes_as_image(
         "/tmp",
-        "test",
+        "test.png",
         prompts[:nb],
         answers[:nb],
         # You can add a bool to put a frame around the predicted parts
diff --git a/main.py b/main.py
index 6b00bbf..40772c2 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -3,6 +3,9 @@
 # Any copyright is dedicated to the Public Domain.
 # https://creativecommons.org/publicdomain/zero/1.0/
 
+# > A > f(A) > B ; > f(B)
+# < f(B) ; < B < f(A) < A
+
 # Written by Francois Fleuret <francois@fleuret.org>
 
 import math, sys, argparse, time, tqdm, os, datetime, warnings
@@ -16,7 +19,9 @@ import ffutils
 import mygpt
 import sky, grids, quiz_machine
 
-import threading
+from quiz_machine import one_batch_masked_inplace_autoregression
+
+import threading, subprocess
 
 import torch.multiprocessing as mp
 
@@ -36,7 +41,9 @@ parser.add_argument("--resume", action="store_true", default=False)
 
 parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
 
-########################################
+parser.add_argument("--log_command", type=str, default=None)
+
+# ----------------------------------
 
 parser.add_argument("--nb_epochs", type=int, default=10000)
 
@@ -44,6 +51,8 @@ parser.add_argument("--batch_size", type=int, default=None)
 
 parser.add_argument("--physical_batch_size", type=int, default=None)
 
+parser.add_argument("--inference_batch_size", type=int, default=None)
+
 parser.add_argument("--nb_train_samples", type=int, default=None)
 
 parser.add_argument("--nb_test_samples", type=int, default=None)
@@ -54,8 +63,9 @@ parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
 
 parser.add_argument("--learning_rate", type=float, default=5e-4)
 
-########################################
+parser.add_argument("--schedule_free", action="store_true", default=False)
 
+# ----------------------------------
 parser.add_argument("--model", type=str, default=None)
 
 parser.add_argument("--dim_model", type=int, default=None)
@@ -70,8 +80,7 @@ parser.add_argument("--nb_blocks", type=int, default=None)
 
 parser.add_argument("--dropout", type=float, default=0.1)
 
-########################################
-
+# ----------------------------------
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
 parser.add_argument("--problem", type=str, default="grids")
@@ -80,18 +89,28 @@ parser.add_argument("--nb_threads", type=int, default=1)
 
 parser.add_argument("--gpus", type=str, default="all")
 
+# ----------------------------------
+
 parser.add_argument("--nb_gpts", type=int, default=5)
 
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
+parser.add_argument("--max_fail_to_validate", type=int, default=3)
+
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
+
+parser.add_argument("--proba_understands", type=float, default=0.95)
 
-parser.add_argument("--proba_understands", type=float, default=0.9)
+parser.add_argument("--proba_not_understands", type=float, default=0.1)
 
-parser.add_argument("--proba_not_understands", type=float, default=0.5)
+parser.add_argument("--temperature_hot", type=float, default=1.5)
 
-parser.add_argument("--generation_temperature", type=float, default=1.0)
+parser.add_argument("--temperature_cold", type=float, default=1)
+
+parser.add_argument("--prompt_noise", type=float, default=0.05)
 
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
+parser.add_argument("--test", type=str, default=None)
+
 ######################################################################
 
 grids_tasks = ", ".join(
@@ -99,10 +118,17 @@ grids_tasks = ", ".join(
 )
 
 parser.add_argument(
-    "--grids_tasks",
+    "--grids_world_tasks",
+    type=str,
+    default="replace_color,translate,grow,frame",
+    help="A comma-separated subset of: " + grids_tasks + ".",
+)
+
+parser.add_argument(
+    "--grids_science_tasks",
     type=str,
     default=None,
-    help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
+    help="A comma-separated subset of: " + grids_tasks + ", or None.",
 )
 
 ######################################################################
@@ -124,13 +150,22 @@ args = parser.parse_args()
 if args.result_dir is None:
     args.result_dir = f"results_culture"
 
+assert not args.grids_science_tasks or (
+    len(
+        set(args.grids_world_tasks.split(","))
+        & set(args.grids_science_tasks.split(","))
+    )
+    == 0
+), "World and science tasks have to be disjoint"
+
 ######################################################################
 
 default_args = {
     "model": "37M",
     "batch_size": 25,
-    "nb_train_samples": 100000,
-    "nb_test_samples": 10000,
+    "inference_batch_size": 50,
+    "nb_train_samples": 40000,
+    "nb_test_samples": 1000,
 }
 
 for k, v in default_args.items():
@@ -220,9 +255,17 @@ def log_string(s):
     sys.stdout.flush()
 
 
+######################################################################
+# Create a time-stamped archive of the source code
+
+with open("this_run.sh", "w") as f:
+    f.write(f"{' '.join(sys.argv)}\n")
+
 now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
 
-os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py")
+os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
+
+######################################################################
 
 log_string(f"argv {' '.join(sys.argv)}")
 
@@ -268,27 +311,41 @@ if args.problem == "sky":
         chunk_size=100,
         nb_threads=args.nb_threads,
     )
-    back_accuracy = False
+
 elif args.problem == "grids":
     problem = grids.Grids(
         max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
         chunk_size=100,
         nb_threads=args.nb_threads,
-        tasks=args.grids_tasks,
+        tasks=args.grids_world_tasks,
     )
-    back_accuracy = True
+
+    if args.grids_science_tasks is None:
+        science_w_quizzes = None
+    else:
+        science_problem = grids.Grids(
+            max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+            chunk_size=100,
+            nb_threads=args.nb_threads,
+            tasks=args.grids_science_tasks,
+        )
+        science_w_quizzes = science_problem.generate_w_quizzes(100)
+
+        if not args.resume:
+            science_problem.save_some_examples(args.result_dir, "science_")
+
+
 else:
     raise ValueError
 
-problem.save_some_examples(args.result_dir)
+if not args.resume:
+    problem.save_some_examples(args.result_dir)
 
 quiz_machine = quiz_machine.QuizMachine(
     problem=problem,
-    nb_train_samples=args.nb_train_samples,
-    nb_test_samples=args.nb_test_samples,
-    back_accuracy=back_accuracy,
-    batch_size=args.physical_batch_size,
+    batch_size=args.inference_batch_size,
     result_dir=args.result_dir,
+    prompt_noise=args.prompt_noise,
     logger=log_string,
     device=main_device,
 )
@@ -304,23 +361,54 @@ log_string(f"vocabulary_size {vocabulary_size}")
 ######################################################################
 
 
-def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_device):
+def optimizer_to(optim, device):
+    for param in optim.state.values():
+        # Not sure there are any global tensors in the state dict
+        if isinstance(param, torch.Tensor):
+            param.data = param.data.to(device)
+            if param._grad is not None:
+                param._grad.data = param._grad.data.to(device)
+        elif isinstance(param, dict):
+            for subparam in param.values():
+                if isinstance(subparam, torch.Tensor):
+                    subparam.data = subparam.data.to(device)
+                    if subparam._grad is not None:
+                        subparam._grad.data = subparam._grad.data.to(device)
+
+
+######################################################################
+
+
+def run_tests(model, quiz_machine, local_device=main_device):
     with torch.autograd.no_grad():
-        model.eval().to(local_device)
+        model.to(local_device).eval()
+        if args.schedule_free:
+            model.optimizer.eval()
 
         nb_test_samples, acc_test_loss = 0, 0.0
         nb_samples_accumulated = 0
 
-        for input in quiz_machine.batches(model, split="test"):
-            input = input.to(local_device)
-
-            bs = model(mygpt.BracketedSequence(input))
-            output = bs.x
-
-            loss = F.cross_entropy(output.transpose(1, 2), input)
+        full_input, full_mask_loss = quiz_machine.data_input(model, split="test")
+        src = zip(
+            full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
+        )
 
+        for input, mask_loss in tqdm.tqdm(
+            src,
+            dynamic_ncols=True,
+            desc="test",
+            total=full_input.size(0) // args.batch_size,
+        ):
+            input = input.to(local_device)
+            mask_loss = mask_loss.to(local_device)
+            targets = input
+
+            output = model(mygpt.BracketedSequence(input)).x
+            loss_per_token = F.cross_entropy(
+                output.transpose(1, 2), targets, reduction="none"
+            )
+            loss = (loss_per_token * mask_loss).mean()
             acc_test_loss += loss.item() * input.size(0)
-
             nb_test_samples += input.size(0)
 
         test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
@@ -330,149 +418,616 @@ def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_de
         model.main_test_accuracy = quiz_machine.produce_results(
             n_epoch=n_epoch,
             model=model,
+            input=full_input[:2000],
             result_dir=args.result_dir,
-            deterministic_synthesis=deterministic_synthesis,
         )
 
 
+######################################################################
+
+
 def one_epoch(model, quiz_machine, local_device=main_device):
     model.to(local_device).train()
+    optimizer_to(model.optimizer, local_device)
 
-    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+    if args.schedule_free:
+        model.optimizer.train()
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    for input in quiz_machine.batches(model, split="train"):
+    hard_w_quizzes = []
+
+    full_input, full_mask_loss = quiz_machine.data_input(model, split="train")
+    src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
+
+    for input, mask_loss in tqdm.tqdm(
+        src,
+        dynamic_ncols=True,
+        desc="training",
+        total=full_input.size(0) // args.batch_size,
+    ):
         input = input.to(local_device)
+        mask_loss = mask_loss.to(local_device)
 
         if nb_train_samples % args.batch_size == 0:
-            optimizer.zero_grad()
+            model.optimizer.zero_grad()
+
+        targets = input
 
         output = model(mygpt.BracketedSequence(input)).x
-        loss = F.cross_entropy(output.transpose(1, 2), input)
+        loss_per_token = F.cross_entropy(
+            output.transpose(1, 2), targets, reduction="none"
+        )
+        loss = (loss_per_token * mask_loss).mean() + model.loss
         acc_train_loss += loss.item() * input.size(0)
 
+        loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
+
         nb_train_samples += input.size(0)
 
         loss.backward()
 
         if nb_train_samples % args.batch_size == 0:
-            optimizer.step()
+            model.optimizer.step()
 
     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
 
     log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}")
 
-    run_tests(model, quiz_machine, deterministic_synthesis=False)
+    run_tests(model, quiz_machine)
+
+    # threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values
+    # threshold = threshold[threshold.size(0) // 2]
+
+    # model.hard_w_quizzes = torch.cat(
+    # [x[l >= threshold] for x, l in hard_w_quizzes], dim=0
+    # )
 
     model.to(main_device)
+    optimizer_to(model.optimizer, main_device)
 
 
 ######################################################################
 
-# This is the key routine that decides what generated quizzes to keep
 
+def model_transformer_hot(model):
+    model.temperature = args.temperature_hot
+    # model.set_noise_injection(1.0, ("ffw", args.nb_blocks // 2))
 
-# token_logprobas are NxMxT where M is the number of models
 
+def model_transformer_cold(model):
+    model.temperature = args.temperature_cold
+    # pass
 
-def compute_valid_quizzes_(token_logprobas):
-    warnings.warn("validation with uniform constraints", RuntimeWarning)
-    l = token_logprobas.min(dim=-1).values.sort(dim=-1).values
-    return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5))
 
+c_quizzes_procedure = [
+    (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
+    (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
+    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
+    (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold),
+]
+
+######################################################################
 
-def compute_valid_quizzes(token_logprobas):
-    l = token_logprobas.sum(dim=-1).sort(dim=-1).values
-    return (l[:, 0] < math.log(args.proba_not_understands)) & (
-        l[:, 1] > math.log(args.proba_understands)
+
+def save_additional_results(model, models, science_w_quizzes):
+    # Save generated quizzes with the successive steps
+
+    recorder = []
+
+    c_quizzes = quiz_machine.generate_c_quizzes(
+        64,
+        model_for_generation=model,
+        procedure=c_quizzes_procedure,
+        recorder=recorder,
+    )
+
+    # This is nb_quizzes x nb_models
+
+    seq_logproba = quiz_machine.models_logprobas(
+        models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+    ) + quiz_machine.models_logprobas(
+        models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+    )
+
+    probas = seq_logproba.exp()
+
+    comments = []
+
+    for l in seq_logproba:
+        comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
+
+    ##
+
+    c_quizzes = torch.cat([c[:, None, :] for c, _, in recorder], dim=1)
+    predicted_parts = torch.cat([t[:, None, :] for _, t in recorder], dim=1)
+    nb_steps = c_quizzes.size(1)
+    c_quizzes = c_quizzes.reshape(-1, c_quizzes.size(-1))
+    predicted_parts = predicted_parts.reshape(-1, predicted_parts.size(-1))
+
+    # We have comments only for the final quiz, not the successive
+    # steps, so we have to add nb_steps-1 empty comments
+
+    steps_comments = []
+    for c in comments:
+        steps_comments += [""] * (nb_steps - 1) + [c]
+
+    filename = f"non_validated_{n_epoch:04d}_{model.id:02d}.png"
+
+    quiz_machine.problem.save_quizzes_as_image(
+        args.result_dir,
+        filename,
+        quizzes=c_quizzes,
+        predicted_parts=predicted_parts,
+        comments=steps_comments,
+        nrow=nb_steps * 2,  # two quiz per row
     )
 
+    log_string(f"wrote {filename}")
 
-def extract_valid_quizzes_and_logprobas(recorded):
-    validated_quizzes, validated_logprobas = [], []
-    for quizzes, token_logprobas in recorded:
-        validated_indices = compute_valid_quizzes(token_logprobas)
-        validated_quizzes.append(quizzes[validated_indices])
-        validated_logprobas.append(token_logprobas[validated_indices])
+    ######################################################################
 
-    if len(validated_quizzes) > 0:
-        return torch.cat(validated_quizzes, dim=0), torch.cat(
-            validated_logprobas, dim=0
+    if science_w_quizzes is not None:
+        struct = ("A", "f_A", "B", "f_B")
+        mask = (0, 0, 0, 1)
+        result, correct = quiz_machine.predict(
+            model=model,
+            quizzes=science_w_quizzes.to(main_device),
+            struct=struct,
+            mask=mask,
         )
-    else:
-        return None, None
 
+        predicted_parts = torch.tensor(mask, device=correct.device)[None, :].expand(
+            correct.size(0), -1
+        )
+        correct = (2 * correct - 1) * (predicted_parts.sum(dim=-1) == 1).long()
 
-######################################################################
+        nb_correct = (correct == 1).long().sum()
+        nb_total = (correct != 0).long().sum()
+
+        log_string(
+            f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}"
+        )
+
+        i = correct == 1
+        j = correct != 1
 
+        result = torch.cat([result[i], result[j]], dim=0)
+        correct = torch.cat([correct[i], correct[j]], dim=0)
+        correct_parts = predicted_parts * correct[:, None]
 
-def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
-    nb_to_create = nb_for_train + nb_for_test
+        result = result[:128]
+        predicted_parts = predicted_parts[:128]
+        correct_parts = correct_parts[:128]
 
-    recorded_quizzes_logprobas = []
+        quiz_machine.problem.save_quizzes_as_image(
+            args.result_dir,
+            f"culture_science_{n_epoch:04d}_{model.id:02d}.png",
+            quizzes=result,
+            predicted_parts=predicted_parts,
+            correct_parts=correct_parts,
+        )
 
+
+######################################################################
+
+
+def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
+    nb_to_validate = nb_for_train + nb_for_test
+    nb_to_generate_per_iteration = max(args.physical_batch_size, nb_to_validate)
     nb_validated = 0
 
-    while nb_validated < nb_to_create:
-        model_for_generation = models[torch.randint(len(models), (1,))]
+    recorded_validated = []
+
+    start_time = time.perf_counter()
+
+    nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64)
+
+    while nb_validated_per_model.sum() < nb_to_validate:
+        # We use the model that has generated the fewest quizzes to
+        # balance the number of quizzes per model overall
 
-        c_quizzes = quiz_machine.generate_quizzes(
-            nb_to_create,
-            model_for_generation=model_for_generation,
-            temperature=args.generation_temperature,
+        # model_for_generation = sorted(
+        # models, key=lambda m: nb_validated_per_model[m.id]
+        # )[0]
+
+        model_for_generation = models[torch.randint(len(models), (1,)).item()]
+
+        # We generate quizzes with a procedure that injects some
+        # structured noise
+
+        c_quizzes = quiz_machine.generate_c_quizzes(
+            nb_to_generate_per_iteration,
+            model_for_generation=model,
+            procedure=c_quizzes_procedure,
         )
 
-        c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
+        # We discard the trivial ones, according to a criterion
+        # specific to the world quizzes (e.g. B=f(B))
 
-        if c_quizzes.size(0) > 0:
-            token_logproba = quiz_machine.solution_token_logprobas(models, c_quizzes)
-            recorded_quizzes_logprobas.append((c_quizzes, token_logproba))
+        to_keep = quiz_machine.problem.trivial(c_quizzes) == False
 
-            (
-                validated_quizzes,
-                validated_logprobas,
-            ) = extract_valid_quizzes_and_logprobas(recorded_quizzes_logprobas)
+        c_quizzes = c_quizzes[to_keep]
 
-            if validated_quizzes is not None:
-                nb_validated = validated_quizzes.size(0)
+        # This is nb_quizzes x nb_models
+
+        seq_logproba = quiz_machine.models_logprobas(
+            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        ) + quiz_machine.models_logprobas(
+            models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        )
+
+        probas = seq_logproba.exp()
+
+        nb_succeed = (probas >= args.proba_understands).long().sum(dim=1)
+        nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1)
+
+        to_keep = (
+            (nb_succeed + nb_fail == probas.size(1))
+            & (nb_fail >= 1)
+            & (nb_fail <= args.max_fail_to_validate)
+        )
+
+        c_quizzes = c_quizzes[to_keep]
+
+        if c_quizzes.size(0) > 0:
+            nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
+            recorded_validated.append(c_quizzes)
+            nb_validated = c_quizzes.size(0)
+        else:
+            nb_validated = 0
+
+        total_nb_validated = nb_validated_per_model.sum().item()
+
+        duration = time.perf_counter() - start_time
+
+        if total_nb_validated > 0:
+            if total_nb_validated < nb_to_validate:
+                d = (
+                    (nb_to_validate - total_nb_validated)
+                    * duration
+                    / total_nb_validated
+                )
+                e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
+                    "%a %H:%M"
+                )
+            else:
+                e = "now!"
+        else:
+            e = "???"
 
         log_string(
-            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}"
+            f"keep c_quizzes model {model_for_generation.id} validated {nb_validated} / {nb_to_generate_per_iteration} ({100*nb_validated/nb_to_generate_per_iteration:.02f}%) nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
         )
 
+    validated_quizzes = torch.cat(recorded_validated, dim=0)
+
+    ######################################################################
     # store the new c_quizzes which have been validated
 
-    quiz_machine.reverse_random_half_in_place(validated_quizzes)
-    quiz_machine.store_c_quizzes(validated_quizzes[:nb_for_train], for_train=True)
-    quiz_machine.store_c_quizzes(
-        validated_quizzes[nb_for_train:nb_to_create], for_train=False
-    )
+    v_train = validated_quizzes[:nb_for_train]
+    quiz_machine.store_c_quizzes(v_train, for_train=True)
+
+    v_test = validated_quizzes[nb_for_train:nb_to_validate]
+    quiz_machine.store_c_quizzes(v_test, for_train=False)
 
     ######################################################################
-    # save images with their logprobas
+    # save images
 
-    vq = validated_quizzes[:72]
-    vl = validated_logprobas[:72]
+    vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
 
     if vq.size(0) > 0:
-        prefix = f"culture_c_quiz_{n_epoch:04d}"
-        filename = os.path.join(args.result_dir, prefix + "_logp.pth")
-        torch.save(vl, filename)
-        # with open(file_name, "w") as logp_file:
-        # for l in vl:
-        # s = " ".join([str(x.item()) for x in l])
-        # logp_file.write(s + "\n")
+        seq_logproba = quiz_machine.models_logprobas(
+            models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        ) + quiz_machine.models_logprobas(
+            models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        )
 
-        quiz_machine.save_quiz_illustrations(args.result_dir, prefix, vq)
+        probas = seq_logproba.exp()
+
+        comments = []
+
+        for l in seq_logproba:
+            comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
+
+        filename = f"culture_c_quiz_{n_epoch:04d}.png"
+        quiz_machine.problem.save_quizzes_as_image(
+            args.result_dir, filename, vq, comments=comments
+        )
+
+
+######################################################################
+
+# The generator is very similar to a "solving GPT" except that it
+# deals with quizzes prologued with one token per solving GPT that
+# indicates if the said model solves it or not.
+#
+# There are three levels of solving 0->proba<=proba_not_understands,
+# 2->proba>=proba_understands and 1 otherwise.
+
+
+def generate_c_quizzes_with_generator(generator, quiz_machine, nb):
+    generator.to(main_device)
+
+    struct = ("A", "f_A", "B", "f_B")
+
+    c_quizzes = quiz_machine.problem.create_empty_quizzes(nb, struct=struct)
+    ar_mask = quiz_machine.make_quiz_mask(c_quizzes, struct, (1, 1, 1, 1))
+
+    i = F.one_hot(
+        torch.randint(args.nb_gpts, (c_quizzes.size(0),)),
+        num_classes=args.nb_gpts,
+    )
+
+    prologs_c_quizzes = token_prolog_0 * i + token_prolog_2 * (1 - i)
+    prologs_ar_mask = ar_mask.new_zeros(ar_mask.size(0), prologs_c_quizzes.size(1))
+
+    prologued_c_quizzes = torch.cat([prologs_c_quizzes, c_quizzes], dim=1).to(
+        main_device
+    )
+    prologued_ar_mask = torch.cat([prologs_ar_mask, ar_mask], dim=1).to(main_device)
+
+    seq_logproba = torch.zeros(
+        prologued_c_quizzes.size(0), device=prologued_c_quizzes.device
+    )
+
+    generator.temperature = args.temperature_hot
+
+    with torch.autograd.no_grad():
+        t = generator.training
+        generator.eval()
+
+        one_batch_masked_inplace_autoregression(
+            generator,
+            prologued_c_quizzes,
+            prologued_ar_mask,
+            seq_logproba,
+            deterministic_synthesis=False,
+        )
+
+        generator.train(t)
+
+    generator.reset_transformations()
+
+    prologued_c_quizzes = (
+        prologued_c_quizzes * (prologued_c_quizzes < vocabulary_size).long()
+    )
+
+    c_quizzes = prologued_c_quizzes[:, prologs_c_quizzes.size(1) :]
+
+    return c_quizzes.to("cpu"), prologs_c_quizzes.to("cpu")
+
+
+def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1.0):
+    samples = []
+
+    for _ in range(args.nb_train_samples // args.batch_size):
+        while sum([x.size(0) for x in samples]) < args.batch_size:
+            # Generate a bunch of quizzes
+
+            if torch.rand(1).item() <= fraction_w_quizzes:
+                # Either we start with the world quizzes
+                c_quizzes = quiz_machine.problem.generate_w_quizzes(
+                    args.batch_size, progress_bar=False
+                )
+            else:
+                # Or we use the generator itself to generate them
+                c_quizzes, _ = generate_c_quizzes_with_generator(
+                    generator, quiz_machine, args.batch_size
+                )
+
+            # We remove the trivial ones
+            to_keep = quiz_machine.problem.trivial(c_quizzes) == False
+            c_quizzes = c_quizzes[to_keep]
+
+            # If there are remaining ones, we compute the true prolog
+            # that indicates how the GPTs solve it
+
+            if c_quizzes.size(0) > 0:
+                seq_logproba = quiz_machine.models_logprobas(
+                    models,
+                    c_quizzes,
+                    ("A", "f_A", "B", "f_B"),
+                    (0, 0, 0, 1),
+                    (0, 0, 1, 0),
+                ) + quiz_machine.models_logprobas(
+                    models,
+                    c_quizzes,
+                    ("f_A", "A", "f_B", "B"),
+                    (0, 0, 0, 1),
+                    (0, 0, 1, 0),
+                )
+
+                probas = seq_logproba.exp()
+
+                u0 = probas <= args.proba_not_understands
+                u2 = probas >= args.proba_understands
+                u1 = (u0 | u2) == False
+
+                prologs = (
+                    (u0.long() * token_prolog_0)
+                    + (u1.long() * token_prolog_1)
+                    + (u2.long() * token_prolog_2)
+                )
+
+                prologued_c_quizzes = torch.cat([prologs, c_quizzes], dim=1)
+
+                # nb_u2 = u2.long().sum(dim=1)
+                # nb_u0 = u0.long().sum(dim=1)
+                # prologued_c_quizzes = prologued_c_quizzes[(nb_u2 >= 1) & (nb_u0 >= 1)]
+
+                if prologued_c_quizzes.size(0) > 0:
+                    samples.append(prologued_c_quizzes)
+
+        # Now we yield a batch
+
+        x = torch.cat(samples, dim=0)
+        samples = [x[args.batch_size :]]
+
+        yield x[: args.batch_size]
+
+
+def one_generator_epoch(
+    generator, quiz_machine, models, fraction_w_quizzes, local_device=main_device
+):
+    model.to(local_device).train()
+
+    optimizer = torch.optim.Adam(generator.parameters(), lr=args.learning_rate)
+
+    nb_train_samples, acc_train_loss = 0, 0.0
+
+    src = batches_for_generator(
+        generator=generator,
+        quiz_machine=quiz_machine,
+        models=models,
+        fraction_w_quizzes=fraction_w_quizzes,
+    )
+
+    for input in tqdm.tqdm(
+        src,
+        dynamic_ncols=True,
+        desc="training",
+        total=args.nb_train_samples // args.batch_size,
+    ):
+        input = input.to(local_device)
+
+        if nb_train_samples % args.batch_size == 0:
+            optimizer.zero_grad()
+
+        targets = input
+
+        output = generator(mygpt.BracketedSequence(input)).x
+        loss = F.cross_entropy(output.transpose(1, 2), targets)
+        acc_train_loss += loss.item() * input.size(0)
+        nb_train_samples += input.size(0)
+
+        loss.backward()
+
+        if nb_train_samples % args.batch_size == 0:
+            optimizer.step()
+
+    train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+    log_string(f"train_perplexity {n_epoch} generator - {train_perplexity}")
+
+    generator.to(main_device)
+
+
+######################################################################
+
+
+def train_complexifier(model_gen, model_pred1, model_pred2):
+    samples = []
+    perf = []
+
+    optimizer = torch.optim.Adam(model_gen.parameters(), lr=args.learning_rate)
+
+    nb_train_samples, acc_train_loss = 0, 0.0
+
+    for n_epoch in range(args.nb_epochs):
+        for b in range(args.nb_train_samples // args.batch_size):
+            while sum([x.size(0) for x in samples]) < args.batch_size:
+                c_quizzes = quiz_machine.generate_c_quizzes(
+                    args.inference_batch_size,
+                    model_for_generation=model_gen,
+                    procedure=c_quizzes_procedure,
+                )
+                to_keep = quiz_machine.problem.trivial(c_quizzes) == False
+                c_quizzes = c_quizzes[to_keep]
+                if c_quizzes.size(0) > 0:
+                    seq_logproba = quiz_machine.models_logprobas(
+                        [model_pred1, model_pred2],
+                        c_quizzes,
+                        ("A", "f_A", "B", "f_B"),
+                        (0, 0, 0, 1),
+                    ) + quiz_machine.models_logprobas(
+                        [model_pred1, model_pred2],
+                        c_quizzes,
+                        ("f_A", "A", "f_B", "B"),
+                        (0, 0, 0, 1),
+                    )
+                    probas = seq_logproba.exp()
+                    to_keep = (probas[:, model_pred1.id] >= args.proba_understands) & (
+                        probas[:, model_pred2.id] <= args.proba_not_understands
+                    )
+                    log_string(
+                        f"generating {to_keep.long().sum()} / {c_quizzes.size(0)}"
+                    )
+                    c_quizzes = c_quizzes[to_keep]
+                    if c_quizzes.size(0):
+                        samples.append(c_quizzes)
+
+            log_string(f"full batch {sum([x.size(0) for x in samples])}")
+
+            x = torch.cat(samples, dim=0)
+
+            input = x[: args.batch_size]
+            samples = [x[args.batch_size :]]
+
+            # -------------------
+
+            seq_logproba = quiz_machine.models_logprobas(
+                [model_pred1, model_pred2],
+                input,
+                ("A", "f_A", "B", "f_B"),
+                (0, 0, 0, 1),
+            ) + quiz_machine.models_logprobas(
+                [model_pred1, model_pred2],
+                input,
+                ("f_A", "A", "f_B", "B"),
+                (0, 0, 0, 1),
+            )
+
+            comments = []
+
+            for l in seq_logproba:
+                comments.append(
+                    f"proba {l[model_pred1.id].exp().item():.02f} {l[model_pred2.id].exp().item():.02f}"
+                )
+
+            filename = f"batch_{n_epoch:04d}_{b:04d}.png"
+            quiz_machine.problem.save_quizzes_as_image(
+                args.result_dir, filename, input, comments=comments
+            )
+            log_string(f"wrote {filename}")
+
+            # ------------------------
+
+            input = input.to(main_device)
+
+            if nb_train_samples % args.batch_size == 0:
+                optimizer.zero_grad()
+
+            output = model_gen(mygpt.BracketedSequence(input)).x
+            loss = F.cross_entropy(output.transpose(1, 2), input)
+            acc_train_loss += loss.item() * input.size(0)
+            nb_train_samples += input.size(0)
+
+            loss.backward()
+
+            if nb_train_samples % args.batch_size == 0:
+                optimizer.step()
+
+        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+        log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}")
 
 
 ######################################################################
 
 models = []
 
+
+def compute_causal_attzero(t_q, t_k):
+    return t_q < t_k
+
+
+if args.schedule_free:
+    import schedulefree
+
 for k in range(args.nb_gpts):
     log_string(f"creating model {k} and its w_quizzes")
+
     model = mygpt.MyGPT(
         vocabulary_size=vocabulary_size,
         dim_model=args.dim_model,
@@ -480,47 +1035,86 @@ for k in range(args.nb_gpts):
         dim_hidden=args.dim_hidden,
         nb_heads=args.nb_heads,
         nb_blocks=args.nb_blocks,
-        causal=True,
+        compute_attzero=compute_causal_attzero,
         dropout=args.dropout,
     ).to(main_device)
 
-    model.main_test_accuracy = 0.0
     model.id = k
 
-    model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples)
-    quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
-    model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
-    quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
+    if args.schedule_free:
+        model.optimizer = schedulefree.AdamWScheduleFree(
+            model.parameters(), lr=args.learning_rate
+        )
+    else:
+        model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
+    model.main_test_accuracy = 0.0
+
+    model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes(
+        args.nb_train_samples
+    )
+
+    model.test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
 
     models.append(model)
 
 ######################################################################
 
-if args.resume:
-    try:
-        for model in models:
-            filename = f"gpt_{model.id:03d}.pth"
+if args.test == "quant":
+    nb_bits = 8
+    for model in models:
+        model.trunk.insert(
+            12,
+            mygpt.CacheWrapper(
+                mygpt.RandomBypass(
+                    nn.Sequential(
+                        nn.Linear(args.dim_model, nb_bits),
+                        mygpt.BSQ(nb_bits),
+                        nn.Linear(nb_bits, args.dim_model),
+                    ),
+                    0.1,
+                )
+            ),
+        )
+
+        print(model)
+        exit(0)
 
-            try:
-                d = torch.load(os.path.join(args.result_dir, filename))
-                model.load_state_dict(d[0])
-                model.main_test_accuracy = d[1]
-                log_string(f"successfully loaded {filename}")
-            except FileNotFoundError:
-                log_string(f"cannot find {filename}")
-                pass
+
+######################################################################
+
+current_epoch = 0
+
+if args.resume:
+    for model in models:
+        filename = f"gpt_{model.id:03d}.pth"
 
         try:
-            filename = "c_quizzes.pth"
-            quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename))
+            d = torch.load(os.path.join(args.result_dir, filename))
+            model.load_state_dict(d["state_dict"])
+            model.optimizer.load_state_dict(d["optimizer_state_dict"])
+            model.main_test_accuracy = d["main_test_accuracy"]
             log_string(f"successfully loaded {filename}")
         except FileNotFoundError:
             log_string(f"cannot find {filename}")
             pass
 
-    except:
-        log_string(f"error when loading {filename}.")
-        exit(1)
+    try:
+        filename = "c_quizzes.pth"
+        quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename))
+        log_string(f"successfully loaded {filename}")
+    except FileNotFoundError:
+        log_string(f"cannot find {filename}")
+        pass
+
+    try:
+        filename = "state.pth"
+        state = torch.load(os.path.join(args.result_dir, filename))
+        log_string(f"successfully loaded {filename}")
+        current_epoch = state["current_epoch"]
+    except FileNotFoundError:
+        log_string(f"cannot find {filename}")
+        pass
 
 ######################################################################
 
@@ -529,59 +1123,11 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 ######################################################################
 
-# Compute the entropy of the training tokens
-
-token_count = 0
-for input in quiz_machine.batches(models[0], split="train", desc="train-entropy"):
-    token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
-        (0, 1)
-    )
-token_probas = token_count / token_count.sum()
-entropy = -torch.xlogy(token_probas, token_probas).sum()
-train_set_perplexity = math.exp(entropy)
-
-######################################################################
-# A bit of paranoia never hurts
-
-if args.max_percents_of_test_in_train >= 0:
-
-    def subsets_as_tuples(batches, cs):
-        s = set()
-        for batch in batches:
-            for x in batch:
-                s.add(tuple([v.item() for v in x]))
-                if len(s) == cs:
-                    yield s
-                    s = set()
-        yield s
-
-    nb_test, nb_in_train = 0, 0
-    for test_subset in subsets_as_tuples(
-        quiz_machine.batches(models[0], split="test", desc="test-check"), 25000
-    ):
-        in_train = set()
-        for train_subset in subsets_as_tuples(
-            quiz_machine.batches(models[0], split="train", desc="train-check"), 25000
-        ):
-            in_train.update(test_subset.intersection(train_subset))
-        nb_in_train += len(in_train)
-        nb_test += len(test_subset)
-
-    log_string(
-        f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
-    )
-
-    assert (
-        nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
-    ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
-
-######################################################################
-
 if args.nb_new_c_quizzes_for_train is None:
-    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50
+    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 100
 
 if args.nb_new_c_quizzes_for_test is None:
-    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50
+    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 100
 
 log_string(
     f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
@@ -595,10 +1141,161 @@ if args.dirty_debug:
     args.nb_new_c_quizzes_for_train = 100
     args.nb_new_c_quizzes_for_test = 10
 
+######################################################################
+
+if args.test == "tsne":
+    model = models[0]
+
+    quizzes = []
+    labels = []
+    nb_samples_per_task = 1000
+
+    for n, t in enumerate(args.grids_world_tasks.split(",")):
+        quizzes.append(
+            quiz_machine.problem.generate_w_quizzes(nb_samples_per_task, [t])
+        )
+        labels.append(torch.full((quizzes[-1].size(0),), n))
+
+    quizzes = torch.cat(quizzes, dim=0)
+    labels = torch.cat(labels, dim=0)
+
+    with torch.autograd.no_grad():
+        model.eval().to(main_device)
+        record = []
+        for input, targets in zip(
+            quizzes.split(args.batch_size), labels.split(args.batch_size)
+        ):
+            input = input.to(main_device)
+            bs = mygpt.BracketedSequence(input)
+            bs = mygpt.BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+            bs = model.embedding(bs)
+            bs = model.trunk[args.nb_blocks // 2](bs)
+            record.append((bs.x.to("cpu"), targets))
+
+    x = torch.cat([x for x, y in record], dim=0).flatten(1)
+    y = torch.cat([y for x, y in record], dim=0)
+
+    print(f"{x.size()=} {y.size()=}")
+    # torch.save((x,y), "/tmp/embed.pth")
+    # exit(0)
+
+    from sklearn.manifold import TSNE
+
+    x_np = x.numpy()
+    z_np = TSNE(n_components=2, perplexity=50).fit_transform(x_np)
+    z = torch.from_numpy(z_np)
+
+    print(f"{z.size()=}")
+
+    with open("/tmp/result.dat", "w") as f:
+        for k in range(z.size(0)):
+            f.write(f"{y[k]} {z[k,0]} {z[k,1]}\n")
+
+    exit(0)
 
 ######################################################################
 
-for n_epoch in range(args.nb_epochs):
+if args.test == "generator":
+    token_prolog_0 = vocabulary_size + 0
+    token_prolog_1 = vocabulary_size + 1
+    token_prolog_2 = vocabulary_size + 2
+    generator_vocabulary_size = vocabulary_size + 3
+
+    generator = mygpt.MyGPT(
+        vocabulary_size=generator_vocabulary_size,
+        dim_model=args.dim_model,
+        dim_keys=args.dim_keys,
+        dim_hidden=args.dim_hidden,
+        nb_heads=args.nb_heads,
+        nb_blocks=args.nb_blocks,
+        compute_attzero=compute_causal_attzero,
+        dropout=args.dropout,
+    ).to(main_device)
+
+    generator.main_test_accuracy = 0.0
+
+    filename = f"generator.pth"
+
+    try:
+        d = torch.load(os.path.join(args.result_dir, filename))
+        generator.load_state_dict(d[0])
+        generator.main_test_accuracy = d[1]
+        log_string(f"successfully loaded {filename}")
+    except FileNotFoundError:
+        log_string(f"cannot find {filename}")
+        pass
+
+    for n_epoch in range(args.nb_epochs):
+        one_generator_epoch(
+            generator,
+            quiz_machine=quiz_machine,
+            models=models,
+            fraction_w_quizzes=1 if n_epoch < 25 else 0.5,
+            local_device=main_device,
+        )
+
+        filename = f"generator.pth"
+        torch.save(
+            (generator.state_dict(), generator.main_test_accuracy),
+            os.path.join(args.result_dir, filename),
+        )
+        log_string(f"wrote {filename}")
+
+        c_quizzes, prologs = generate_c_quizzes_with_generator(
+            generator, quiz_machine, args.batch_size
+        )
+
+        seq_logproba = quiz_machine.models_logprobas(
+            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        ) + quiz_machine.models_logprobas(
+            models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        )
+
+        probas = seq_logproba.exp()
+
+        u0 = probas <= args.proba_not_understands
+        u2 = probas >= args.proba_understands
+        u1 = (u0 | u2) == False
+
+        predicted_prologs = (
+            (u0.long() * token_prolog_0)
+            + (u1.long() * token_prolog_1)
+            + (u2.long() * token_prolog_2)
+        )
+
+        comments = []
+
+        nb_errors = (predicted_prologs != prologs).long().sum()
+        nb_total = prologs.numel()
+
+        log_string(f"generator_error {nb_errors} / {nb_total}")
+
+        def readable(prologs):
+            return (prologs == token_prolog_1) + 2 * (prologs == token_prolog_2)
+
+        for aa, ee, ff in zip(probas, readable(predicted_prologs), readable(prologs)):
+            sa = "prolog " + " ".join(
+                [f"{e.item()}/{f.item()}" for e, f in zip(ee, ff)]
+            )
+            sp = "proba " + " ".join([f"{p.item():.02f}" for p in aa])
+            comments.append(sa + "\n" + sp)
+
+        filename = f"generator_batch_{n_epoch:04d}.png"
+        quiz_machine.problem.save_quizzes_as_image(
+            args.result_dir, filename, c_quizzes, comments=comments
+        )
+        log_string(f"wrote {filename}")
+
+    exit(0)
+
+######################################################################
+
+for n_epoch in range(current_epoch, args.nb_epochs):
+    state = {"current_epoch": n_epoch}
+    filename = "state.pth"
+    torch.save(state, os.path.join(args.result_dir, filename))
+    log_string(f"wrote {filename}")
+
     log_string(f"--- epoch {n_epoch} ----------------------------------------")
 
     cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
@@ -609,7 +1306,7 @@ for n_epoch in range(args.nb_epochs):
     # re-compute the test errors
 
     if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
-        create_c_quizzes(
+        record_new_c_quizzes(
             models,
             quiz_machine,
             nb_for_train=args.nb_new_c_quizzes_for_train,
@@ -625,7 +1322,7 @@ for n_epoch in range(args.nb_epochs):
             model.main_test_accuracy = 0.0
 
     ##################################################
-    # Select, improve, and eval the worst model
+    # Select, improve, and eval the worst model(s)
 
     ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
 
@@ -652,15 +1349,28 @@ for n_epoch in range(args.nb_epochs):
     for model in weakest_models:
         filename = f"gpt_{model.id:03d}.pth"
         torch.save(
-            (model.state_dict(), model.main_test_accuracy),
+            {
+                "state_dict": model.state_dict(),
+                "optimizer_state_dict": model.optimizer.state_dict(),
+                "main_test_accuracy": model.main_test_accuracy,
+            },
             os.path.join(args.result_dir, filename),
         )
         log_string(f"wrote {filename}")
 
+    for model in weakest_models:
+        save_additional_results(model, models, science_w_quizzes)
+
+    ######################################################################
+
     # Renew the training samples
 
     for model in weakest_models:
-        quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
+        quiz_machine.renew_train_w_quizzes(model=model)
 
+    if args.log_command is not None:
+        s = args.log_command.split()
+        s.insert(1, args.result_dir)
+        subprocess.run(s)
 
 ######################################################################
index d0fda7e..041d28c 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -19,6 +19,45 @@ from torch.nn import functional as F
 
 ######################################################################
 
+
+class BSQ(nn.Module):
+    def __init__(self, L):
+        super().__init__()
+        self.L = L
+
+    def forward(self, input, indexes=False):
+        norm = input.pow(2).sum(dim=2, keepdim=True).sqrt()
+        u = input / norm
+
+        if indexes:
+            return ((u >= 0).long() * (2 ** torch.arange(self.L))[None, :]).sum(dim=1)
+
+        hat_u = 1 / math.sqrt(self.L) * (2 * (u >= 0).float() - 1)
+        if self.training:
+            self.loss += u.mean(dim=0).tanh().pow(2).mean()
+            return hat_u + u - u.detach()
+        else:
+            return hat_u
+
+
+class RandomBypass(nn.Module):
+    def __init__(self, m, p):
+        super().__init__()
+        self.m = m
+        self.p = p
+
+    def forward(self, x):
+        y = self.m(x)
+
+        if self.training:
+            u = (torch.rand(x.size(0), device=x.device) <= self.p).long()[:, None]
+            return (u * x.flatten(1) + (1 - u) * y.flatten(1)).reshape(x.size())
+        else:
+            return y
+
+
+######################################################################
+
 # A BracketedSequence is a BxTx... tensor with a first and a nb time
 # steps to compute.
 
@@ -114,6 +153,30 @@ class AddPositionalEncoding(nn.Module):
 ##############################
 
 
+class EncoderHead(nn.Module):
+    def __init__(self, dim_in, dim_out):
+        super().__init__()
+        self.fc = nn.Linear(dim_in, dim_out)
+
+    def forward(self, bs):
+        z = self.fc(bs.x).mean(dim=1)
+        return z, bs.x.shape
+
+
+class DecoderBottom(nn.Module):
+    def __init__(self, dim_in, dim_out):
+        super().__init__()
+        self.fc = nn.Linear(dim_in, dim_out)
+
+    def forward(self, z_shape):
+        z, shape = z_shape
+        y = self.fc(z)[:, None, :].expand(shape)
+        return BracketedSequence(y)
+
+
+##############################
+
+
 class QKVAttention(nn.Module):
     def __init__(
         self,
@@ -121,7 +184,7 @@ class QKVAttention(nn.Module):
         dim_qk,
         dim_v,
         nb_heads=1,
-        causal=False,
+        compute_attzero=None,
         attention_dropout=0.0,
     ):
         super().__init__()
@@ -129,7 +192,7 @@ class QKVAttention(nn.Module):
         def randw(*d):
             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
 
-        self.causal = causal
+        self.compute_attzero = compute_attzero
         self.attention_dropout = attention_dropout
         self.record_attention = False
 
@@ -141,10 +204,6 @@ class QKVAttention(nn.Module):
     def forward(self, bs_q):
         x_q = bs_q.x
 
-        assert (
-            self.causal or bs_q.complete()
-        ), "Partial evaluation is only possible for causal models"
-
         if bs_q.first == 0:
             self.cache_k = x_q.new_zeros(
                 x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
@@ -169,12 +228,12 @@ class QKVAttention(nn.Module):
             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
         ) / math.sqrt(self.w_q.size(1))
 
-        if self.causal:
+        if self.compute_attzero is not None:
             if bs_q.first == 0:
-                self.cache_attzero = (
-                    torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
-                    < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
-                )
+                self.cache_attzero = self.compute_attzero(
+                    torch.arange(x_q.size(1), device=q.device)[:, None],
+                    torch.arange(x_q.size(1), device=q.device)[None, :],
+                )[None, None, :, :]
             a = a.masked_fill(
                 self.cache_attzero[
                     :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
@@ -202,22 +261,19 @@ class QKVAttention(nn.Module):
 
 
 class NoiseInjector(nn.Module):
-    def __init__(self):
+    def __init__(self, identifier=None):
         super().__init__()
         self.noise_std = 0.0
+        self.identifier = identifier
 
     def forward(self, x):
         if self.noise_std > 0:
-            x = x + torch.randn(x.size(), device=x.device) * self.noise_std
+            x = x * (
+                1 - 2 * (torch.rand(x.size(), device=x.device) < self.noise_std).long()
+            )
         return x
 
 
-def set_noise_injection(model, noise_std):
-    for m in model.modules():
-        if isinstance(m, NoiseInjector):
-            m.noise_std = noise_std
-
-
 ##############################
 
 
@@ -230,7 +286,8 @@ class MyGPT(nn.Module):
         dim_hidden,
         nb_heads,
         nb_blocks,
-        causal=False,
+        compute_attzero=None,
+        autoencoder_dim=-1,
         dropout=0.0,
         len_max=1e5,
     ):
@@ -238,6 +295,8 @@ class MyGPT(nn.Module):
 
         assert dim_model % nb_heads == 0
 
+        self.temperature = 1.0
+
         self.embedding = nn.Sequential(
             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
             AddPositionalEncoding(len_max),
@@ -250,21 +309,21 @@ class MyGPT(nn.Module):
                 WithResidual(
                     CacheWrapper(
                         nn.LayerNorm((dim_model,)),
-                        NoiseInjector(),
+                        NoiseInjector(identifier=("attention", b)),
                     ),
                     QKVAttention(
                         dim_in=dim_model,
                         dim_qk=dim_keys,
                         dim_v=dim_model // nb_heads,
                         nb_heads=nb_heads,
-                        causal=causal,
+                        compute_attzero=compute_attzero,
                         attention_dropout=dropout,
                     ),
                 ),
                 WithResidual(
                     CacheWrapper(
                         nn.LayerNorm((dim_model,)),
-                        NoiseInjector(),
+                        NoiseInjector(identifier=("ffw", b)),
                         nn.Linear(in_features=dim_model, out_features=dim_hidden),
                         nn.ReLU(),
                         nn.Linear(in_features=dim_hidden, out_features=dim_model),
@@ -279,6 +338,26 @@ class MyGPT(nn.Module):
             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
         )
 
+        # -------------------------------------------------------
+        if autoencoder_dim > 0:
+            self.encoder = nn.Sequential(
+                *(
+                    trunk_blocks[: nb_blocks // 2]
+                    + [EncoderHead(dim_model, autoencoder_dim)]
+                )
+            )
+
+            self.decoder = nn.Sequential(
+                *(
+                    [
+                        DecoderBottom(autoencoder_dim, dim_model),
+                        AddPositionalEncoding(len_max),
+                    ]
+                    + trunk_blocks[nb_blocks // 2 :]
+                )
+            )
+        # -------------------------------------------------------
+
         with torch.no_grad():
             for m in self.modules():
                 if isinstance(m, nn.Embedding):
@@ -288,13 +367,59 @@ class MyGPT(nn.Module):
                     m.weight.fill_(1.0)
 
     def forward(self, bs):
-        # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
+        for m in self.modules():
+            m.loss = 0
+
         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
         bs = self.embedding(bs)
         bs = self.trunk(bs)
         bs = self.readout(bs)
+        bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature
+
+        for m in self.modules():
+            self.loss += m.loss
+
+        return bs
+
+    def encode(self, bs):
+        bs = self.embedding(bs)
+        z = self.encoder(bs)
+        return z
+
+    def decode(self, z_shape):
+        bs = self.decoder(z_shape)
+        bs = self.readout(bs)
         return bs
 
+    def partial_forward(self, bs, start_layer=None, end_layer=None):
+        if start_layer is None:
+            # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
+            bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+            bs = self.embedding(bs)
+            if end_layer is not None:
+                return self.trunk[:end_layer](bs)
+            else:
+                bs = self.trunk(bs)
+                bs = self.readout(bs)
+                return bs
+        else:
+            bs = self.trunk[start_layer:](bs)
+            bs = self.trunk(bs)
+            bs = self.readout(bs)
+            return bs
+
+    def reset_transformations(self):
+        self.temperature = 1.0
+        for m in self.modules():
+            if isinstance(m, NoiseInjector):
+                m.noise_std = 0.0
+
+    def set_noise_injection(self, noise_std, identifier=None):
+        for m in self.modules():
+            if isinstance(m, NoiseInjector):
+                if identifier is None or identifier == m.identifier:
+                    m.noise_std = noise_std
+
     def record_attention(self, v=True):
         for m in self.modules():
             if isinstance(m, QKVAttention):
@@ -324,7 +449,6 @@ if __name__ == "__main__":
         nb_heads=2,
         nb_blocks=2,
         dropout=0.1,
-        causal=True,
     )
 
     model.eval()
index 05f3b20..9bee5b2 100755 (executable)
@@ -25,14 +25,59 @@ class Problem:
         else:
             return self.queue.qsize() * self.chunk_size
 
-    def nb_token_values(self):
-        pass
+    def fill_cache(self):
+        while True:
+            quizzes = self.generate_w_quizzes_(self.chunk_size)
+            self.queue.put(quizzes.to("cpu"), block=True)
+
+    def generate_w_quizzes(self, nb, progress_bar=True):
+        if self.queue is None:
+            return self.generate_w_quizzes_(nb)
+
+        if self.rest is not None:
+            quizzes = rest
+        else:
+            quizzes = []
+
+        self.rest = None
+
+        n = sum([q.size(0) for q in quizzes])
+
+        if progress_bar:
+            with tqdm.tqdm(
+                total=nb,
+                dynamic_ncols=True,
+                desc="world generation",
+            ) as pbar:
+                while n < nb:
+                    q = self.queue.get(block=True)
+                    quizzes.append(q)
+                    n += q.size(0)
+                    pbar.update(q.size(0))
+        else:
+            while n < nb:
+                q = self.queue.get(block=True)
+                quizzes.append(q)
+                n += q.size(0)
+
+        quizzes = torch.cat(quizzes, dim=0)
+        assert n == quizzes.size(0)
+
+        k = n - nb
+
+        if k > 0:
+            rest = quizzes[-k:]
+            quizzes = quizzes[:-k]
+
+        return quizzes
+
+    ######################################################################
 
     def trivial_prompts_and_answers(self, prompts, answers):
         pass
 
     # The one to implement, returns two tensors nb x D and nb x D'
-    def generate_prompts_and_answers_(self, nb):
+    def generate_w_quizzes_(self, nb):
         pass
 
     # save a file to vizualize quizzes, you can save a txt or png file
@@ -47,47 +92,7 @@ class Problem:
     ):
         pass
 
-    def fill_cache(self):
-        while True:
-            prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
-
-            self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
-
-    def generate_prompts_and_answers(self, nb):
-        if self.queue is None:
-            return self.generate_prompts_and_answers_(nb)
-
-        if self.rest is not None:
-            prompts, answers = rest
-        else:
-            prompts, answers = [], []
-
-        self.rest = None
-
-        n = sum([p.size(0) for p in prompts])
-
-        with tqdm.tqdm(
-            total=nb,
-            dynamic_ncols=True,
-            desc="world generation",
-        ) as pbar:
-            while n < nb:
-                p, s = self.queue.get(block=True)
-                prompts.append(p)
-                answers.append(s)
-                n += p.size(0)
-                pbar.update(p.size(0))
-
-        prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
-        assert n == prompts.size(0)
-
-        k = n - nb
-
-        if k > 0:
-            rest = (prompts[-k:], answers[-k:])
-            prompts, answers = prompts[:-k], answers[:-k]
-
-        return prompts, answers
-
     def save_some_examples(self, result_dir):
         pass
+
+    ######################################################################
index bc468d3..92da03d 100755 (executable)
@@ -17,36 +17,6 @@ from mygpt import BracketedSequence
 
 import threading
 
-######################################################################
-# if output is log(P(X=y)) and target is Y, returns -log P(X=Y) + H(X
-# | X != Y)
-
-
-# output is NxCxT and target is NxT
-def confusion(output, target, reduction="mean"):
-    N, C, T = output.shape
-    output = output.permute(0, 2, 1).reshape(-1, C)
-    target = target.flatten()
-    all_t = torch.arange(N * T, device=output.device)
-    output = output.log_softmax(dim=-1)
-    result = -output[all_t, target]
-
-    output[all_t, target] = float("-inf")
-    output = output.log_softmax(dim=-1)
-    e = output.exp()
-    output[all_t, target] = 0
-    result = result - (output * e).sum(-1)
-
-    if reduction == "none":
-        return result.reshape(N, T)
-    elif reduction == "mean":
-        return result.reshape(N, T).mean()
-    elif reduction == "sum":
-        return result.reshape(N, T).sum()
-    else:
-        raise ValueError(f"unknown reduction '{reduction}'.")
-
-
 ######################################################################
 
 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
@@ -59,9 +29,11 @@ def one_batch_masked_inplace_autoregression(
     input,
     ar_mask,
     seq_logproba,
-    temperature,
-    deterministic_synthesis,
+    deterministic_synthesis=False,
 ):
+    if input.size(0) == 0:
+        return
+
     to_generate = (ar_mask.sum(0) > 0).nonzero()
 
     if to_generate.min() > 0:
@@ -73,8 +45,6 @@ def one_batch_masked_inplace_autoregression(
 
         logits = output[:, s]
 
-        logits = (logits / temperature).log_softmax(dim=-1)
-
         if deterministic_synthesis:
             t_next = logits.argmax(-1)
         else:
@@ -88,229 +58,99 @@ def one_batch_masked_inplace_autoregression(
         input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
 
 
-def masked_inplace_autoregression(
-    model,
-    batch_size,
-    input,
-    ar_mask,
-    seq_logproba,
-    temperature,
-    deterministic_synthesis,
-    forbidden_tokens=None,
-    logit_biases=None,
-    progress_bar_desc=None,
-    device=torch.device("cpu"),
-):
-    assert input.size() == ar_mask.size()
-
-    batches = zip(
-        input.split(batch_size),
-        ar_mask.split(batch_size),
-        seq_logproba.split(batch_size),
-    )
-
-    if progress_bar_desc is not None:
-        batches = tqdm.tqdm(
-            batches,
-            dynamic_ncols=True,
-            desc=progress_bar_desc,
-            total=(input.size(0) + batch_size - 1) // batch_size,
-        )
-
-    with torch.autograd.no_grad():
-        t = model.training
-        model.eval()
-
-        for input, ar_mask, seq_logproba in batches:
-            one_batch_masked_inplace_autoregression(
-                model=model,
-                input=input,
-                ar_mask=ar_mask,
-                seq_logproba=seq_logproba,
-                temperature=temperature,
-                deterministic_synthesis=deterministic_synthesis,
-            )
-
-        model.train(t)
-
-
 ######################################################################
 
 
 class QuizMachine:
-    def indices_forward_and_backward(self, quizzes):
-        i_forward = quizzes[:, 0] == self.token_forward
-        j_forward = quizzes[:, 1 + self.prompt_len] == self.token_forward
-        i_backward = quizzes[:, 0] == self.token_backward
-        j_backward = quizzes[:, 1 + self.answer_len] == self.token_backward
-        assert torch.logical_or(
-            torch.logical_and(i_forward, j_forward),
-            torch.logical_and(i_backward, j_backward),
-        ).all()
-        return i_forward, i_backward
-
-    def non_trivial(self, quizzes):
-        quizzes = quizzes.clone()
-        n_forward = quizzes[quizzes[:, 0] == self.token_forward]
-        n_backward = quizzes[:, 0] == self.token_backward
-        backward = quizzes[n_backward]
-        quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
-        return torch.logical_not(
-            self.problem.trivial_prompts_and_answers(
-                quizzes[:, 1 : 1 + self.prompt_len],
-                quizzes[:, 2 + self.prompt_len :],
-            )
-        )
-
-    def reverse_time(self, quizzes):
-        i_forward, i_backward = self.indices_forward_and_backward(quizzes)
-
-        forward_to_backward = torch.cat(
-            [
-                quizzes[:, 0:1],
-                quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len],
-                quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1],
-                quizzes[:, 1 : 1 + self.prompt_len],
-            ],
-            dim=1,
-        )
-
-        forward_to_backward[:, 0] = self.token_backward
-        forward_to_backward[:, 1 + self.answer_len] = self.token_backward
-
-        backward_to_forward = torch.cat(
-            [
-                quizzes[:, 0:1],
-                quizzes[:, 2 + self.answer_len :],
-                quizzes[:, 1 + self.answer_len : 2 + self.answer_len],
-                quizzes[:, 1 : 1 + self.answer_len],
-            ],
-            dim=1,
-        )
-
-        backward_to_forward[:, 0] = self.token_forward
-        backward_to_forward[:, 1 + self.prompt_len] = self.token_forward
-
-        m = i_forward.long()[:, None]
-
-        return m * forward_to_backward + (1 - m) * backward_to_forward
-
-    def reverse_random_half_in_place(self, quizzes):
-        i = torch.rand(quizzes.size(0)) < 0.5
-        if i.any():
-            quizzes[i] = self.reverse_time(quizzes[i])
-
-    def make_ar_mask(self, quizzes, first=False):
-        i_forward, i_backward = self.indices_forward_and_backward(quizzes)
-
-        t = torch.arange(quizzes.size(1), device=quizzes.device)
-
-        if first:
-            m_forward = (t >= 1).long() * (t < 1 + self.prompt_len).long()
-            m_backward = (t >= 1).long() * (t < 1 + self.answer_len).long()
-        else:
-            m_forward = (t >= 2 + self.prompt_len).long()
-            m_backward = (t >= 2 + self.answer_len).long()
-
-        m = i_forward.long()[:, None]
-
-        return m * m_forward + (1 - m) * m_backward
-
-    def generate_token_sequences(self, nb):
-        prompts, answers = self.problem.generate_prompts_and_answers(nb)
-
-        if self.prompt_len is None:
-            self.prompt_len = prompts.size(1)
-
-        if self.answer_len is None:
-            self.answer_len = answers.size(1)
-
-        assert prompts.size(1) == self.prompt_len and answers.size(1) == self.answer_len
-
-        result = []
-
-        for prompt, answer in zip(prompts, answers):
-            a = [
-                torch.tensor([self.token_forward]),
-                prompt,
-                torch.tensor([self.token_forward]),
-                answer,
-            ]
-
-            result.append(torch.cat(a, dim=0)[None, :])
-
-        return torch.cat(result, dim=0)
-
     def __init__(
         self,
         problem,
-        nb_train_samples,
-        nb_test_samples,
-        back_accuracy,
         batch_size,
         result_dir,
+        prompt_noise,
         logger,
         device=torch.device("cpu"),
     ):
         super().__init__()
 
-        v = problem.nb_token_values()
-        self.token_forward = v
-        self.token_backward = v + 1
-        self.nb_token_values = v + 2
-
         self.problem = problem
-        self.back_accuracy = back_accuracy
         self.batch_size = batch_size
         self.device = device
         self.logger = logger
         self.prompt_len = None
         self.answer_len = None
+        self.prompt_noise = prompt_noise
+
+        # struct, mask_generate, mask_noise, mask_loss
+        self.train_structures = [
+            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
+            (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
+            (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+            (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+            (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+            # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
+            # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+            # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
+            # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+            # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+        ]
+
+        self.test_structures = self.train_structures
 
         self.LOCK_C_QUIZZES = threading.Lock()
         self.train_c_quizzes = []
         self.test_c_quizzes = []
 
-    def save_quiz_illustrations(
+    def vocabulary_size(self):
+        return self.problem.nb_token_values
+
+    ######################################################################
+
+    def autoregression(
         self,
-        result_dir,
-        filename_prefix,
-        quizzes,
-        mistakes=None,
+        model,
+        input,
+        ar_mask,
+        seq_logproba=None,
+        progress_bar_desc=None,
     ):
-        quizzes = quizzes.clone().to("cpu")
-        n_forward = quizzes[quizzes[:, 0] == self.token_forward]
-        n_backward = quizzes[:, 0] == self.token_backward
-        backward = quizzes[n_backward]
-        assert n_forward.size(0) + backward.size(0) == quizzes.size(0)
-        quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
-
-        predicted_prompts = n_backward.long()
-        predicted_answers = 1 - predicted_prompts
-        if mistakes is not None:
-            # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
-            predicted_prompts *= mistakes.to("cpu")
-            predicted_answers *= mistakes.to("cpu")
-        else:
-            # 0/2 ~ not-to-predict / to predict
-            predicted_prompts *= 2
-            predicted_answers *= 2
+        assert input.size() == ar_mask.size()
 
-        self.problem.save_quiz_illustrations(
-            result_dir,
-            filename_prefix,
-            quizzes[:, 1 : 1 + self.prompt_len],
-            quizzes[:, 2 + self.prompt_len :],
-            predicted_prompts,
-            predicted_answers,
+        if seq_logproba is None:
+            seq_logproba = torch.empty(input.size(0), device=self.device)
+
+        batches = zip(
+            input.split(self.batch_size),
+            ar_mask.split(self.batch_size),
+            seq_logproba.split(self.batch_size),
         )
 
-    def vocabulary_size(self):
-        return self.nb_token_values
+        if progress_bar_desc is not None:
+            batches = tqdm.tqdm(
+                batches,
+                dynamic_ncols=True,
+                desc=progress_bar_desc,
+                total=(input.size(0) + self.batch_size - 1) // self.batch_size,
+            )
+
+        with torch.autograd.no_grad():
+            t = model.training
+            model.eval()
+
+            for input, ar_mask, seq_logproba in batches:
+                one_batch_masked_inplace_autoregression(
+                    model=model,
+                    input=input,
+                    ar_mask=ar_mask,
+                    seq_logproba=seq_logproba,
+                    deterministic_synthesis=False,
+                )
+
+            model.train(t)
 
     ######################################################################
 
-    def batches(self, model, split="train", desc=None):
+    def data_input(self, model, split="train"):
         assert split in {"train", "test"}
 
         with self.LOCK_C_QUIZZES:
@@ -323,6 +163,7 @@ class QuizMachine:
 
             if len(c_quizzes) > 0:
                 c_quizzes = torch.cat(c_quizzes, dim=0)
+
                 if c_quizzes.size(0) > w_quizzes.size(0) // 2:
                     i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
                     c_quizzes = c_quizzes[i]
@@ -332,106 +173,154 @@ class QuizMachine:
                 ]
                 w_quizzes = w_quizzes[i]
 
-                self.nb_batch_w_quizzes = w_quizzes.size(0)
-                self.nb_batch_c_quizzes = c_quizzes.size(0)
+                quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
+                from_w = torch.arange(
+                    quizzes.size(0), device=quizzes.device
+                ) < w_quizzes.size(0)
 
-                input = torch.cat([w_quizzes, c_quizzes], dim=0)
             else:
-                input = w_quizzes
-                self.nb_batch_w_quizzes = w_quizzes.size(0)
-                self.nb_batch_c_quizzes = 0
+                quizzes = w_quizzes.clone()
+                from_w = torch.full((quizzes.size(0),), True, device=quizzes.device)
+
+        i = torch.randperm(quizzes.size(0), device=quizzes.device)
+        quizzes, from_w = quizzes[i], from_w[i]
 
-        # Shuffle
-        input = input[torch.randperm(input.size(0))]
+        self.randomize_configuations_inplace(
+            quizzes, structs=[s for s, _, _, _ in self.train_structures]
+        )
 
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=desc
-        ):
-            yield batch
+        quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
+
+        if self.prompt_noise > 0.0:
+            for struct, _, mask_noise, mask_loss in self.train_structures:
+                i = self.problem.indices_select(quizzes=quizzes, struct=struct)
+                if i.any():
+                    quizzes[i] = self.problem.inject_noise(
+                        quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise
+                    )
+                    quiz_mask_loss[i] = self.make_quiz_mask(
+                        quizzes=quizzes[i], struct=struct, mask=mask_loss
+                    )
+
+        return quizzes, quiz_mask_loss
 
     ######################################################################
 
-    def produce_results(
-        self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
-    ):
-        def compute_accuracy(input, log_prefix=None):
-            input = input.to(self.device)
-            ar_mask = self.make_ar_mask(input)
-            result = input.clone() * (1 - ar_mask)
-            seq_logproba = torch.empty(input.size(0), device=self.device)
+    def make_quiz_mask(self, quizzes, struct, mask):
+        assert struct in [s for s, _, _, _ in self.train_structures]
+        return self.problem.make_quiz_mask(quizzes, struct=struct, mask=mask)
 
-            masked_inplace_autoregression(
-                model=model,
-                batch_size=self.batch_size,
-                input=result,
-                ar_mask=ar_mask,
-                seq_logproba=seq_logproba,
-                temperature=1.0,
-                deterministic_synthesis=deterministic_synthesis,
-                progress_bar_desc=None,
-                device=self.device,
-            )
+    ######################################################################
 
-            correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device)
+    def predict(self, model, quizzes, struct, mask):
+        ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask)
+        result = quizzes * (1 - ar_mask)
 
-            n_forward = input[:, 0] == self.token_forward
-            n_backward = input[:, 0] == self.token_backward
+        seq_logproba = torch.empty(quizzes.size(0), device=self.device)
 
-            correct[n_forward] = (
-                (input[n_forward] == result[n_forward]).long().min(dim=1).values
-            )
+        self.autoregression(
+            model=model,
+            input=result,
+            ar_mask=ar_mask,
+            seq_logproba=seq_logproba,
+            progress_bar_desc="accuracy",
+        )
 
-            if self.back_accuracy and n_backward.any():
-                # accuracy of B->A*->B*=B instead of B->A*=A
-                back_input = self.reverse_time(result[n_backward])
-                back_input[:, 2 + self.prompt_len :] = input[
-                    n_backward, 1 : 1 + self.answer_len
-                ]
-                _, correct[n_backward] = compute_accuracy(back_input)
+        correct = (result == quizzes).min(dim=1).values.long()
 
-            if log_prefix is not None:
-                forward_nb_correct = correct[n_forward].sum()
-                forward_nb_total = correct[n_forward].size(0)
-                backward_nb_correct = correct[n_backward].sum()
-                backward_nb_total = correct[n_backward].size(0)
+        return result, correct
 
-                self.logger(
-                    f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}"
-                )
+    ######################################################################
+
+    def produce_results(self, n_epoch, model, input, result_dir):
+        input = input.to(self.device)
+        result = input.new(input.size())
+        correct = input.new(input.size(0))
+        predicted_parts = input.new(input.size(0), 4)
 
-            return result, correct
+        nb = 0
 
-        # compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
+        # We consider all the configurations that we train for
+        for struct, mask_generate, _, _ in self.test_structures:
+            i = self.problem.indices_select(quizzes=input, struct=struct)
+            nb += i.long().sum()
+            result[i], correct[i] = self.predict(
+                model=model, quizzes=input[i], struct=struct, mask=mask_generate
+            )
+            predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[
+                None, :
+            ]
+            solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1
+            correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long()
+
+        assert nb == input.size(0)
 
-        test_result, test_correct = compute_accuracy(
-            model.test_w_quizzes[:nmax], log_prefix="test"
+        nb_correct = (correct == 1).long().sum()
+        nb_total = (correct != 0).long().sum()
+        self.logger(
+            f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}"
         )
 
-        main_test_accuracy = test_correct.sum() / test_correct.size(0)
-        self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
+        main_test_accuracy = nb_correct / nb_total
 
         ##############################
 
-        self.save_quiz_illustrations(
+        correct_parts = predicted_parts * correct[:, None]
+
+        result = result[:128]
+        predicted_parts = predicted_parts[:128]
+        correct_parts = correct_parts[:128]
+
+        self.problem.save_quizzes_as_image(
             result_dir,
-            f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
-            quizzes=test_result[:72],
-            mistakes=test_correct[:72] * 2 - 1,
+            f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png",
+            quizzes=result,
+            predicted_parts=predicted_parts,
+            correct_parts=correct_parts,
         )
 
         return main_test_accuracy
 
     ######################################################################
 
-    def renew_w_quizzes(self, model, nb, for_train=True):
-        input = model.train_w_quizzes if for_train else model.test_w_quizzes
-        nb = min(nb, input.size(0))
-        input[:-nb] = input[nb:].clone()
-        fresh_w_quizzes = self.generate_token_sequences(nb)
-        self.reverse_random_half_in_place(fresh_w_quizzes)
-        input[-nb:] = fresh_w_quizzes.to("cpu")
+    def randomize_configuations_inplace(self, quizzes, structs):
+        r = torch.randint(len(structs), (quizzes.size(0),), device=quizzes.device)
+        for c in range(len(structs)):
+            quizzes[r == c] = self.problem.reconfigure(
+                quizzes[r == c], struct=structs[c]
+            )
+
+    ######################################################################
+
+    def renew_train_w_quizzes(self, model):
+        if hasattr(model, "hard_w_quizzes"):
+            hard_w_quizzes = self.problem.reconfigure(
+                model.hard_w_quizzes, struct=("A", "f_A", "B", "f_B")
+            )
+            self.logger(
+                f"re-using {hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
+            )
+            if hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0):
+                nb_to_generate = 0
+                model.train_w_quizzes[...] = hard_w_quizzes[
+                    torch.randperm(hard_w_quizzes.size(0))[
+                        model.train_w_quizzes.size(0)
+                    ]
+                ]
+            else:
+                nb_to_generate = model.train_w_quizzes.size(0) - hard_w_quizzes.size(0)
+                model.train_w_quizzes[...] = torch.cat(
+                    [
+                        hard_w_quizzes,
+                        self.problem.generate_w_quizzes(nb_to_generate),
+                    ],
+                    dim=0,
+                )
+        else:
+            nb_to_generate = 0
+            model.train_w_quizzes[...] = self.problem.generate_w_quizzes(
+                model.train_w_quizzes.size(0)
+            )
 
     ######################################################################
 
@@ -450,162 +339,92 @@ class QuizMachine:
 
     ######################################################################
 
-    def solution_token_logprobas(self, models, c_quizzes):
-        logproba = c_quizzes.new_zeros(
+    def models_logprobas(
+        self,
+        models_for_validation,
+        c_quizzes,
+        struct,
+        mask_loss,
+        mask_noise=None,
+        device=None,
+    ):
+        if device is None:
+            device = self.device
+
+        c_quizzes = self.problem.reconfigure(c_quizzes, struct)
+
+        seq_logproba = torch.zeros(
             c_quizzes.size(0),
-            len(models),
-            c_quizzes.size(1),
-            device=self.device,
-            dtype=torch.float32,
+            max([m.id for m in models_for_validation]) + 1,
+            device=device,
         )
 
-        for model in models:
+        # if self.prompt_noise > 0.0 and mask_noise is not None:
+        # c_quizzes = self.problem.inject_noise(
+        # c_quizzes, self.prompt_noise, struct=struct, mask=mask_noise
+        # )
+
+        for model in models_for_validation:
             with torch.autograd.no_grad():
                 t = model.training
                 model.eval()
 
                 for input, l in zip(
-                    c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+                    c_quizzes.split(self.batch_size),
+                    seq_logproba.split(self.batch_size),
                 ):
-                    input = input.to(self.device)
-                    ar_mask = self.make_ar_mask(input)
+                    input = input.to(device)
+                    quiz_mask_loss = self.make_quiz_mask(
+                        input, struct=struct, mask=mask_loss
+                    )
                     output = model(mygpt.BracketedSequence(input)).x
                     l[:, model.id] = (
                         -F.cross_entropy(
                             output.transpose(1, 2), input, reduction="none"
                         )
-                        * ar_mask
-                    )
+                        * quiz_mask_loss
+                    ).sum(dim=1)
 
                 model.train(t)
 
-        return logproba.to("cpu")
+        return seq_logproba.to("cpu")
 
-    ###############################################################
+    ######################################################################
 
-    def compute_correctness(
-        self,
-        c_quizzes,
-        models_for_validation,
-        bidirectional_validation=False,
-        deterministic_validation=True,
-    ):
-        if bidirectional_validation:
-            backward_c_quizzes = self.forward_to_backward(c_quizzes)
+    def generate_c_quizzes(self, nb, model_for_generation, procedure, recorder=None):
+        seq_logproba = torch.zeros(nb, device=self.device)
 
-        seq_logproba = torch.zeros(
-            c_quizzes.size(0),
-            max([m.id for m in models_for_validation]) + 1,
-            device=self.device,
-        )
+        c_quizzes = None
 
-        nb_correct = 0
+        for s, m, mt in procedure:
+            if c_quizzes is None:
+                c_quizzes = self.problem.create_empty_quizzes(nb, s)
+                c_quizzes = c_quizzes.to(self.device)
+            elif s != pred_s:
+                c_quizzes = self.problem.reconfigure(c_quizzes, s)
+            pred_s = s
 
-        seq_logproba[...] = 0.0
+            if mt is not None:
+                mt(model_for_generation)
 
-        for model in models_for_validation:
-            result = c_quizzes.clone()
-
-            ar_mask = self.make_ar_mask(result)
-
-            masked_inplace_autoregression(
-                model=model,
-                batch_size=self.batch_size,
-                input=result,
-                ar_mask=ar_mask,
-                seq_logproba=seq_logproba[:, model.id],
-                temperature=1.0,
-                deterministic_synthesis=deterministic_validation,
-                # progress_bar_desc="solving c_quizzes",
-                device=self.device,
+            self.autoregression(
+                model=model_for_generation,
+                input=c_quizzes,
+                ar_mask=self.make_quiz_mask(c_quizzes, s, m),
+                seq_logproba=seq_logproba,
             )
 
-            correct = (c_quizzes == result).long().min(dim=-1).values
-
-            if bidirectional_validation:
-                backward_result = backward_c_quizzes.clone()
+            model_for_generation.reset_transformations()
 
-                ar_mask = self.make_ar_mask(backward_result)
-
-                masked_inplace_autoregression(
-                    model=model,
-                    batch_size=self.batch_size,
-                    input=backward_result,
-                    ar_mask=ar_mask,
-                    seq_logproba=seq_logproba[:, model.id],
-                    temperature=1.0,
-                    deterministic_synthesis=deterministic_validation,
-                    # progress_bar_desc="solving backward c_quizzes",
-                    device=self.device,
-                )
-
-                backward_correct = (
-                    (backward_c_quizzes == backward_result).long().min(dim=-1).values
+            if recorder is not None:
+                x = c_quizzes.clone()
+                t = torch.tensor(m, device=x.device)[None, :].expand(x.size(0), -1)
+                recorder.append(
+                    self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B"))
                 )
 
-                correct *= backward_correct
-
-            # endif
-
-            nb_correct += correct
-
-        return nb_correct, seq_logproba
-
-    ###############################################################
-
-    def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
-        c_quizzes = torch.empty(
-            nb,
-            self.prompt_len + self.answer_len + 2,
-            device=self.device,
-            dtype=torch.int64,
-        )
-
-        seq_logproba = torch.zeros(nb, device=self.device)
-
-        # First, we generate the answer at high temperature
-
-        c_quizzes[:, 0] = self.token_backward
-        c_quizzes[:, 1 + self.answer_len] = self.token_backward
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=self.make_ar_mask(c_quizzes, first=True),
-            seq_logproba=seq_logproba,
-            temperature=temperature,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        # Then, we generate the prompt at low temperature
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=self.make_ar_mask(c_quizzes),
-            seq_logproba=seq_logproba,
-            temperature=1 / temperature,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        # Then we return the quizz, and re-generate the response, now
-        # at low temperature
-
-        c_quizzes = self.reverse_time(c_quizzes)
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=self.make_ar_mask(c_quizzes),
-            seq_logproba=seq_logproba,
-            temperature=1 / temperature,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
+        c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
 
         return c_quizzes.to("cpu")
+
+    ######################################################################
diff --git a/report/culture.tex b/report/culture.tex
new file mode 100644 (file)
index 0000000..d9f39e3
--- /dev/null
@@ -0,0 +1,596 @@
+%% -*- mode: latex; mode: reftex; mode: flyspell; coding: utf-8; tex-command: "pdflatex.sh" -*-
+
+%% Any copyright is dedicated to the Public Domain.
+%% https://creativecommons.org/publicdomain/zero/1.0/
+%% Written by Francois Fleuret <francois@fleuret.org>
+
+\documentclass[11pt,a4paper,oneside]{article}
+\usepackage[paperheight=15cm,paperwidth=8cm,top=2mm,bottom=15mm,right=5mm,left=5mm]{geometry}
+%\usepackage[a4paper,top=2.5cm,bottom=2cm,left=2.5cm,right=2.5cm]{geometry}
+\usepackage[utf8]{inputenc}
+\usepackage{amsmath,amssymb,dsfont}
+\usepackage[pdftex]{graphicx}
+\usepackage[colorlinks=true,linkcolor=blue,urlcolor=blue,citecolor=blue]{hyperref}
+\urlstyle{same}
+\usepackage{tikz}
+\usetikzlibrary{arrows,arrows.meta,calc}
+\usetikzlibrary{patterns,backgrounds}
+\usetikzlibrary{positioning,fit}
+\usetikzlibrary{shapes.geometric,shapes.multipart}
+\usetikzlibrary{patterns.meta,decorations.pathreplacing,calligraphy}
+\usetikzlibrary{tikzmark}
+\usetikzlibrary{decorations.pathmorphing}
+\usepackage[round]{natbib}
+\usepackage[osf]{libertine}
+\usepackage{microtype}
+
+\usepackage{mleftright}
+
+\usepackage{enumitem}
+\setlist[itemize]{leftmargin=0pt,itemindent=1em,itemsep=2ex}
+\setlist{nosep} % or \setlist{noitemsep} to leave space around whole list
+
+\newcommand{\setmuskip}[2]{#1=#2\relax}
+\setmuskip{\thinmuskip}{1.5mu} % by default it is equal to 3 mu
+\setmuskip{\medmuskip}{2mu} % by default it is equal to 4 mu
+\setmuskip{\thickmuskip}{3.5mu} % by default it is equal to 5 mu
+
+\setlength{\parindent}{0cm}
+\setlength{\parskip}{1ex}
+%\renewcommand{\baselinestretch}{1.3}
+%\setlength{\tabcolsep}{0pt}
+%\renewcommand{\arraystretch}{1.0}
+
+\def\argmax{\operatornamewithlimits{argmax}}
+\def\argmin{\operatornamewithlimits{argmin}}
+
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+\def\given{\,\middle\vert\,}
+\def\proba{\operatorname{P}}
+\newcommand{\seq}{{S}}
+\newcommand{\expect}{\mathds{E}}
+\newcommand{\variance}{\mathds{V}}
+\newcommand{\empexpect}{\hat{\mathds{E}}}
+\newcommand{\mutinf}{\mathds{I}}
+\newcommand{\empmutinf}{\hat{\mathds{I}}}
+\newcommand{\entropy}{\mathds{H}}
+\newcommand{\empentropy}{\hat{\mathds{H}}}
+\newcommand{\ganG}{\mathbf{G}}
+\newcommand{\ganD}{\mathbf{D}}
+\newcommand{\ganF}{\mathbf{F}}
+
+\newcommand{\dkl}{\mathds{D}_{\mathsf{KL}}}
+\newcommand{\djs}{\mathds{D}_{\mathsf{JS}}}
+
+\newcommand*{\vertbar}{\rule[-1ex]{0.5pt}{2.5ex}}
+\newcommand*{\horzbar}{\rule[.5ex]{2.5ex}{0.5pt}}
+
+\def\positionalencoding{\operatorname{pos-enc}}
+\def\concat{\operatorname{concat}}
+\def\crossentropy{\LL_{\operatorname{ce}}}
+
+\newcommand{\separator}{\begin{center}
+*
+\end{center}}
+
+\newcommand{\pic}[2]{%
+\hspace*{\stretch{1}}
+%
+\includegraphics[scale=0.25]{#1}
+%
+\hspace*{\stretch{1}}%
+}
+
+\newcommand{\birdpic}[2]{%
+\hspace*{\stretch{1}}
+%
+\includegraphics[scale=0.35]{#1}
+%
+\hspace*{\stretch{1}}%
+}
+
+\newenvironment{example}{%
+
+\vspace*{2ex}
+
+\begin{minipage}{\textwidth}
+
+\setlength{\parindent}{0cm}
+\setlength{\parskip}{1ex}
+}{%
+\end{minipage}
+}
+
+\begin{document}
+
+\vspace*{-3ex}
+
+\begin{center}
+
+{\Large Self-Generated Culture}
+
+Fran\c cois Fleuret
+
+\today
+
+\vspace*{2ex}
+
+\centerline{\color{red}(work in progress, to be updated)}
+
+\medskip
+
+\centerline{\url{https://fleuret.org/public/culture/culture.pdf}}
+
+\end{center}
+
+\section{Introduction}
+
+The hypothesis behind this experiment is that high-level abstract
+thinking is fueled by social competition.
+
+A group of communicating agents that try to demonstrate their
+cognitive superiority would end up developing a rich and consistent
+culture.
+
+\subsection{Setup}
+
+The experiment is designed with a group of GPTs that alternatively
+learn to solve quizzes and generate new ones.
+
+A ``quiz'' is a pair composed of a prompt and a solution, both being
+sequence of tokens.
+
+We differentiate \textbf{world quizzes} that follow pre-defined and
+fixed regularities, and mimic the world's physical and environmental
+patterns that an organism has to grasp to survive, and \textbf{culture
+  quizzes} that are generated by the GPTs, and mimic the knowledge one
+has to master to perform socially.
+
+
+We train five GPTs on a a very large set of ``world quizzes''
+generated randomly. These models are trained to generate both the
+solution given the prompt, and the prompt given the solution.
+
+This is achieved by using for training both ``forward sequences'',
+composed of a token \texttt{[fwd]}, followed by the prompt's tokens,
+followed by another token \texttt{[fwd]}, followed by the solution's
+tokens, or ``backward sequences'' composed of a token \texttt{[bck]},
+followed by the solution's tokens, followed by another token
+\texttt{[bck]}, followed by the prompt's tokens,
+
+\subsection{Generating Culture Quizzes}
+
+When their accuracy get above $95\%$ we generate new quizzes as follows:
+%
+\begin{enumerate}
+
+\item generate a solution (without conditioning) at temperature $T=2$,
+  then generate a prompt for that solution at temperature $T=1/2$, and
+  then generate a solution for that prompt at temperature $T=1/2$.
+
+\item generate one solution for that prompt with each of the $5$ GPTs
+  at temperature $T=1$, if $4$ of them generate the correct solution,
+  validate that quiz and include it in the training data.
+
+\end{enumerate}
+
+This criterion assures that the new quizzes are both solvable and
+sophisticated, and incrementally complexify the culture. Imposing both
+direction prevents the generation of quizzes which are not trivial
+only because the prompt has been randomly degraded.
+
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+\pagebreak
+
+\section{Grid Quizzes}
+
+\subsection{World Quizzes}
+
+We define several types of quizzes and implement algorithmic
+procedures to generate randomly as many examples from each that we
+need.
+
+In these quizzes, the prompt is made of three grids $A, f(A), B$ and
+the solution is a single grid $f(B)$.
+
+\subsubsection{Half Fill}
+
+\pic{pics/task_color_grow.png}{``half fill''}
+
+The first grid contains three rectangles, each with a vertical or an
+horizontal line of another color in its middle. The second grid is
+identical with one of the rectangle having one half filled. The third
+grid contains three rectangles of identical colors as the firs grid,
+of different size and locations. The solution is obtained by filling
+similarly one of the half of a rectangle of the third image.
+
+\subsubsection{Detect}
+
+\pic{pics/task_detect.png}{``detect''}
+
+The first grid contains three rectangles, the second has two pixels of
+same colors located in the top-left corner of two of them. The
+solution is obtained by marking in the fourth image the top-left
+corners of the rectangles of same colors in the third.
+
+\subsubsection{Frame}
+
+\pic{pics/task_frame.png}{``frame''}
+
+The first grid contains three rectangles, and the second is identical
+except that one rectangle has been replaced by its frame. The same
+should be done to the similarly colored rectangles of the third grid
+to obtain the solution.
+
+\subsubsection{Grow}
+
+\pic{pics/task_grow.png}{``grow''}
+
+The first grid contains three rectangles, one of them getting one
+pixel thicker or thinner in the second. The same should be done to the
+similarly colored rectangles of the third grid to get the solution.
+
+\subsubsection{Replace color}
+
+\pic{pics/task_replace_color.png}{``replace color''}
+
+The first grid contains three rectangles, the second is obtained by
+changing one of the colors. The same should be done to the third grid
+to obtain the solution.
+
+\subsubsection{Translate}
+
+\pic{pics/task_translate.png}{``translate''}
+
+The first grid contains three rectangles. The second is obtained by
+displacing one of them by one pixel in both direction. The solution is
+obtained by applying the same motion to the similarly colored
+rectangle in the third grid.
+
+%% \subsubsection{Bounce}
+
+%% \pic{pics/task_bounce.png}{``bounce''}
+
+%% The solution should join the two pixels of same color, with a path of
+%% another color, starting in the direction indicated by a pixel of that
+%% color, and changing direction only when colliding with a pixel of a
+%% third color or one of the lattice border.
+
+%% \subsubsection{count}
+
+%% \pic{pics/task_count.png}{``count''}
+
+%% \subsubsection{scale}
+
+%% \pic{pics/task_scale.png}{``scale''}
+
+%% \subsubsection{trajectory}
+
+%% \pic{pics/task_trajectory.png}{``trajectory''}
+
+\subsection{Culture Quizzes}
+
+We list here some generated quizzes that exhibit features that were not present in the ``world quizzes'' used for training.
+
+\bigskip
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0110_N4_validated/quiz_63.png}{0110/63}
+
+\pic{pics/culture_c_quiz_0115_N4_validated/quiz_37.png}{0115/37}
+
+The quizzes ``frame'' and ``half fill'' have been combined in a single
+quiz.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0120_N4_validated/quiz_05.png}{0110/05}
+
+The ``frame'' quiz has been generalized to non-rectangular shapes.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_01.png}{0078/01}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_02.png}{0078/02}
+
+More rectangles were added as distractors.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0087_N4_validated/quiz_62.png}{0087/62}
+
+\pic{pics/culture_c_quiz_0102_N4_validated/quiz_04.png}{0102/04}
+
+\pic{pics/culture_c_quiz_0102_N4_validated/quiz_11.png}{0102/11}
+
+\pic{pics/culture_c_quiz_0108_N4_validated/quiz_31.png}{0108/31}
+
+Variation of ``Detect'' with location markers colored according to the
+color of the rectangle they mark.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_16.png}{0078/16}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_21.png}{0084/21}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_42.png}{0078/42}
+
+\pic{pics/culture_c_quiz_0089_N4_validated/quiz_28.png}{0089/28}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_00.png}{0084/00}
+
+Variations of ``Half Fill'', ``Detect'', ``Translate'', ``Grow'', and
+``Frame'' with a number of rectangles not equal to three.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_27.png}{0078/27}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_18.png}{0078/18}
+
+\pic{pics/culture_c_quiz_0086_N4_validated/quiz_45.png}{0086/45}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_37.png}{0078/37}
+
+Variations of ``Half Fill'' where the shapes to change have more
+complex coloring.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_30.png}{0078/30}
+
+Variation of ``Translate'' where the moving part is occluded, which
+was never the case.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_31.png}{0078/31}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_10.png}{0084/10}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_12.png}{0084/12}
+
+\pic{pics/culture_c_quiz_0086_N4_validated/quiz_23.png}{0086/23}
+
+\pic{pics/culture_c_quiz_0086_N4_validated/quiz_28.png}{0086/28}
+
+Variations of ``Half Fill'' with non-rectangular shapes.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_60.png}{0078/60}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_41.png}{0084/41}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_49.png}{0084/49}
+
+\pic{pics/culture_c_quiz_0086_N4_validated/quiz_04.png}{0086/04}
+
+Variations of ``Half Fill'' with two colors or two rectangles have to
+be modified.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0111_N4_validated/quiz_23.png}{0111/23}
+
+Variation of ``Frame'' with no rectangle of adequate size to be
+modified.
+
+\end{example}
+
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+\pagebreak
+
+\section{Bird World}
+
+These results were obtained with a slightly different procedure. In
+particular the quizzes were validated if the models could predict both
+the solution from the prompt and the prompt from the solution. We
+report them since they exhibit the same patterns of generalization
+although they are quite different.
+
+\subsection{World Quizzes}
+
+The initial set of quizzes consist of predicting the dynamics of a
+very simple world: A $6 \times 8$ grid with three colored ``birds'' moving in
+a straight line, possibly bouncing on the grid's borders. There are
+ten different colors.
+%
+\birdpic{pics/examples_train.png}{}
+%
+
+In each on these quizzes, $A$ is the left image serialized in
+raster-scan order as a sequence of $6 \times 8 = 48$ tokens, $d$ is
+either the token ``forward'' or the token ``backward'', and $B$ is the
+right image, also serialized. The direction of prediction is chosen at
+random.
+
+\subsection{Culture quizzes}
+
+This procedure results in the discovery of patterns which are not
+present in the original quizzes:
+
+\begin{example}
+
+\birdpic{pics/4_birds_1.png}{}
+
+\birdpic{pics/5_birds_1.png}{}
+
+\birdpic{pics/6_birds_1.png}{}
+
+More birds.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\birdpic{pics/other_shapes_2.png}{}
+
+\birdpic{pics/other_shapes_3.png}{}
+
+New bird shapes.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\birdpic{pics/other_shapes_1.png}{}
+
+\birdpic{pics/occlusions_1.png}{}
+
+Occlusions.
+
+\end{example}
+
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+\pagebreak
+
+\section{Various thoughts}
+
+\begin{itemize}
+
+\item The whole process can be envisioned as natural selection of
+  quizzes in the representation landscape of GPTs. There probably is a
+  subtle relation between the temperature (mutation rate) and the
+  number of models used to validate with the ``all but one'' criterion
+  (survival criterion).
+
+\item The ``all but one'' could be ``all but K'', and there may be
+  some information-theoretical thing, where the goal is to maximize
+  mutual information, with $K=N$ being total randomness, so high
+  entropy but no structure, and $K=0$ is total determinism, so no
+  information to share.
+
+\item The setup does not push toward any specific invariance or
+  property in the generated quizzes, their consistency is entirely due
+  to the statistics of the ``world quizzes'' that remain in the
+  training set, and to the GPTs' inductive biased.
+
+\item The GPTs obviously get a sense of objectness and 2d topology
+  early on, since they rapidly increase the number of birds and
+  ``discover'' occlusion even though they never was in the world
+  quizzes.
+
+\item There may not be so many problems that can be cast as pairs of
+  patterns that are each a deterministic function of the other, which
+  is probably critical here.
+
+\item This overall process probably fight the ``simplicity bias'': If
+  a model is lacking a ``cue'' that the others have, there will
+  rapidly be quizzes that require this cue, they will be added to the
+  training data, and that model will catch up.
+
+\item The randomness of the process probably allow to even go beyond
+  just synchronizing the abilities of the models. There may be some
+  additional complexification of quizzes that get accepted by chance.
+
+\item It can be parallelized by dispatching the GPTs across multiples
+  nodes, and avoiding a quadratic cost by limiting the validation of
+  the quizzes to a subset of them.
+
+\item The current process to generate new quizzes, which simply
+  samples them at random is very rudimentary and probably not
+  sufficient in a real-data setup. It can probably be supplemented
+  with a MCTS-type search.
+
+\item There may be already in the generated quizzes some structure
+  that \emph{we} do not pick up (e.g. certain color or motion
+  patterns).
+
+\end{itemize}
+
+\section*{Appendix}
+
+The code is available at
+
+\medskip
+
+\centerline{\url{https://fleuret.org/git/culture}}
+
+The experiments are done with a GTX 4090.
+
+The GPT used has 37M parameters and the following structure:
+
+\begin{center}
+\begin{tabular}{lc}
+    \texttt{dim\_model}  & 512  \\
+    \texttt{dim\_keys}   & 64   \\
+    \texttt{dim\_hidden} & 2048 \\
+    \texttt{nb\_heads}   & 8    \\
+    \texttt{nb\_blocks}  & 12
+\end{tabular}
+\end{center}
+
+Adam, $\eta = 1e-4$, no scheduling.
+
+There are $N_{\text{train}}=250'000$ original quizzes for training and
+$N_{\text{test}} = 10'000$ for test.
+
+At each epoch, for both train and test samples, we mix original
+quizzes and the generated ones.
+
+For training for instance, if there are less than $N_{\text{train}}/2$
+new quizzes, we take all of them, otherwise we sample
+$N_{\text{train}}/2$ of them without replacement, and then we sample
+without replacement enough original quizzes to get $N_{\text{train}}$
+samples in total.
+
+We proceed similarly to get $N_{\text{test}}$ samples for test.
+
+\end{document}
diff --git a/report/pics/4_birds_1.png b/report/pics/4_birds_1.png
new file mode 100644 (file)
index 0000000..961b95d
Binary files /dev/null and b/report/pics/4_birds_1.png differ
diff --git a/report/pics/5_birds_1.png b/report/pics/5_birds_1.png
new file mode 100644 (file)
index 0000000..09870c7
Binary files /dev/null and b/report/pics/5_birds_1.png differ
diff --git a/report/pics/6_birds_1.png b/report/pics/6_birds_1.png
new file mode 100644 (file)
index 0000000..3717298
Binary files /dev/null and b/report/pics/6_birds_1.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_01.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_01.png
new file mode 100644 (file)
index 0000000..23aeac0
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_01.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_02.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_02.png
new file mode 100644 (file)
index 0000000..d17d796
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_02.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_16.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_16.png
new file mode 100644 (file)
index 0000000..8727132
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_16.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_18.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_18.png
new file mode 100644 (file)
index 0000000..7eef189
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_18.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_27.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_27.png
new file mode 100644 (file)
index 0000000..3b8eb8d
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_27.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_30.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_30.png
new file mode 100644 (file)
index 0000000..8f93b14
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_30.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_31.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_31.png
new file mode 100644 (file)
index 0000000..7deacbf
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_31.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_37.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_37.png
new file mode 100644 (file)
index 0000000..762bd89
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_37.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_42.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_42.png
new file mode 100644 (file)
index 0000000..ef92823
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_42.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_60.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_60.png
new file mode 100644 (file)
index 0000000..4ec6305
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_60.png differ
diff --git a/report/pics/culture_c_quiz_0084_N4_validated/quiz_00.png b/report/pics/culture_c_quiz_0084_N4_validated/quiz_00.png
new file mode 100644 (file)
index 0000000..6272aeb
Binary files /dev/null and b/report/pics/culture_c_quiz_0084_N4_validated/quiz_00.png differ
diff --git a/report/pics/culture_c_quiz_0084_N4_validated/quiz_10.png b/report/pics/culture_c_quiz_0084_N4_validated/quiz_10.png
new file mode 100644 (file)
index 0000000..72cd85e
Binary files /dev/null and b/report/pics/culture_c_quiz_0084_N4_validated/quiz_10.png differ
diff --git a/report/pics/culture_c_quiz_0084_N4_validated/quiz_12.png b/report/pics/culture_c_quiz_0084_N4_validated/quiz_12.png
new file mode 100644 (file)
index 0000000..b7b386d
Binary files /dev/null and b/report/pics/culture_c_quiz_0084_N4_validated/quiz_12.png differ
diff --git a/report/pics/culture_c_quiz_0084_N4_validated/quiz_21.png b/report/pics/culture_c_quiz_0084_N4_validated/quiz_21.png
new file mode 100644 (file)
index 0000000..a4eeb42
Binary files /dev/null and b/report/pics/culture_c_quiz_0084_N4_validated/quiz_21.png differ
diff --git a/report/pics/culture_c_quiz_0084_N4_validated/quiz_41.png b/report/pics/culture_c_quiz_0084_N4_validated/quiz_41.png
new file mode 100644 (file)
index 0000000..3edad5a
Binary files /dev/null and b/report/pics/culture_c_quiz_0084_N4_validated/quiz_41.png differ
diff --git a/report/pics/culture_c_quiz_0084_N4_validated/quiz_49.png b/report/pics/culture_c_quiz_0084_N4_validated/quiz_49.png
new file mode 100644 (file)
index 0000000..ee10aea
Binary files /dev/null and b/report/pics/culture_c_quiz_0084_N4_validated/quiz_49.png differ
diff --git a/report/pics/culture_c_quiz_0086_N4_validated/quiz_04.png b/report/pics/culture_c_quiz_0086_N4_validated/quiz_04.png
new file mode 100644 (file)
index 0000000..fa153ae
Binary files /dev/null and b/report/pics/culture_c_quiz_0086_N4_validated/quiz_04.png differ
diff --git a/report/pics/culture_c_quiz_0086_N4_validated/quiz_23.png b/report/pics/culture_c_quiz_0086_N4_validated/quiz_23.png
new file mode 100644 (file)
index 0000000..aab6285
Binary files /dev/null and b/report/pics/culture_c_quiz_0086_N4_validated/quiz_23.png differ
diff --git a/report/pics/culture_c_quiz_0086_N4_validated/quiz_28.png b/report/pics/culture_c_quiz_0086_N4_validated/quiz_28.png
new file mode 100644 (file)
index 0000000..f0edfb2
Binary files /dev/null and b/report/pics/culture_c_quiz_0086_N4_validated/quiz_28.png differ
diff --git a/report/pics/culture_c_quiz_0086_N4_validated/quiz_45.png b/report/pics/culture_c_quiz_0086_N4_validated/quiz_45.png
new file mode 100644 (file)
index 0000000..bc9c9d5
Binary files /dev/null and b/report/pics/culture_c_quiz_0086_N4_validated/quiz_45.png differ
diff --git a/report/pics/culture_c_quiz_0087_N4_validated/quiz_62.png b/report/pics/culture_c_quiz_0087_N4_validated/quiz_62.png
new file mode 100644 (file)
index 0000000..6a3c50f
Binary files /dev/null and b/report/pics/culture_c_quiz_0087_N4_validated/quiz_62.png differ
diff --git a/report/pics/culture_c_quiz_0089_N4_validated/quiz_28.png b/report/pics/culture_c_quiz_0089_N4_validated/quiz_28.png
new file mode 100644 (file)
index 0000000..f460aba
Binary files /dev/null and b/report/pics/culture_c_quiz_0089_N4_validated/quiz_28.png differ
diff --git a/report/pics/culture_c_quiz_0102_N4_validated/quiz_04.png b/report/pics/culture_c_quiz_0102_N4_validated/quiz_04.png
new file mode 100644 (file)
index 0000000..ccfe8cc
Binary files /dev/null and b/report/pics/culture_c_quiz_0102_N4_validated/quiz_04.png differ
diff --git a/report/pics/culture_c_quiz_0102_N4_validated/quiz_11.png b/report/pics/culture_c_quiz_0102_N4_validated/quiz_11.png
new file mode 100644 (file)
index 0000000..c92f874
Binary files /dev/null and b/report/pics/culture_c_quiz_0102_N4_validated/quiz_11.png differ
diff --git a/report/pics/culture_c_quiz_0108_N4_validated/quiz_31.png b/report/pics/culture_c_quiz_0108_N4_validated/quiz_31.png
new file mode 100644 (file)
index 0000000..9a0c00e
Binary files /dev/null and b/report/pics/culture_c_quiz_0108_N4_validated/quiz_31.png differ
diff --git a/report/pics/culture_c_quiz_0110_N4_validated/quiz_63.png b/report/pics/culture_c_quiz_0110_N4_validated/quiz_63.png
new file mode 100644 (file)
index 0000000..00a0650
Binary files /dev/null and b/report/pics/culture_c_quiz_0110_N4_validated/quiz_63.png differ
diff --git a/report/pics/culture_c_quiz_0111_N4_validated/quiz_23.png b/report/pics/culture_c_quiz_0111_N4_validated/quiz_23.png
new file mode 100644 (file)
index 0000000..3ddd703
Binary files /dev/null and b/report/pics/culture_c_quiz_0111_N4_validated/quiz_23.png differ
diff --git a/report/pics/culture_c_quiz_0115_N4_validated/quiz_37.png b/report/pics/culture_c_quiz_0115_N4_validated/quiz_37.png
new file mode 100644 (file)
index 0000000..8507fa4
Binary files /dev/null and b/report/pics/culture_c_quiz_0115_N4_validated/quiz_37.png differ
diff --git a/report/pics/culture_c_quiz_0120_N4_validated/quiz_05.png b/report/pics/culture_c_quiz_0120_N4_validated/quiz_05.png
new file mode 100644 (file)
index 0000000..7a27afe
Binary files /dev/null and b/report/pics/culture_c_quiz_0120_N4_validated/quiz_05.png differ
diff --git a/report/pics/examples_train.png b/report/pics/examples_train.png
new file mode 100644 (file)
index 0000000..d1b349f
Binary files /dev/null and b/report/pics/examples_train.png differ
diff --git a/report/pics/occlusions_1.png b/report/pics/occlusions_1.png
new file mode 100644 (file)
index 0000000..28c39ba
Binary files /dev/null and b/report/pics/occlusions_1.png differ
diff --git a/report/pics/other_shapes_1.png b/report/pics/other_shapes_1.png
new file mode 100644 (file)
index 0000000..620fd45
Binary files /dev/null and b/report/pics/other_shapes_1.png differ
diff --git a/report/pics/other_shapes_2.png b/report/pics/other_shapes_2.png
new file mode 100644 (file)
index 0000000..fa1e3d4
Binary files /dev/null and b/report/pics/other_shapes_2.png differ
diff --git a/report/pics/other_shapes_3.png b/report/pics/other_shapes_3.png
new file mode 100644 (file)
index 0000000..5779ebb
Binary files /dev/null and b/report/pics/other_shapes_3.png differ
diff --git a/report/pics/task_bounce.png b/report/pics/task_bounce.png
new file mode 100644 (file)
index 0000000..d62d165
Binary files /dev/null and b/report/pics/task_bounce.png differ
diff --git a/report/pics/task_color_grow.png b/report/pics/task_color_grow.png
new file mode 100644 (file)
index 0000000..d872af2
Binary files /dev/null and b/report/pics/task_color_grow.png differ
diff --git a/report/pics/task_count.png b/report/pics/task_count.png
new file mode 100644 (file)
index 0000000..84da321
Binary files /dev/null and b/report/pics/task_count.png differ
diff --git a/report/pics/task_detect.png b/report/pics/task_detect.png
new file mode 100644 (file)
index 0000000..0beb8af
Binary files /dev/null and b/report/pics/task_detect.png differ
diff --git a/report/pics/task_frame.png b/report/pics/task_frame.png
new file mode 100644 (file)
index 0000000..8d3e015
Binary files /dev/null and b/report/pics/task_frame.png differ
diff --git a/report/pics/task_grow.png b/report/pics/task_grow.png
new file mode 100644 (file)
index 0000000..ce08e0d
Binary files /dev/null and b/report/pics/task_grow.png differ
diff --git a/report/pics/task_replace_color.png b/report/pics/task_replace_color.png
new file mode 100644 (file)
index 0000000..d6c9582
Binary files /dev/null and b/report/pics/task_replace_color.png differ
diff --git a/report/pics/task_scale.png b/report/pics/task_scale.png
new file mode 100644 (file)
index 0000000..b0e3820
Binary files /dev/null and b/report/pics/task_scale.png differ
diff --git a/report/pics/task_trajectory.png b/report/pics/task_trajectory.png
new file mode 100644 (file)
index 0000000..20b7c2b
Binary files /dev/null and b/report/pics/task_trajectory.png differ
diff --git a/report/pics/task_translate.png b/report/pics/task_translate.png
new file mode 100644 (file)
index 0000000..5f2bc2a
Binary files /dev/null and b/report/pics/task_translate.png differ
diff --git a/tasks.py b/tasks.py
new file mode 100755 (executable)
index 0000000..80ffdbb
--- /dev/null
+++ b/tasks.py
@@ -0,0 +1,374 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, os, tqdm, warnings
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+from mygpt import BracketedSequence
+
+######################################################################
+
+
+def masked_inplace_autoregression(
+    model,
+    batch_size,
+    input,
+    ar_mask,
+    summed_logits,
+    temperature,
+    deterministic_synthesis,
+    forbidden_tokens=None,
+    logit_biases=None,
+    progress_bar_desc="autoregression",
+    device=torch.device("cpu"),
+):
+    assert input.size() == ar_mask.size()
+
+    batches = zip(input.split(batch_size), ar_mask.split(batch_size))
+
+    if progress_bar_desc is not None:
+        batches = tqdm.tqdm(
+            batches,
+            dynamic_ncols=True,
+            desc=progress_bar_desc,
+            total=(input.size(0) + batch_size - 1) // batch_size,
+        )
+
+    with torch.autograd.no_grad():
+        t = model.training
+        model.eval()
+
+        for input, ar_mask in batches:
+            model.masked_inplace_autoregression(
+                input=input,
+                ar_mask=ar_mask,
+                summed_logits=summed_logits,
+                temperature=temperature,
+                deterministic_synthesis=deterministic_synthesis,
+                forbidden_tokens=forbidden_tokens,
+                forced_biases=logit_biases,
+            )
+
+        model.train(t)
+
+
+######################################################################
+
+
+class Task:
+    def batches(self, split="train", nb_to_use=-1, desc=None):
+        pass
+
+    def vocabulary_size(self):
+        pass
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis
+    ):
+        pass
+
+
+######################################################################
+
+import world
+
+
+class World(Task):
+    def save_image(self, input, result_dir, filename, logger):
+        img = world.seq2img(input.to("cpu"), self.height, self.width)
+        image_name = os.path.join(result_dir, filename)
+        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
+        logger(f"wrote {image_name}")
+
+    def make_ar_mask(self, input):
+        b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
+        return b.long()[None, :].expand_as(input)
+
+    def __init__(
+        self,
+        nb_train_samples,
+        nb_test_samples,
+        batch_size,
+        result_dir=None,
+        logger=None,
+        device=torch.device("cpu"),
+    ):
+        super().__init__()
+
+        self.batch_size = batch_size
+        self.device = device
+        self.height = 6
+        self.width = 8
+
+        self.train_input = world.generate_seq(
+            nb_train_samples, height=self.height, width=self.width
+        ).to(device)
+
+        self.test_input = world.generate_seq(
+            nb_test_samples, height=self.height, width=self.width
+        ).to(device)
+
+        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+        self.train_quizzes = []
+        self.test_quizzes = []
+
+        if result_dir is not None:
+            self.save_image(
+                self.train_input[:72], result_dir, f"world_train.png", logger
+            )
+
+    def batches(self, split="train", desc=None):
+        assert split in {"train", "test"}
+        if split == "train":
+            input = self.train_input
+            quizzes = self.train_quizzes
+        else:
+            input = self.test_input
+            quizzes = self.test_quizzes
+
+        if len(quizzes) > 0:
+            quizzes = torch.cat(quizzes, dim=0)
+            if quizzes.size(0) > input.size(0) // 2:
+                i = torch.randperm(input.size(0))[: input.size(0) // 2]
+                quizzes = quizzes[i]
+
+            i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)]
+            input = input[i]
+
+            self.nb_batch_samples_world = input.size(0)
+            self.nb_batch_samples_quizzes = quizzes.size(0)
+
+            input = torch.cat([input, quizzes], dim=0)
+        else:
+            self.nb_batch_samples_world = input.size(0)
+            self.nb_batch_samples_quizzes = 0
+
+        # Shuffle
+        input = input[torch.randperm(input.size(0))]
+
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_codes
+
+    def produce_results(
+        self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
+    ):
+        def compute_accuracy(input, logger=None):
+            input = input[:nmax]
+            ar_mask = self.make_ar_mask(input)
+            result = input.clone() * (1 - ar_mask)
+
+            masked_inplace_autoregression(
+                model=model,
+                batch_size=self.batch_size,
+                input=result,
+                ar_mask=ar_mask,
+                summed_logits=None,
+                temperature=1.0,
+                deterministic_synthesis=deterministic_synthesis,
+                progress_bar_desc=None,
+                device=self.device,
+            )
+
+            nb_total, nb_correct = (
+                input.size(0),
+                (input == result).long().min(dim=1).values.sum(),
+            )
+
+            return nb_total, nb_correct
+
+        train_nb_total, train_nb_correct = compute_accuracy(self.train_input)
+
+        logger(
+            f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+        )
+
+        test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger)
+
+        logger(
+            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+        )
+
+        main_test_accuracy = test_nb_correct / test_nb_total
+        logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
+
+        ##############################
+
+        input = self.test_input[:96]
+        ar_mask = self.make_ar_mask(input)
+        result = input.clone() * (1 - ar_mask)
+
+        masked_inplace_autoregression(
+            model=model,
+            batch_size=self.batch_size,
+            input=result,
+            ar_mask=ar_mask,
+            summed_logits=None,
+            temperature=1.0,
+            deterministic_synthesis=deterministic_synthesis,
+            progress_bar_desc=None,
+            device=self.device,
+        )
+
+        self.save_image(
+            result[:72],
+            result_dir,
+            f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
+            logger,
+        )
+
+        return main_test_accuracy
+
+    def renew_samples(self, nb, for_train=True):
+        input = self.train_input if for_train else self.test_input
+        nb = min(nb, input.size(0))
+        input[:-nb] = input[nb:].clone()
+        input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
+            self.device
+        )
+
+    def store_new_quizzes(self, new_quizzes, for_train=True):
+        if for_train:
+            self.train_quizzes.append(new_quizzes)
+        else:
+            self.test_quizzes.append(new_quizzes)
+
+    def create_new_quizzes(
+        self,
+        n_epoch,
+        result_dir,
+        logger,
+        nb,
+        model,
+        other_models,
+        desired_average_logits=None,
+    ):
+        ###############################################################
+        # Generate quizzes with model
+
+        quizzes = torch.empty(
+            nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
+        )
+
+        ar_mask = torch.full(quizzes.size(), 1, device=self.device)
+        summed_logits = torch.empty(nb, device=self.device)
+
+        temperature = 1
+        d_temperature = 1
+
+        while True:
+            summed_logits[...] = 0
+
+            masked_inplace_autoregression(
+                model=model,
+                batch_size=self.batch_size,
+                input=quizzes,
+                ar_mask=ar_mask,
+                summed_logits=summed_logits,
+                temperature=temperature,
+                deterministic_synthesis=False,
+                progress_bar_desc="creating quizzes",
+                device=self.device,
+            )
+
+            average_logits = summed_logits.mean()
+
+            logger(f"{average_logits=} {desired_average_logits=}")
+
+            if desired_average_logits is None:
+                break
+
+            # Oh man that's ugly
+            if average_logits < desired_average_logits * 1.1:
+                if d_temperature > 0:
+                    d_temperature *= -0.5
+                temperature += d_temperature
+            elif average_logits > desired_average_logits:
+                if d_temperature < 0:
+                    d_temperature *= -0.5
+                temperature += d_temperature
+            else:
+                break
+
+            logger(f"changing temperature to {temperature}")
+
+        ###############################################################
+        # Create the reverse quizzes
+
+        l = self.height * self.width
+        direction = quizzes[:, l : l + 1]
+        direction = world.token_forward * (
+            direction == world.token_backward
+        ) + world.token_backward * (direction == world.token_forward)
+        reverse_quizzes = torch.cat(
+            [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1
+        )
+
+        ar_mask = self.make_ar_mask(quizzes)
+
+        ###############################################################
+        # Check how many of the other models can solve them in both
+        # directions
+
+        nb_correct = []
+
+        for m in other_models:
+            result = quizzes.clone()
+
+            masked_inplace_autoregression(
+                model=m,
+                batch_size=self.batch_size,
+                input=result,
+                ar_mask=ar_mask,
+                summed_logits=None,
+                temperature=1.0,
+                deterministic_synthesis=True,
+                progress_bar_desc="solving quizzes",
+                device=self.device,
+            )
+
+            correct = (quizzes == result).long().min(dim=-1).values
+
+            reverse_result = reverse_quizzes.clone()
+
+            masked_inplace_autoregression(
+                model=m,
+                batch_size=self.batch_size,
+                input=reverse_result,
+                ar_mask=ar_mask,
+                summed_logits=None,
+                temperature=1.0,
+                deterministic_synthesis=True,
+                progress_bar_desc="solving reversed quizzes",
+                device=self.device,
+            )
+
+            reverse_correct = (
+                (reverse_quizzes == reverse_result).long().min(dim=-1).values
+            )
+
+            nb_correct.append((correct * reverse_correct)[None, :])
+
+        nb_correct = torch.cat(nb_correct, dim=0)
+
+        # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
+        # with open(filename, "w") as f:
+        # for k in nb_correct:
+        # f.write(f"{k}\n")
+
+        return quizzes, nb_correct.sum(dim=0), summed_logits.mean()