Update. master
authorFrançois Fleuret <francois@fleuret.org>
Sat, 21 Sep 2024 03:15:35 +0000 (05:15 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 21 Sep 2024 03:15:35 +0000 (05:15 +0200)
56 files changed:
attae.py [new file with mode: 0755]
grids.py
main.py
mygpt.py [deleted file]
problem.py
quiz_machine.py [deleted file]
report/culture.tex [new file with mode: 0644]
report/pics/4_birds_1.png [new file with mode: 0644]
report/pics/5_birds_1.png [new file with mode: 0644]
report/pics/6_birds_1.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_01.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_02.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_16.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_18.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_27.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_30.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_31.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_37.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_42.png [new file with mode: 0644]
report/pics/culture_c_quiz_0078_N4_validated/quiz_60.png [new file with mode: 0644]
report/pics/culture_c_quiz_0084_N4_validated/quiz_00.png [new file with mode: 0644]
report/pics/culture_c_quiz_0084_N4_validated/quiz_10.png [new file with mode: 0644]
report/pics/culture_c_quiz_0084_N4_validated/quiz_12.png [new file with mode: 0644]
report/pics/culture_c_quiz_0084_N4_validated/quiz_21.png [new file with mode: 0644]
report/pics/culture_c_quiz_0084_N4_validated/quiz_41.png [new file with mode: 0644]
report/pics/culture_c_quiz_0084_N4_validated/quiz_49.png [new file with mode: 0644]
report/pics/culture_c_quiz_0086_N4_validated/quiz_04.png [new file with mode: 0644]
report/pics/culture_c_quiz_0086_N4_validated/quiz_23.png [new file with mode: 0644]
report/pics/culture_c_quiz_0086_N4_validated/quiz_28.png [new file with mode: 0644]
report/pics/culture_c_quiz_0086_N4_validated/quiz_45.png [new file with mode: 0644]
report/pics/culture_c_quiz_0087_N4_validated/quiz_62.png [new file with mode: 0644]
report/pics/culture_c_quiz_0089_N4_validated/quiz_28.png [new file with mode: 0644]
report/pics/culture_c_quiz_0102_N4_validated/quiz_04.png [new file with mode: 0644]
report/pics/culture_c_quiz_0102_N4_validated/quiz_11.png [new file with mode: 0644]
report/pics/culture_c_quiz_0108_N4_validated/quiz_31.png [new file with mode: 0644]
report/pics/culture_c_quiz_0110_N4_validated/quiz_63.png [new file with mode: 0644]
report/pics/culture_c_quiz_0111_N4_validated/quiz_23.png [new file with mode: 0644]
report/pics/culture_c_quiz_0115_N4_validated/quiz_37.png [new file with mode: 0644]
report/pics/culture_c_quiz_0120_N4_validated/quiz_05.png [new file with mode: 0644]
report/pics/examples_train.png [new file with mode: 0644]
report/pics/occlusions_1.png [new file with mode: 0644]
report/pics/other_shapes_1.png [new file with mode: 0644]
report/pics/other_shapes_2.png [new file with mode: 0644]
report/pics/other_shapes_3.png [new file with mode: 0644]
report/pics/task_bounce.png [new file with mode: 0644]
report/pics/task_color_grow.png [new file with mode: 0644]
report/pics/task_count.png [new file with mode: 0644]
report/pics/task_detect.png [new file with mode: 0644]
report/pics/task_frame.png [new file with mode: 0644]
report/pics/task_grow.png [new file with mode: 0644]
report/pics/task_replace_color.png [new file with mode: 0644]
report/pics/task_scale.png [new file with mode: 0644]
report/pics/task_trajectory.png [new file with mode: 0644]
report/pics/task_translate.png [new file with mode: 0644]
sky.py [deleted file]
wireworld.py [deleted file]

diff --git a/attae.py b/attae.py
new file mode 100755 (executable)
index 0000000..c04c5d3
--- /dev/null
+++ b/attae.py
@@ -0,0 +1,292 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+import math
+
+import torch
+
+from torch import nn
+from torch.nn import functional as F
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+
+######################################################################
+
+
+class VaswaniPositionalEncoding(nn.Module):
+    def __init__(self, len_max):
+        super().__init__()
+        self.len_max = len_max
+
+    def forward(self, x):
+        t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
+        j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
+        k = j % 2  # works with float, weird
+        pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
+        y = x + pe
+        return y
+
+
+######################################################################
+
+
+class WithResidual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, x):
+        return x + self.f(x)
+
+
+######################################################################
+
+
+def vanilla_attention(q, k, v):
+    a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+    a = a.softmax(dim=3)
+    y = torch.einsum("nhts,nhsd->nhtd", a, v)
+    return y
+
+
+######################################################################
+
+
+class MHAttention(nn.Module):
+    def __init__(
+        self,
+        dim_model,
+        dim_qk,
+        dim_v,
+        nb_heads=1,
+        attention=vanilla_attention,
+        attention_dropout=0.0,
+    ):
+        super().__init__()
+
+        def randw(*d):
+            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+        self.attention = attention
+        self.attention_dropout = attention_dropout
+        self.w_q = randw(nb_heads, dim_qk, dim_model)
+        self.w_k = randw(nb_heads, dim_qk, dim_model)
+        self.w_v = randw(nb_heads, dim_v, dim_model)
+        self.w_o = randw(nb_heads, dim_v, dim_model)
+
+    def forward(self, x_q, x_kv=None):
+        if x_kv is None:
+            x_kv = x_q
+
+        q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
+        k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k)
+        v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v)
+        y = self.attention(q, k, v)
+        y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
+
+        return y
+
+
+######################################################################
+
+
+class AttentionAE(nn.Module):
+    def __init__(
+        self,
+        vocabulary_size,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        dropout=0.0,
+        len_max=1e5,
+    ):
+        super().__init__()
+
+        assert dim_model % nb_heads == 0
+
+        self.embedding = nn.Sequential(
+            nn.Embedding(2 * vocabulary_size, dim_model),
+            nn.Dropout(dropout),
+        )
+
+        self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+        trunk_blocks = []
+
+        for b in range(nb_blocks):
+            trunk_blocks += [
+                WithResidual(
+                    nn.LayerNorm((dim_model,)),
+                    MHAttention(
+                        dim_model=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        attention=vanilla_attention,
+                        attention_dropout=dropout,
+                    ),
+                ),
+                WithResidual(
+                    nn.LayerNorm((dim_model,)),
+                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                    nn.ReLU(),
+                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                    nn.Dropout(dropout),
+                ),
+            ]
+
+        self.trunk = nn.Sequential(*trunk_blocks)
+
+        self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
+
+    def forward(self, x):
+        x = self.embedding(x)
+        x = self.positional_encoding(x)
+        x = self.trunk(x)
+        x = self.readout(x)
+        return x
+
+
+######################################################################
+
+
+class WithMaskedResidual(nn.Module):
+    def __init__(self, masker, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+        self.masker = masker
+        self.mask = None
+
+    def forward(self, x):
+        if self.mask is None:
+            self.mask = self.masker(x)
+        return self.mask * x + self.f(x)
+
+
+######################################################################
+
+
+class FunctionalAttentionAE(nn.Module):
+    def __init__(
+        self,
+        vocabulary_size,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        nb_work_tokens=100,
+        dropout=0.0,
+        len_max=1e5,
+    ):
+        super().__init__()
+
+        assert dim_model % nb_heads == 0
+
+        self.nb_work_tokens = nb_work_tokens
+
+        self.embedding = nn.Sequential(
+            nn.Embedding(2 * vocabulary_size, dim_model),
+            nn.Dropout(dropout),
+        )
+
+        self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+        trunk_blocks = []
+
+        def no_peek_attention(q, k, v):
+            a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+            n = self.nb_work_tokens
+            s = (q.size(2) - n) // 2
+            a[:, :, n + 1 * s : n + 2 * s, n + 0 * s : n + 1 * s] = float("-inf")
+            a[:, :, n + 0 * s : n + 1 * s, n + 1 * s : n + 2 * s] = float("-inf")
+            a = a.softmax(dim=3)
+            y = torch.einsum("nhts,nhsd->nhtd", a, v)
+            return y
+
+        def masker(x):
+            m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens
+            return m[None, :, None]
+
+        for b in range(nb_blocks):
+            trunk_blocks += [
+                WithMaskedResidual(
+                    masker,
+                    nn.LayerNorm((dim_model,)),
+                    MHAttention(
+                        dim_model=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        attention=no_peek_attention,
+                        attention_dropout=dropout,
+                    ),
+                ),
+                WithMaskedResidual(
+                    masker,
+                    nn.LayerNorm((dim_model,)),
+                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                    nn.ReLU(),
+                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                    nn.Dropout(dropout),
+                ),
+            ]
+
+        self.trunk = nn.Sequential(*trunk_blocks)
+
+        self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
+
+    def forward(self, x):
+        x = self.embedding(x)
+        x = F.pad(x, (0, 0, self.nb_work_tokens, 0))
+        x = self.positional_encoding(x)
+        x = self.trunk(x)
+        x = F.pad(x, (0, 0, -self.nb_work_tokens, 0))
+        x = self.readout(x)
+        return x
+
+
+######################################################################
+
+
+if __name__ == "__main__":
+    model = FunctionalAttentionAE(
+        vocabulary_size=100,
+        dim_model=16,
+        dim_keys=64,
+        dim_hidden=32,
+        nb_heads=4,
+        nb_work_tokens=10,
+        nb_blocks=4,
+        dropout=0.1,
+    )
+
+    x = torch.randint(100, (10, 50))
+    y = model(x)
+
+    with torch.no_grad():
+        model.eval()
+        x = torch.randint(100, (10, 50))
+        y = model(x)
+
+        print(y)
index eea8c6c..78d9297 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, sys, tqdm, os, warnings
+import math, sys, tqdm, os, warnings, cairo, re
 
 import torch, torchvision
 
@@ -14,6 +14,36 @@ from torch.nn import functional as F
 
 ######################################################################
 
+
+def text_img(height, width, text):
+    pixel_map = torch.full((height, width, 4), 255, dtype=torch.uint8)
+
+    surface = cairo.ImageSurface.create_for_data(
+        pixel_map.numpy(), cairo.FORMAT_ARGB32, pixel_map.size(1), pixel_map.size(0)
+    )
+
+    ctx = cairo.Context(surface)
+    ctx.set_source_rgb(0, 0, 0)
+    ctx.set_font_size(16)
+    ctx.select_font_face("courier", cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
+    y = None
+    for line in text.split("\n"):
+        xbearing, ybearing, width, height, dx, dy = ctx.text_extents(line)
+        if y is None:
+            y = height * 1.5
+            x = height * 0.5
+
+        ctx.move_to(x, y)
+        ctx.show_text(line)
+        y += height * 1.5
+
+    ctx.stroke()
+
+    return pixel_map.permute(2, 0, 1)[None, :3].contiguous()
+
+
+######################################################################
+
 import problem
 
 
@@ -104,20 +134,105 @@ def grow_islands(nb, height, width, nb_seeds, nb_iterations):
 
 
 class Grids(problem.Problem):
+    # grid_gray = 64
+    # thickness = 1
+    # background_gray = 255
+    # dots = False
+
+    grid_gray = 240
+    thickness = 0
+    background_gray = 240
+    dots = False
+
+    # grid_gray = 192
+    # thickness = 0
+    # background_gray = 255
+    # dots = True
+
     named_colors = [
-        ("white", [255, 255, 255]),
+        ("white", [background_gray, background_gray, background_gray]),
+        # ("white", [224, 224, 224]),
         ("red", [255, 0, 0]),
-        ("green", [0, 192, 0]),
+        ("green", [0, 160, 0]),
         ("blue", [0, 0, 255]),
         ("yellow", [255, 224, 0]),
         ("cyan", [0, 255, 255]),
         ("violet", [224, 128, 255]),
-        ("lightgreen", [192, 255, 192]),
+        ("lightgreen", [160, 255, 160]),
         ("brown", [165, 42, 42]),
         ("lightblue", [192, 192, 255]),
         ("gray", [128, 128, 128]),
     ]
 
+    def pure_noise(self, nb, device):
+        result = torch.randint(
+            self.nb_colors, (nb, 4 * (self.height * self.height)), device=device
+        )
+        return result
+
+    def trivial(self, quizzes):
+        S = self.height * self.width
+        assert self.check_order(quizzes, quad_order=("A", "f_A", "B", "f_B"))
+        a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
+        return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min(
+            dim=1
+        ).values
+
+    def text2quiz(self, t):
+        chr2col = [
+            (".", "white"),
+            ("r", "red"),
+            ("g", "green"),
+            ("b", "blue"),
+            ("y", "yellow"),
+            ("c", "cyan"),
+            ("v", "violet"),
+            ("l", "lightgreen"),
+            ("o", "brown"),
+            ("l", "lightblue"),
+            ("a", "gray"),
+        ]
+
+        col2tok = dict([(c[0], n) for n, c in enumerate(self.named_colors)])
+        chr2tok = dict([(c, col2tok[col]) for c, col in chr2col])
+
+        t = re.sub(r"#.*\n", "", t).strip()
+        l = t.replace("\n\n", ";").split(";")
+
+        result = []
+
+        for t in l:
+            t = "".join(t.replace("\n", " ").strip().split(" "))
+            t = torch.tensor([chr2tok[c] for c in t])
+            t = t.reshape(10, 4, 10).permute(1, 0, 2).flatten(1)
+            t = torch.cat(
+                [
+                    torch.tensor(
+                        [
+                            [self.token_A],
+                            [self.token_f_A],
+                            [self.token_B],
+                            [self.token_f_B],
+                        ]
+                    ),
+                    t,
+                ],
+                dim=1,
+            )
+            result.append(t.flatten()[None, :])
+
+        return torch.cat(result, dim=0)
+
+    def indices_select(self, quizzes, quad_order=("A", "f_A", "B", "f_B")):
+        S = self.height * self.width
+        q = quizzes.reshape(quizzes.size(0), 4, S + 1)
+        return (
+            (q[:, 0, 0] == self.l2tok[quad_order[0]])
+            & (q[:, 1, 0] == self.l2tok[quad_order[1]])
+            & (q[:, 2, 0] == self.l2tok[quad_order[2]])
+            & (q[:, 3, 0] == self.l2tok[quad_order[3]])
+        )
+
     def __init__(
         self,
         max_nb_cached_chunks=None,
@@ -126,24 +241,40 @@ class Grids(problem.Problem):
         tasks=None,
     ):
         self.colors = torch.tensor([c for _, c in self.named_colors])
+
+        self.nb_colors = len(self.colors)
+
+        self.nb_rec_max = 5
+        self.rfree = torch.tensor([])
+
         self.height = 10
         self.width = 10
+        self.seq_len = 4 * self.height * self.width
+
         self.cache_rec_coo = {}
 
         all_tasks = [
+            ############################################ fundamental ones
             self.task_replace_color,
             self.task_translate,
             self.task_grow,
-            self.task_half_fill,
             self.task_frame,
+            ############################################
+            ############################################
+            self.task_half_fill,
             self.task_detect,
-            self.task_count,
-            self.task_trajectory,
-            self.task_bounce,
             self.task_scale,
             self.task_symbols,
+            self.task_corners,
+            self.task_contact,
+            self.task_path,
+            self.task_fill,
+            ############################################ hard ones
             self.task_isometry,
-            #            self.task_islands,
+            self.task_trajectory,
+            self.task_bounce,
+            # self.task_count, # NOT REVERSIBLE
+            # self.task_islands, # TOO MESSY
         ]
 
         if tasks is None:
@@ -155,147 +286,206 @@ class Grids(problem.Problem):
 
     ######################################################################
 
-    def frame2img(self, x, scale=15):
-        x = x.reshape(x.size(0), self.height, -1)
-        m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
-        x = self.colors[x * m].permute(0, 3, 1, 2)
-        s = x.shape
-        x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
-        x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
-
-        x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
-        x[:, :, torch.arange(0, x.size(2), scale), :] = 0
-        x = x[:, :, 1:, 1:]
+    def vocabulary_size(self):
+        # warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning)
+        # return self.nb_colors+4
+        return self.nb_colors
+
+    def grid2img(self, x, scale=15, grids=True):
+        m = torch.logical_and(x >= 0, x < self.nb_colors).long()
+        y = self.colors[x * m].permute(0, 3, 1, 2)
+        s = y.shape
+        y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+        y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
+
+        if grids:
+            for t in range(self.thickness):
+                y[:, :, :, torch.arange(t, y.size(3), scale)] = self.grid_gray
+                y[:, :, torch.arange(t, y.size(2), scale), :] = self.grid_gray
+        if self.dots:
+            z = y.reshape(
+                y.size(0),
+                y.size(1),
+                y.size(2) // scale,
+                scale,
+                y.size(3) // scale,
+                scale,
+            )
+            z = z[
+                :,
+                :,
+                :,
+                scale // 2 - 1 : scale // 2 + 2,
+                :,
+                scale // 2 - 1 : scale // 2 + 2,
+            ]
+            zz = (z == self.background_gray).min(dim=1, keepdim=True).values
+            z[...] = zz * self.grid_gray + (zz == False) * z
 
         for n in range(m.size(0)):
             for i in range(m.size(1)):
                 for j in range(m.size(2)):
-                    if m[n, i, j] == 0:
-                        for k in range(2, scale - 2):
-                            for l in [0, 1]:
-                                x[n, :, i * scale + k, j * scale + k - l] = 0
-                                x[
-                                    n, :, i * scale + scale - 1 - k, j * scale + k - l
-                                ] = 0
+                    if x[n, i, j] >= self.nb_colors:
+                        # for k in range(3, scale - 2):
+                        c = self.colors[x[n, i, j] - self.nb_colors][:, None, None]
+                        # y[n, :, i * scale + k, j * scale + k] = c
+                        # y[n, :, i * scale + k, j * scale + scale - k] = c
+                        y[
+                            n,
+                            :,
+                            i * scale + 3 : i * scale + scale - 2,
+                            j * scale + 3 : j * scale + scale - 2,
+                        ] = c
+
+        y = y[:, :, 1:, 1:]
+
+        return y
+
+    def add_frame(self, img, colors, thickness):
+        if thickness > 0:
+            result = img.new(
+                img.size(0),
+                img.size(1),
+                img.size(2) + 2 * thickness,
+                img.size(3) + 2 * thickness,
+            )
 
-        return x
+            result[...] = colors[:, :, None, None]
+            result[:, :, thickness:-thickness, thickness:-thickness] = img
+        else:
+            result = img
 
-    def save_image(
+        return result
+
+    def save_quizzes_as_image(
         self,
         result_dir,
         filename,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
+        quizzes,
+        predicted_parts=None,
+        correct_parts=None,
+        comments=None,
+        comment_height=48,
         nrow=4,
-        margin=8,
+        grids=True,
+        margin=12,
+        delta=False,
+        delta_highlight=False,
     ):
+        quizzes = quizzes.to("cpu")
+
         S = self.height * self.width
-        As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
-        f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
-            -1, self.height, self.width
-        )
-        Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
-        prompts = torch.cat([As, f_As, Bs], dim=2)
-        answers = answers.reshape(answers.size(0), self.height, self.width)
 
-        if predicted_prompts is None:
-            predicted_prompts = 255
+        A, f_A, B, f_B = (
+            quizzes.reshape(quizzes.size(0), 4, S)
+            .reshape(quizzes.size(0), 4, self.height, self.width)
+            .permute(1, 0, 2, 3)
+        )
 
-        if predicted_answers is None:
-            predicted_answers = 255
+        frame, white, gray, green, red = torch.tensor(
+            [
+                [self.grid_gray, self.grid_gray, self.grid_gray],
+                [255, 255, 255],
+                [200, 200, 200],
+                [0, 255, 0],
+                [255, 0, 0],
+            ],
+            device=quizzes.device,
+        )
 
-        def add_frame(x, c, margin, bottom=False):
-            if bottom:
-                h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
-            else:
-                h, w, di, dj = (
-                    x.size(2) + 2 * margin,
-                    x.size(3) + 2 * margin,
-                    margin,
-                    margin,
-                )
+        thickness = self.thickness
 
-            y = x.new_full((x.size(0), x.size(1), h, w), 0)
+        if delta:
+            u = (A != f_A).long()
+            img_delta_A = self.add_frame(
+                self.grid2img(u, grids=grids), frame[None, :], thickness=thickness
+            )
+            img_delta_A = img_delta_A.min(dim=1, keepdim=True).values.expand_as(
+                img_delta_A
+            )
+            u = (B != f_B).long()
+            img_delta_B = self.add_frame(
+                self.grid2img(u, grids=grids), frame[None, :], thickness=thickness
+            )
+            img_delta_B = img_delta_B.min(dim=1, keepdim=True).values.expand_as(
+                img_delta_B
+            )
 
-            if type(c) is int:
-                y[...] = c
-            else:
-                c = c.long()[:, None]
-                c = (
-                    (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
-                    * torch.tensor([64, 64, 64])
-                    + (c == 1).long() * torch.tensor([0, 255, 0])
-                    + (c == 0).long() * torch.tensor([255, 255, 255])
-                    + (c == -1).long() * torch.tensor([255, 0, 0])
-                )
-                y[...] = c[:, :, None, None]
+        img_A = self.add_frame(
+            self.grid2img(A, grids=grids), frame[None, :], thickness=thickness
+        )
+        img_f_A = self.add_frame(
+            self.grid2img(f_A, grids=grids), frame[None, :], thickness=thickness
+        )
+        img_B = self.add_frame(
+            self.grid2img(B, grids=grids), frame[None, :], thickness=thickness
+        )
+        img_f_B = self.add_frame(
+            self.grid2img(f_B, grids=grids), frame[None, :], thickness=thickness
+        )
 
-            y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
+        if delta_highlight:
+            q = (img_B == img_f_B).min(dim=1, keepdim=True).values.long()
+            img_f_B = q * (img_f_B // 4 + 192) + (1 - q) * img_f_B
 
-            return y
+        # predicted_parts Nx4
+        # correct_parts Nx4
 
-        img_prompts = torch.cat(
-            [
-                add_frame(
-                    add_frame(self.frame2img(x), c=0, margin=1),
-                    c=predicted_prompts,
-                    margin=margin,
+        if predicted_parts is None:
+            colors = white[None, None, :].expand(-1, 4, -1)
+        else:
+            predicted_parts = predicted_parts.to("cpu")
+            if correct_parts is None:
+                colors = (
+                    predicted_parts[:, :, None] * gray[None, None, :]
+                    + (1 - predicted_parts[:, :, None]) * white[None, None, :]
+                )
+            else:
+                correct_parts = correct_parts.to("cpu")
+                colors = (
+                    predicted_parts[:, :, None]
+                    * (
+                        (correct_parts[:, :, None] == 1).long() * green[None, None, :]
+                        + (correct_parts[:, :, None] == 0).long() * gray[None, None, :]
+                        + (correct_parts[:, :, None] == -1).long() * red[None, None, :]
+                    )
+                    + (1 - predicted_parts[:, :, None]) * white[None, None, :]
                 )
-                for x in prompts.to("cpu").split(split_size=self.width, dim=2)
-            ],
-            dim=3,
-        )
 
-        h = img_prompts.size(2)
-        img_answers = add_frame(
-            add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
-            c=predicted_answers,
-            margin=margin,
-        )
+        separation = 6
 
-        separator_size = 2 * margin
+        img_A = self.add_frame(img_A, colors[:, 0], thickness=separation)
+        img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=separation)
+        img_B = self.add_frame(img_B, colors[:, 2], thickness=separation)
+        img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=separation)
 
-        separator = img_prompts.new_full(
-            (
-                img_prompts.size(0),
-                img_prompts.size(1),
-                img_prompts.size(2),
-                separator_size,
-            ),
-            255,
-        )
+        img_A = self.add_frame(img_A, white[None, :], thickness=2)
+        img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2)
+        img_B = self.add_frame(img_B, white[None, :], thickness=2)
+        img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2)
 
-        marker = img_prompts.new_full(
-            (
-                img_prompts.size(0),
-                img_prompts.size(1),
-                img_prompts.size(2),
-                separator_size,
-            ),
-            255,
-        )
-
-        # marker[:, :, 0] = 0
-        # marker[:, :, h - 1] = 0
-
-        for k in range(1, 2 * separator_size - 8):
-            i = k - (separator_size - 4)
-            j = separator_size - 5 - abs(i)
-            marker[:, :, h // 2 - 1 + i, 2 + j] = 0
-            marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+        if delta:
+            img_delta_A = self.add_frame(
+                img_delta_A, colors[:, 0], thickness=separation
+            )
+            img_delta_A = self.add_frame(img_delta_A, white[None, :], thickness=2)
+            img_delta_B = self.add_frame(
+                img_delta_B, colors[:, 0], thickness=separation
+            )
+            img_delta_B = self.add_frame(img_delta_B, white[None, :], thickness=2)
+            img = torch.cat(
+                [img_A, img_f_A, img_delta_A, img_B, img_f_B, img_delta_B], dim=3
+            )
+        else:
+            img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3)
 
-        img = torch.cat(
-            [
-                img_prompts,
-                marker,
-                img_answers,
-            ],
-            dim=3,
-        )
+        if comments is not None:
+            comment_img = [text_img(comment_height, img.size(3), t) for t in comments]
+            comment_img = torch.cat(comment_img, dim=0)
+            img = torch.cat([img, comment_img], dim=2)
 
         image_name = os.path.join(result_dir, filename)
+
         torchvision.utils.save_image(
             img.float() / 255.0,
             image_name,
@@ -306,9 +496,6 @@ class Grids(problem.Problem):
 
     ######################################################################
 
-    def nb_token_values(self):
-        return len(self.colors)
-
     # @torch.compile
     def rec_coo(
         self,
@@ -335,7 +522,8 @@ class Grids(problem.Problem):
             while True:
                 i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
                 j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
-
+                i[:, 1] += 1
+                j[:, 1] += 1
                 big_enough = (
                     (i[:, 1] >= i[:, 0] + min_height)
                     & (j[:, 1] >= j[:, 0] + min_height)
@@ -370,13 +558,13 @@ class Grids(problem.Problem):
                         j[:, 1, 0],
                         j[:, 1, 1],
                     )
-                    no_overlap = torch.logical_not(
+                    no_overlap = (
                         (A_i1 >= B_i2)
-                        & (A_i2 <= B_i1)
-                        & (A_j1 >= B_j1)
-                        & (A_j2 <= B_j1)
+                        | (A_i2 <= B_i1)
+                        | (A_j1 >= B_j2)
+                        | (A_j2 <= B_j1)
                     )
-                    i, j = i[no_overlap], j[no_overlap]
+                    i, j = (i[no_overlap], j[no_overlap])
                 elif nb_rec == 3:
                     A_i1, A_i2, A_j1, A_j2 = (
                         i[:, 0, 0],
@@ -440,10 +628,109 @@ class Grids(problem.Problem):
 
     ######################################################################
 
+    def contact_matrices(self, rn, ri, rj, rz):
+        n = torch.arange(self.nb_rec_max)
+        return (
+            (
+                (
+                    (
+                        (ri[:, :, None, 0] == ri[:, None, :, 1] + 1)
+                        | (ri[:, :, None, 1] + 1 == ri[:, None, :, 0])
+                    )
+                    & (rj[:, :, None, 0] <= rj[:, None, :, 1])
+                    & (rj[:, :, None, 1] >= rj[:, None, :, 0])
+                )
+                | (
+                    (
+                        (rj[:, :, None, 0] == rj[:, None, :, 1] + 1)
+                        | (rj[:, :, None, 1] + 1 == rj[:, None, :, 0])
+                    )
+                    & (ri[:, :, None, 0] <= ri[:, None, :, 1])
+                    & (ri[:, :, None, 1] >= ri[:, None, :, 0])
+                )
+            )
+            # & (rz[:, :, None] == rz[:, None, :])
+            & (n[None, :, None] < rn[:, None, None])
+            & (n[None, None, :] < n[None, :, None])
+        )
+
+    def sample_rworld_states(self, N=1000):
+        while True:
+            ri = (
+                torch.randint(self.height - 2, (N, self.nb_rec_max, 2))
+                .sort(dim=2)
+                .values
+            )
+            ri[:, :, 1] += 2
+            rj = (
+                torch.randint(self.width - 2, (N, self.nb_rec_max, 2))
+                .sort(dim=2)
+                .values
+            )
+            rj[:, :, 1] += 2
+            rn = torch.randint(self.nb_rec_max - 1, (N,)) + 2
+            rz = torch.randint(2, (N, self.nb_rec_max))
+            rc = torch.randint(self.nb_colors - 1, (N, self.nb_rec_max)) + 1
+            n = torch.arange(self.nb_rec_max)
+            nb_collisions = (
+                (
+                    (ri[:, :, None, 0] <= ri[:, None, :, 1])
+                    & (ri[:, :, None, 1] >= ri[:, None, :, 0])
+                    & (rj[:, :, None, 0] <= rj[:, None, :, 1])
+                    & (rj[:, :, None, 1] >= rj[:, None, :, 0])
+                    & (rz[:, :, None] == rz[:, None, :])
+                    & (n[None, :, None] < rn[:, None, None])
+                    & (n[None, None, :] < n[None, :, None])
+                )
+                .long()
+                .flatten(1)
+                .sum(dim=1)
+            )
+
+            no_collision = nb_collisions == 0
+
+            if no_collision.any():
+                print(no_collision.long().sum() / N)
+                self.rn = rn[no_collision]
+                self.ri = ri[no_collision]
+                self.rj = rj[no_collision]
+                self.rz = rz[no_collision]
+                self.rc = rc[no_collision]
+
+                nb_contact = (
+                    self.contact_matrices(rn, ri, rj, rz).long().flatten(1).sum(dim=1)
+                )
+
+                self.rcontact = nb_contact > 0
+                self.rfree = torch.full((self.rn.size(0),), True)
+
+                break
+
+    def get_recworld_state(self):
+        if not self.rfree.any():
+            self.sample_rworld_states()
+        k = torch.arange(self.rn.size(0))[self.rfree]
+        k = k[torch.randint(k.size(0), (1,))].item()
+        self.rfree[k] = False
+        return self.rn[k], self.ri[k], self.rj[k], self.rz[k], self.rc[k]
+
+    def draw_state(self, X, rn, ri, rj, rz, rc):
+        for n in sorted(list(range(rn)), key=lambda n: rz[n].item()):
+            X[ri[n, 0] : ri[n, 1] + 1, rj[n, 0] : rj[n, 1] + 1] = rc[n]
+
+    def task_recworld_immobile(self, A, f_A, B, f_B):
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            rn, ri, rj, rz, rc = self.get_recworld_state()
+            self.draw_state(X, rn, ri, rj, rz, rc)
+            ri += 1
+            self.draw_state(f_X, rn, ri, rj, rz, rc)
+
+    ######################################################################
+
     # @torch.compile
     def task_replace_color(self, A, f_A, B, f_B):
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
@@ -451,6 +738,29 @@ class Grids(problem.Problem):
                 X[i1:i2, j1:j2] = c[n]
                 f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
 
+    # @torch.compile
+    def task_symmetry(self, A, f_A, B, f_B):
+        a, b = torch.randint(2, (2,))
+        nb_rec = 3
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                if min([x[2] for x in r]) > self.height // 2 + 1:
+                    break
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                f_X[i1:i2, j1:j2] = c[n]
+            X[: self.height // 2] = 0
+            f_X[: self.height // 2] = f_X.flip([0])[: self.height // 2]
+            if a == 1:
+                X[...] = X.flip((0,))
+                f_X[...] = f_X.flip((0,))
+            if b == 1:
+                X[...] = X.clone().t()
+                f_X[...] = f_X.clone().t()
+
     # @torch.compile
     def task_translate(self, A, f_A, B, f_B):
         while True:
@@ -459,7 +769,7 @@ class Grids(problem.Problem):
                 break
 
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
                 r = self.rec_coo(nb_rec, prevent_overlap=True)
@@ -484,7 +794,7 @@ class Grids(problem.Problem):
     def task_grow(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
         direction = torch.randint(2, (1,)).item()
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
@@ -510,7 +820,7 @@ class Grids(problem.Problem):
     def task_half_fill(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
+        c = torch.randperm(self.nb_colors - 1)[: 2 * nb_rec] + 1
         direction = torch.randint(4, (1,)).item()
         for X, f_X in [(A, f_A), (B, f_B)]:
             r = self.rec_coo(nb_rec, prevent_overlap=True)
@@ -551,7 +861,7 @@ class Grids(problem.Problem):
     # @torch.compile
     def task_frame(self, A, f_A, B, f_B):
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
@@ -568,14 +878,17 @@ class Grids(problem.Problem):
     # @torch.compile
     def task_detect(self, A, f_A, B, f_B):
         nb_rec = 3
-        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
+                f_X[i1:i2, j1:j2] = c[n]
                 if n < nb_rec - 1:
-                    f_X[i1, j1] = c[-1]
+                    for k in range(2):
+                        f_X[i1 + k, j1] = c[-1]
+                        f_X[i1, j1 + k] = c[-1]
 
     # @torch.compile
     def contact(self, X, i, j, q):
@@ -613,13 +926,13 @@ class Grids(problem.Problem):
 
         return no, nq, nq_diag
 
-    def task_count(self, A, f_A, B, f_B):
+    def REMOVED_task_count(self, A, f_A, B, f_B):
         while True:
             error = False
 
-            N = torch.randint(5, (1,)).item() + 1
-            c = torch.zeros(N + 1)
-            c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1
+            N = 3
+            c = torch.zeros(N + 2, dtype=torch.int64)
+            c[1:] = torch.randperm(self.nb_colors - 1)[: N + 1] + 1
 
             for X, f_X in [(A, f_A), (B, f_B)]:
                 if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
@@ -629,29 +942,31 @@ class Grids(problem.Problem):
                             self.height,
                             self.width,
                             nb_seeds=self.height * self.width // 8,
-                            nb_iterations=self.height * self.width // 10,
+                            nb_iterations=self.height * self.width // 5,
                         )
                     )
 
                 X[...] = self.cache_count.pop()
 
-                k = (X.max() + 1 + (c.size(0) - 1)).item()
-                V = torch.arange(k) // (c.size(0) - 1)
-                V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
-                    c.size(0) - 1
-                ) + 1
+                # k = (X.max() + 1 + (c.size(0) - 1)).item()
+                # V = torch.arange(k) // (c.size(0) - 1)
+                # V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
+                # c.size(0) - 1
+                # ) + 1
+
+                V = torch.randint(N, (X.max() + 1,)) + 1
                 V[0] = 0
+                NB = F.one_hot(c[V]).sum(dim=0)
                 X[...] = c[V[X]]
-
-                if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1:
-                    f_X[...] = 0
-                    for e in range(1, N + 1):
-                        for j in range((X == c[e]).sum() + 1):
-                            if j < self.width:
-                                f_X[e - 1, j] = c[e]
-                            else:
-                                error = True
-                                break
+                f_X[...] = X
+
+                if F.one_hot(X.flatten()).max(dim=0).values.sum().item() >= 3:
+                    m = NB[c[:-1]].max()
+                    if (NB[c[:-1]] == m).long().sum() == 1:
+                        for e in range(1, N + 1):
+                            if NB[c[e]] == m:
+                                a = (f_X == c[e]).long()
+                                f_X[...] = (1 - a) * f_X + a * c[-1]
                 else:
                     error = True
                     break
@@ -659,9 +974,11 @@ class Grids(problem.Problem):
             if not error:
                 break
 
+        assert F.one_hot(A.flatten()).max(dim=0).values.sum() >= 3
+
     # @torch.compile
     def task_trajectory(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:2] + 1
+        c = torch.randperm(self.nb_colors - 1)[:2] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
                 di, dj = torch.randint(7, (2,)) - 3
@@ -692,7 +1009,7 @@ class Grids(problem.Problem):
 
     # @torch.compile
     def task_bounce(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:3] + 1
+        c = torch.randperm(self.nb_colors - 1)[:3] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             # @torch.compile
             def free(i, j):
@@ -748,6 +1065,7 @@ class Grids(problem.Problem):
                     f_X[i, j] = c[2]
                     if l <= 1:
                         X[i, j] = c[2]
+                        f_X[i, j] = c[1]
 
                     if l >= self.width:
                         break
@@ -760,7 +1078,7 @@ class Grids(problem.Problem):
 
     # @torch.compile
     def task_scale(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:2] + 1
+        c = torch.randperm(self.nb_colors - 1)[:2] + 1
 
         i, j = (
             torch.randint(self.height // 2, (1,)).item(),
@@ -783,13 +1101,16 @@ class Grids(problem.Problem):
                 X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
                 f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
 
-            X[i, j] = c[1]
-            f_X[0:2, 0:2] = c[1]
+            for k in range(2):
+                X[i + k, j] = c[1]
+                X[i, j + k] = c[1]
+                f_X[i + k, j] = c[1]
+                f_X[i, j + k] = c[1]
 
     # @torch.compile
     def task_symbols(self, A, f_A, B, f_B):
         nb_rec = 4
-        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
         delta = 3
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
@@ -801,26 +1122,37 @@ class Grids(problem.Problem):
                 if d.min() > delta:
                     break
 
-            for k in range(1, nb_rec):
-                X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
-
             ai, aj = i.float().mean(), j.float().mean()
 
             q = torch.randint(3, (1,)).item() + 1
 
-            X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
-            X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
-            X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
-            X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
-
             assert i[q] != ai and j[q] != aj
 
-            X[
+            for Z in [X, f_X]:
+                for k in range(0, nb_rec):
+                    Z[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
+                # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
+                # Z[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
+                # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
+                # Z[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
+
+            # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+
+            f_X[i[0] + delta // 2, j[0] + delta // 2] = c[q]
+            # f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+
+            ii, jj = (
                 i[0] + delta // 2 + (i[q] - ai).sign().long(),
                 j[0] + delta // 2 + (j[q] - aj).sign().long(),
-            ] = c[nb_rec]
+            )
+
+            X[ii, jj] = c[nb_rec]
+            X[i[0] + delta // 2, jj] = c[nb_rec]
+            X[ii, j[0] + delta // 2] = c[nb_rec]
 
-            f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+            f_X[ii, jj] = c[nb_rec]
+            f_X[i[0] + delta // 2, jj] = c[nb_rec]
+            f_X[ii, j[0] + delta // 2] = c[nb_rec]
 
     # @torch.compile
     def task_isometry(self, A, f_A, B, f_B):
@@ -840,7 +1172,7 @@ class Grids(problem.Problem):
                 X[...] = 0
                 f_X[...] = 0
 
-                c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+                c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
 
                 for r in range(nb_rec):
                     while True:
@@ -906,8 +1238,8 @@ class Grids(problem.Problem):
                 return dist * (1 - walls)
 
     # @torch.compile
-    def task_distance(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:3] + 1
+    def REMOVED_task_distance(self, A, f_A, B, f_B):
+        c = torch.randperm(self.nb_colors - 1)[:3] + 1
         dist0 = torch.empty(self.height + 2, self.width + 2)
         dist1 = torch.empty(self.height + 2, self.width + 2)
         for X, f_X in [(A, f_A), (B, f_B)]:
@@ -969,10 +1301,10 @@ class Grids(problem.Problem):
     # if
 
     # @torch.compile
-    def task_puzzle(self, A, f_A, B, f_B):
+    def TOO_HARD_task_puzzle(self, A, f_A, B, f_B):
         S = 4
         i0, j0 = (self.height - S) // 2, (self.width - S) // 2
-        c = torch.randperm(len(self.colors) - 1)[:4] + 1
+        c = torch.randperm(self.nb_colors - 1)[:4] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
                 f_X[...] = 0
@@ -1037,8 +1369,8 @@ class Grids(problem.Problem):
                         if f_X[i + i0, j + j0] == c[d]:
                             X[ii + i, jj + j] = c[d]
 
-    def task_islands(self, A, f_A, B, f_B):
-        c = torch.randperm(len(self.colors) - 1)[:2] + 1
+    def TOO_MESSY_task_islands(self, A, f_A, B, f_B):
+        c = torch.randperm(self.nb_colors - 1)[:2] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0:
                 self.cache_islands = list(
@@ -1062,72 +1394,433 @@ class Grids(problem.Problem):
                     break
 
             X[...] = (A > 0) * c[0]
-            X[i, j] = c[1]
             f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0]
+            f_X[i, j] = X[i, j]
+            X[i, j] = c[1]
+
+    # @torch.compile
+    def TOO_HARD_task_stack(self, A, f_A, B, f_B):
+        N = 5
+        c = torch.randperm(self.nb_colors - 1)[:N] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            i1, j1, i2, j2 = (
+                self.height // 2 - 1,
+                self.width // 2 - 1,
+                self.height // 2 + 1,
+                self.width // 2 + 1,
+            )
+            op = torch.tensor((0, 1, 2, 3) * 4)
+            op = op[torch.randperm(op.size(0))[:9]]
+            for q in range(op.size(0)):
+                u = 3 * (q // 3)
+                v = 3 * (q % 3)
+                d = c[torch.randint(N, (1,)).item()]
+                # X[u+1,v+1]=d
+                if op[q] == 0:  # right
+                    X[u : u + 3, v + 2] = d
+                elif op[q] == 1:  # let
+                    X[u : u + 3, v] = d
+                elif op[q] == 2:  # bottom
+                    X[u + 2, v : v + 3] = d
+                elif op[q] == 3:  # top
+                    X[u, v : v + 3] = d
+
+                if q == 0:
+                    f_X[i1:i2, j1:j2] = d
+                elif op[q] == 0:  # right
+                    f_X[i1:i2, j2] = d
+                    j2 += 1
+                elif op[q] == 1:  # let
+                    j1 -= 1
+                    f_X[i1:i2, j1] = d
+                elif op[q] == 2:  # bottom
+                    f_X[i2, j1:j2] = d
+                    i2 += 1
+                elif op[q] == 3:  # top
+                    i1 -= 1
+                    f_X[i1, j1:j2] = d
+
+    def randint(self, *m):
+        m = torch.tensor(m)
+        return (torch.rand(m.size()) * m).long()
+
+    def TOO_HARD_task_matrices(self, A, f_A, B, f_B):
+        N = 6
+        c = torch.randperm(self.nb_colors - 1)[:N] + 1
+
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            M1 = torch.randint(2, (5, 5))
+            M2 = torch.randint(2, (5, 5))
+            P = M1 @ M2
+            for i in range(5):
+                for j in range(5):
+                    X[i, j] = c[M1[i, j]]
+                    X[i, j + 5] = c[M2[i, j]]
+                    f_X[i, j] = c[M1[i, j]]
+                    f_X[i, j + 5] = c[M2[i, j]]
+                    f_X[i + 5, j + 5] = c[P[i, j]]
+
+    def TOO_HARD_task_compute(self, A, f_A, B, f_B):
+        N = 6
+        c = torch.randperm(self.nb_colors - 1)[:N] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            v = torch.randint((self.width - 1) // 2, (N,)) + 1
+            chain = torch.randperm(N)
+            eq = []
+            for i in range(chain.size(0) - 1):
+                i1, i2 = chain[i], chain[i + 1]
+                v1, v2 = v[i1], v[i2]
+                k = torch.arange(self.width // 2) + 1
+                d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1
+                d = d[torch.randint(d.size(0), (1,)).item()]
+                w1, w2 = d
+                eq.append((c[i1], w1, c[i2], w2))
+
+            ii = torch.randperm(self.height - 2)[: len(eq)]
+
+            for k, x in enumerate(eq):
+                i = ii[k]
+                c1, w1, c2, w2 = x
+                s = torch.randint(self.width - (w1 + w2) + 1, (1,)).item()
+                X[i, s : s + w1] = c1
+                X[i, s + w1 : s + w1 + w2] = c2
+                f_X[i, s : s + w1] = c1
+                f_X[i, s + w1 : s + w1 + w2] = c2
+
+            i1, i2 = torch.randperm(N)[:2]
+            v1, v2 = v[i1], v[i2]
+            k = torch.arange(self.width // 2) + 1
+            d = ((k[None, :] * v1 - k[:, None] * v2) == 0).nonzero() + 1
+            d = d[torch.randint(d.size(0), (1,)).item()]
+            w1, w2 = d
+            c1, c2 = c[i1], c[i2]
+            s = 0  # torch.randint(self.width - (w1 + w2) + 1, (1,)).item()
+            i = self.height - 1
+            X[i, s : s + w1] = c1
+            X[i, s + w1 : s + w1 + 1] = c2
+            f_X[i, s : s + w1] = c1
+            f_X[i, s + w1 : s + w1 + w2] = c2
+
+    # @torch.compile
+    # [ai1,ai2] [bi1,bi2]
+    def task_contact(self, A, f_A, B, f_B):
+        def rec_dist(a, b):
+            ai1, aj1, ai2, aj2 = a
+            bi1, bj1, bi2, bj2 = b
+            v = max(ai1 - bi2, bi1 - ai2)
+            h = max(aj1 - bj2, bj1 - aj2)
+            return min(max(v, 0) + max(h + 1, 0), max(v + 1, 0) + max(h, 0))
+
+        nb_rec = 3
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                d = [rec_dist(r[0], r[k]) for k in range(nb_rec)]
+                if min(d[1:]) == 0:
+                    break
+
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                f_X[i1:i2, j1:j2] = c[n]
+                if d[n] == 0:
+                    f_X[i1, j1:j2] = c[0]
+                    f_X[i2 - 1, j1:j2] = c[0]
+                    f_X[i1:i2, j1] = c[0]
+                    f_X[i1:i2, j2 - 1] = c[0]
+
+    # @torch.compile
+    # [ai1,ai2] [bi1,bi2]
+    def task_corners(self, A, f_A, B, f_B):
+        polarity = torch.randint(2, (1,)).item()
+        nb_rec = 3
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            r = self.rec_coo(nb_rec, prevent_overlap=True)
+
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                for k in range(2):
+                    if polarity == 0:
+                        X[i1 + k, j1] = c[n]
+                        X[i2 - 1 - k, j2 - 1] = c[n]
+                        X[i1, j1 + k] = c[n]
+                        X[i2 - 1, j2 - 1 - k] = c[n]
+                    else:
+                        X[i1 + k, j2 - 1] = c[n]
+                        X[i2 - 1 - k, j1] = c[n]
+                        X[i1, j2 - 1 - k] = c[n]
+                        X[i2 - 1, j1 + k] = c[n]
+                    f_X[i1:i2, j1:j2] = c[n]
+
+    def compdist(self, X, i, j):
+        dd = X.new_full((self.height + 2, self.width + 2), self.height * self.width)
+        d = dd[1:-1, 1:-1]
+        m = (X > 0).long()
+        d[i, j] = 0
+        e = d.clone()
+        while True:
+            e[...] = d
+            d[...] = (
+                d.min(dd[:-2, 1:-1] + 1)
+                .min(dd[2:, 1:-1] + 1)
+                .min(dd[1:-1, :-2] + 1)
+                .min(dd[1:-1, 2:] + 1)
+            )
+            d[...] = (1 - m) * d + m * self.height * self.width
+            if e.equal(d):
+                break
+
+        return d
+
+    # @torch.compile
+    def task_path(self, A, f_A, B, f_B):
+        nb_rec = 2
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 2] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                X[...] = 0
+                f_X[...] = 0
+
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                for n in range(nb_rec):
+                    i1, j1, i2, j2 = r[n]
+                    X[i1:i2, j1:j2] = c[n]
+                    f_X[i1:i2, j1:j2] = c[n]
+
+                i1, i2 = torch.randint(self.height, (2,))
+                j1, j2 = torch.randint(self.width, (2,))
+                if (
+                    abs(i1 - i2) + abs(j1 - j2) > 2
+                    and X[i1, j1] == 0
+                    and X[i2, j2] == 0
+                ):
+                    d2 = self.compdist(X, i2, j2)
+                    d = self.compdist(X, i1, j1)
+
+                    if d2[i1, j1] < 2 * self.width:
+                        break
+
+            m = ((d + d2) == d[i2, j2]).long()
+            f_X[...] = m * c[-1] + (1 - m) * f_X
+
+            X[i1, j1] = c[-2]
+            X[i2, j2] = c[-2]
+            f_X[i1, j1] = c[-2]
+            f_X[i2, j2] = c[-2]
+
+    # @torch.compile
+    def task_fill(self, A, f_A, B, f_B):
+        nb_rec = 3
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            accept_full = torch.rand(1) < 0.5
+
+            while True:
+                X[...] = 0
+                f_X[...] = 0
+
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                for n in range(nb_rec):
+                    i1, j1, i2, j2 = r[n]
+                    X[i1:i2, j1:j2] = c[n]
+                    f_X[i1:i2, j1:j2] = c[n]
+
+                while True:
+                    i, j = (
+                        torch.randint(self.height, (1,)).item(),
+                        torch.randint(self.width, (1,)).item(),
+                    )
+                    if X[i, j] == 0:
+                        break
+
+                d = self.compdist(X, i, j)
+                m = (d < self.height * self.width).long()
+                X[i, j] = c[-1]
+                f_X[...] = m * c[-1] + (1 - m) * f_X
+                f_X[i, j] = 0
+
+                if accept_full or (d * (X == 0)).max() == self.height * self.width:
+                    break
+
+    def TOO_HARD_task_addition(self, A, f_A, B, f_B):
+        c = torch.randperm(self.nb_colors - 1)[:4] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            N1 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
+            N2 = torch.randint(2 ** (self.width - 1) - 1, (1,)).item()
+            S = N1 + N2
+            for j in range(self.width):
+                r1 = (N1 // (2**j)) % 2
+                X[0, -j - 1] = c[r1]
+                f_X[0, -j - 1] = c[r1]
+                r2 = (N2 // (2**j)) % 2
+                X[1, -j - 1] = c[r2]
+                f_X[1, -j - 1] = c[r2]
+                rs = (S // (2**j)) % 2
+                f_X[2, -j - 1] = c[2 + rs]
+
+    def task_science_implicit(self, A, f_A, B, f_B):
+        nb_rec = 5
+        c = torch.randperm(self.nb_colors - 1)[:nb_rec] + 1
+
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                i1, i2 = torch.randint(self.height, (2,)).sort().values
+                if i1 >= 1 and i2 < self.height and i1 + 3 < i2:
+                    break
+
+            while True:
+                j1, j2 = torch.randint(self.width, (2,)).sort().values
+                if j1 >= 1 and j2 < self.width and j1 + 3 < j2:
+                    break
+
+            f_X[i1:i2, j1:j2] = c[0]
+
+            # ---------------------
+
+            while True:
+                ii1, ii2 = torch.randint(self.height, (2,)).sort().values
+                if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
+                    break
+            jj = torch.randint(j1, (1,))
+            X[ii1:ii2, jj:j1] = c[1]
+            f_X[ii1:ii2, jj:j1] = c[1]
+
+            while True:
+                ii1, ii2 = torch.randint(self.height, (2,)).sort().values
+                if ii1 >= i1 and ii2 <= i2 and ii1 + 1 < ii2:
+                    break
+            jj = torch.randint(self.width - j2, (1,)) + j2 + 1
+            X[ii1:ii2, j2:jj] = c[2]
+            f_X[ii1:ii2, j2:jj] = c[2]
+
+            # ---------------------
+
+            while True:
+                jj1, jj2 = torch.randint(self.width, (2,)).sort().values
+                if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
+                    break
+            ii = torch.randint(i1, (1,))
+            X[ii:i1, jj1:jj2] = c[3]
+            f_X[ii:i1, jj1:jj2] = c[3]
+
+            while True:
+                jj1, jj2 = torch.randint(self.width, (2,)).sort().values
+                if jj1 >= j1 and jj2 <= j2 and jj1 + 1 < jj2:
+                    break
+            ii = torch.randint(self.height - i2, (1,)) + i2 + 1
+            X[i2:ii, jj1:jj2] = c[4]
+            f_X[i2:ii, jj1:jj2] = c[4]
+
+    def task_science_dot(self, A, f_A, B, f_B):
+        nb_rec = 3
+        c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                X[...] = 0
+                f_X[...] = 0
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                i, j = (
+                    torch.randint(self.height, (1,)).item(),
+                    torch.randint(self.width, (1,)).item(),
+                )
+                q = 0
+                for n in range(nb_rec):
+                    i1, j1, i2, j2 = r[n]
+                    X[i1:i2, j1:j2] = c[n]
+                    f_X[i1:i2, j1:j2] = c[n]
+                    if i >= i1 and i < i2:
+                        q += 1
+                        f_X[i, j1:j2] = c[-1]
+                    if j >= j1 and j < j2:
+                        q += 1
+                        f_X[i1:i2, j] = c[-1]
+                X[i, j] = c[-1]
+                f_X[i, j] = c[-1]
+                if q >= 2:
+                    break
+
+    def collide(self, s, r, rs):
+        i, j = r
+        for i2, j2 in rs:
+            if abs(i - i2) < s and abs(j - j2) < s:
+                return True
+        return False
+
+    def task_science_tag(self, A, f_A, B, f_B):
+        c = torch.randperm(self.nb_colors - 1)[:4] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            rs = []
+            while len(rs) < 4:
+                i, j = (
+                    torch.randint(self.height - 3, (1,)).item(),
+                    torch.randint(self.width - 3, (1,)).item(),
+                )
+                if not self.collide(s=3, r=(i, j), rs=rs):
+                    rs.append((i, j))
+
+            for k in range(len(rs)):
+                i, j = rs[k]
+                q = min(k, 2)
+                X[i, j : j + 3] = c[q]
+                X[i + 2, j : j + 3] = c[q]
+                X[i : i + 3, j] = c[q]
+                X[i : i + 3, j + 2] = c[q]
+
+                f_X[i, j : j + 3] = c[q]
+                f_X[i + 2, j : j + 3] = c[q]
+                f_X[i : i + 3, j] = c[q]
+                f_X[i : i + 3, j + 2] = c[q]
+                if q == 2:
+                    f_X[i + 1, j + 1] = c[-1]
+
+    # end_tasks
 
     ######################################################################
 
-    def trivial_prompts_and_answers(self, prompts, answers):
+    def create_empty_quizzes(self, nb, quad_order=("A", "f_A", "B", "f_B")):
         S = self.height * self.width
-        Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
-        f_Bs = answers
-        return (Bs == f_Bs).long().min(dim=-1).values > 0
+        quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64)
+        quizzes[:, 0 * (S + 1)] = self.l2tok[quad_order[0]]
+        quizzes[:, 1 * (S + 1)] = self.l2tok[quad_order[1]]
+        quizzes[:, 2 * (S + 1)] = self.l2tok[quad_order[2]]
+        quizzes[:, 3 * (S + 1)] = self.l2tok[quad_order[3]]
 
-    def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
-        if tasks is None:
-            tasks = self.all_tasks
+        return quizzes
 
+    def generate_w_quizzes_(self, nb, tasks=None, progress_bar=False):
         S = self.height * self.width
-        prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
-        answers = torch.zeros(nb, S, dtype=torch.int64)
 
-        bunch = zip(prompts, answers)
+        if tasks is None:
+            tasks = self.all_tasks
+
+        quizzes = torch.empty(nb, 4 * self.height * self.width, dtype=torch.int64)
 
         if progress_bar:
-            bunch = tqdm.tqdm(
-                bunch,
+            quizzes = tqdm.tqdm(
+                quizzes,
                 dynamic_ncols=True,
-                desc="world generation",
-                total=prompts.size(0),
+                desc="world quizzes generation",
+                total=quizzes.size(0),
             )
 
-        for prompt, answer in bunch:
-            A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
-            f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
-            B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
-            f_B = answer.view(self.height, self.width)
+        for quiz in quizzes:
+            q = quiz.reshape(4, self.height, self.width)
+            q[...] = 0
+            A, f_A, B, f_B = q
             task = tasks[torch.randint(len(tasks), (1,)).item()]
             task(A, f_A, B, f_B)
 
-        return prompts.flatten(1), answers.flatten(1)
+        return quizzes
 
-    def save_quiz_illustrations(
-        self,
-        result_dir,
-        filename_prefix,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
-        nrow=4,
-    ):
-        self.save_image(
-            result_dir,
-            filename_prefix + ".png",
-            prompts,
-            answers,
-            predicted_prompts,
-            predicted_answers,
-            nrow,
-        )
-
-    def save_some_examples(self, result_dir):
-        nb, nrow = 72, 4
+    def save_some_examples(self, result_dir, prefix=""):
+        nb, nrow = 256, 8
         for t in self.all_tasks:
             print(t.__name__)
-            prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
-            self.save_quiz_illustrations(
-                result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+            quizzes = self.generate_w_quizzes_(nb, tasks=[t])
+            self.save_quizzes_as_image(
+                result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow
             )
 
 
@@ -1137,41 +1830,120 @@ if __name__ == "__main__":
     import time
 
     # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
+
     grids = Grids()
 
-    # nb = 1000
-    # grids = problem.MultiThreadProblem(
-    # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
-    # )
-    #    time.sleep(10)
-    # start_time = time.perf_counter()
-    # prompts, answers = grids.generate_prompts_and_answers(nb)
-    # delay = time.perf_counter() - start_time
-    # print(f"{prompts.size(0)/delay:02f} seq/s")
-    # exit(0)
-
-    # if True:
-    nb, nrow = 72, 4
+    nb, nrow = 64, 4
     # nb, nrow = 8, 2
 
     # for t in grids.all_tasks:
-    for t in [grids.task_distance]:
+
+    for t in [
+        grids.task_replace_color,
+        grids.task_translate,
+        grids.task_grow,
+        grids.task_frame,
+    ]:
         print(t.__name__)
-        prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
-        grids.save_quiz_illustrations(
-            "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+        w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
+
+        # w_quizzes[:5] = torch.randint(grids.vocabulary_size(), w_quizzes[:5].size())
+
+        grids.save_quizzes_as_image(
+            "/tmp",
+            t.__name__ + ".png",
+            w_quizzes,
+            delta=True,
+            # grids=False
+            # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
         )
 
-    # exit(0)
+    exit(0)
+
+    q = grids.text2quiz(
+        """
+
+# the original
+
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+vvvvaaaaa. rrrraaaaa. ...aaa.... ...aaa....
+....aaaaa. ....aaaaa. .vvvvv.... .rrrrr....
+.......... .......... .vvvvvvvvv .rrrrroooo
+.......... .......... .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. .......... ..........
+vvvvaaaaa. rrrraaaaa. .......aaa .......aaa
+vvvvaaaaa. rrrraaaaa. .......aaa .......aaa
+....aaaaa. ....aaaaa. .vvvvv.aaa .rrrrr.aaa
+.......... .......... .vvvvvvvvv .rrrrroooo
+.......... .......... .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+#
+# so what
+#
+
+vvvv...... rrrr...... .......... ..........
+vvvv...... rrrr...... .......... ..........
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+
+vvvv...... rrrr...... .......... ..........
+vvvv...... rrrr...... .......... ..........
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
+.....aaaaa .....aaaaa .vvvvv.aaa .rrrrr.aaa
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+.....aaaaa .....aaaaa .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+....vvvvv. ....ooooo. .vvvvvvvvv .rrrrroooo
+"""
+    )
+
+    grids.save_quizzes_as_image("/tmp", "test.png", q, nrow=1, grids=False)
+
+    exit(0)
 
     nb = 1000
 
-    # for t in grids.all_tasks:
-    for t in [grids.task_distance]:
+    for t in [
+        # grids.task_bounce,
+        # grids.task_contact,
+        # grids.task_corners,
+        # grids.task_detect,
+        # grids.task_fill,
+        # grids.task_frame,
+        # grids.task_grow,
+        # grids.task_half_fill,
+        # grids.task_isometry,
+        # grids.task_path,
+        # grids.task_replace_color,
+        # grids.task_scale,
+        grids.task_symbols,
+        # grids.task_trajectory,
+        # grids.task_translate,
+    ]:
+        # for t in [grids.task_path]:
         start_time = time.perf_counter()
-        prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
+        w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
         delay = time.perf_counter() - start_time
-        print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
+        print(f"{t.__name__} {w_quizzes.size(0)/delay:02f} seq/s")
+        grids.save_quizzes_as_image("/tmp", t.__name__ + ".png", w_quizzes[:128])
 
     exit(0)
 
@@ -1179,9 +1951,9 @@ if __name__ == "__main__":
     predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
     predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
 
-    grids.save_quiz_illustrations(
+    grids.save_quizzes_as_image(
         "/tmp",
-        "test",
+        "test.png",
         prompts[:nb],
         answers[:nb],
         # You can add a bool to put a frame around the predicted parts
diff --git a/main.py b/main.py
index 6b00bbf..5dceefc 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -5,20 +5,21 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, sys, argparse, time, tqdm, os, datetime, warnings
+import math, sys, argparse, time, tqdm, os, datetime, warnings, copy
 
 import torch, torchvision
 from torch import nn
 from torch.nn import functional as F
 
-import ffutils
+import ffutils, grids, attae
 
-import mygpt
-import sky, grids, quiz_machine
+import threading, subprocess
 
-import threading
+# import torch.multiprocessing as mp
 
-import torch.multiprocessing as mp
+torch.set_float32_matmul_precision("high")
+
+# torch.set_default_dtype(torch.bfloat16)
 
 ######################################################################
 
@@ -34,29 +35,37 @@ parser.add_argument("--seed", type=int, default=0)
 
 parser.add_argument("--resume", action="store_true", default=False)
 
-parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
-
-########################################
+# ----------------------------------
 
 parser.add_argument("--nb_epochs", type=int, default=10000)
 
-parser.add_argument("--batch_size", type=int, default=None)
+parser.add_argument("--batch_size", type=int, default=25)
 
-parser.add_argument("--physical_batch_size", type=int, default=None)
+parser.add_argument("--train_batch_size", type=int, default=None)
 
-parser.add_argument("--nb_train_samples", type=int, default=None)
+parser.add_argument("--eval_batch_size", type=int, default=25)
 
-parser.add_argument("--nb_test_samples", type=int, default=None)
+parser.add_argument("--nb_train_samples", type=int, default=50000)
 
-parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
+parser.add_argument("--nb_test_samples", type=int, default=2500)
 
-parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
+parser.add_argument("--nb_c_quizzes", type=int, default=5000)
+
+parser.add_argument("--c_quiz_multiplier", type=int, default=1)
 
 parser.add_argument("--learning_rate", type=float, default=5e-4)
 
-########################################
+parser.add_argument("--nb_have_to_be_correct", type=int, default=3)
+
+parser.add_argument("--nb_have_to_be_wrong", type=int, default=1)
+
+parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5)
+
+# ----------------------------------
+
+parser.add_argument("--model_type", type=str, default="standard")
 
-parser.add_argument("--model", type=str, default=None)
+parser.add_argument("--model", type=str, default="37M")
 
 parser.add_argument("--dim_model", type=int, default=None)
 
@@ -68,29 +77,29 @@ parser.add_argument("--nb_heads", type=int, default=None)
 
 parser.add_argument("--nb_blocks", type=int, default=None)
 
-parser.add_argument("--dropout", type=float, default=0.1)
+parser.add_argument("--dropout", type=float, default=0.5)
 
-########################################
-
-parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
-
-parser.add_argument("--problem", type=str, default="grids")
+# ----------------------------------
 
 parser.add_argument("--nb_threads", type=int, default=1)
 
 parser.add_argument("--gpus", type=str, default="all")
 
-parser.add_argument("--nb_gpts", type=int, default=5)
+# ----------------------------------
+
+parser.add_argument("--nb_models", type=int, default=5)
+
+parser.add_argument("--diffusion_nb_iterations", type=int, default=25)
 
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
+parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05)
 
-parser.add_argument("--proba_understands", type=float, default=0.9)
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
 
-parser.add_argument("--proba_not_understands", type=float, default=0.5)
+parser.add_argument("--proba_prompt_noise", type=float, default=0.05)
 
-parser.add_argument("--generation_temperature", type=float, default=1.0)
+parser.add_argument("--proba_hint", type=float, default=0.25)
 
-parser.add_argument("--dirty_debug", action="store_true", default=False)
+parser.add_argument("--quizzes", type=str, default=None)
 
 ######################################################################
 
@@ -99,26 +108,14 @@ grids_tasks = ", ".join(
 )
 
 parser.add_argument(
-    "--grids_tasks",
+    "--grids_world_tasks",
     type=str,
-    default=None,
-    help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
+    default="replace_color,translate,grow,frame",
+    help="A comma-separated subset of: " + grids_tasks + ".",
 )
 
 ######################################################################
 
-parser.add_argument("--sky_height", type=int, default=6)
-
-parser.add_argument("--sky_width", type=int, default=8)
-
-parser.add_argument("--sky_nb_birds", type=int, default=3)
-
-parser.add_argument("--sky_nb_iterations", type=int, default=2)
-
-parser.add_argument("--sky_speed", type=int, default=3)
-
-######################################################################
-
 args = parser.parse_args()
 
 if args.result_dir is None:
@@ -126,19 +123,6 @@ if args.result_dir is None:
 
 ######################################################################
 
-default_args = {
-    "model": "37M",
-    "batch_size": 25,
-    "nb_train_samples": 100000,
-    "nb_test_samples": 10000,
-}
-
-for k, v in default_args.items():
-    if getattr(args, k) is None:
-        setattr(args, k, v)
-
-######################################################################
-
 default_model_args = {
     "17K": {
         "dim_model": 32,
@@ -187,8 +171,9 @@ else:
 ######################################################################
 
 if args.resume:
-    assert os.path.isdir(args.result_dir)
-
+    if not os.path.isdir(args.result_dir):
+        print(f"Trying to resume from a non-existing result dir {args.result_dir}.")
+        exit(1)
 else:
     try:
         os.mkdir(args.result_dir)
@@ -210,6 +195,9 @@ if args.seed >= 0:
 
 
 def log_string(s):
+    """print the given string prefixed with a time stamps, and log it
+    into log_file is not None"""
+
     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
 
     if log_file is not None:
@@ -220,9 +208,17 @@ def log_string(s):
     sys.stdout.flush()
 
 
+######################################################################
+# Create a time-stamped archive of the source code
+
+with open("this_run.sh", "w") as f:
+    f.write(f"{' '.join(sys.argv)}\n")
+
 now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
 
-os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py")
+os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
+
+######################################################################
 
 log_string(f"argv {' '.join(sys.argv)}")
 
@@ -245,422 +241,760 @@ else:
     assert len(gpus) == 0
     main_device = torch.device("cpu")
 
-if args.dirty_debug:
-    args.nb_train_samples = 2500
-    args.nb_test_samples = 100
-
-if args.physical_batch_size is None:
-    args.physical_batch_size = args.batch_size
+if args.train_batch_size is None:
+    args.train_batch_size = args.batch_size
 else:
-    assert args.batch_size % args.physical_batch_size == 0
+    assert args.batch_size % args.train_batch_size == 0
 
 assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
 
-if args.problem == "sky":
-    problem = sky.Sky(
-        height=args.sky_height,
-        width=args.sky_width,
-        nb_birds=args.sky_nb_birds,
-        nb_iterations=args.sky_nb_iterations,
-        speed=args.sky_speed,
-        max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
-        chunk_size=100,
-        nb_threads=args.nb_threads,
-    )
-    back_accuracy = False
-elif args.problem == "grids":
-    problem = grids.Grids(
-        max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
-        chunk_size=100,
-        nb_threads=args.nb_threads,
-        tasks=args.grids_tasks,
-    )
-    back_accuracy = True
-else:
-    raise ValueError
-
-problem.save_some_examples(args.result_dir)
-
-quiz_machine = quiz_machine.QuizMachine(
-    problem=problem,
-    nb_train_samples=args.nb_train_samples,
-    nb_test_samples=args.nb_test_samples,
-    back_accuracy=back_accuracy,
-    batch_size=args.physical_batch_size,
-    result_dir=args.result_dir,
-    logger=log_string,
-    device=main_device,
-)
+######################################################################
+
+
+def optimizer_to(optim, device):
+    """Move the optimizer optim to the device"""
+    for param in optim.state.values():
+        # Not sure there are any global tensors in the state dict
+        if isinstance(param, torch.Tensor):
+            param.data = param.data.to(device)
+            if param._grad is not None:
+                param._grad.data = param._grad.data.to(device)
+        elif isinstance(param, dict):
+            for subparam in param.values():
+                if isinstance(subparam, torch.Tensor):
+                    subparam.data = subparam.data.to(device)
+                    if subparam._grad is not None:
+                        subparam._grad.data = subparam._grad.data.to(device)
+
 
 ######################################################################
 
-log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
 
-vocabulary_size = quiz_machine.vocabulary_size()
+def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
+    if c_quizzes is None:
+        quizzes = problem.generate_w_quizzes(nb_samples)
+        nb_w_quizzes = quizzes.size(0)
+        nb_c_quizzes = 0
+    else:
+        if c_quiz_multiplier > 1:
+            n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
+            body = c_quizzes.repeat(n, 1)
+            if n < c_quiz_multiplier:
+                tail = c_quizzes[
+                    torch.randperm(c_quizzes.size(0))[: nb_samples // 2 - body.size(0)]
+                ]
+                c_quizzes = torch.cat([body, tail], dim=0)
+            else:
+                c_quizzes = body
+
+        if c_quizzes.size(0) > nb_samples // 2:
+            i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
+            c_quizzes = c_quizzes[i]
+
+        w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
+
+        quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
+        nb_w_quizzes = w_quizzes.size(0)
+        nb_c_quizzes = c_quizzes.size(0)
+
+    i = torch.randperm(quizzes.size(0), device=quizzes.device)
+    quizzes = quizzes[i].contiguous()
+
+    log_string(f"quiz_set nb_w_quizzes {nb_w_quizzes} nb_c_quizzes {nb_c_quizzes}")
+
+    return quizzes
 
-log_string(f"vocabulary_size {vocabulary_size}")
 
 ######################################################################
 
 
-def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_device):
-    with torch.autograd.no_grad():
-        model.eval().to(local_device)
+def add_hints_imt(imt_set):
+    """Set every component of the mask to zero with probability
+    args.proba_hint, and for each component set to zero, copy the
+    corresponding value from the target into the input
+
+    """
+    input, masks, targets = imt_set.unbind(dim=1)
+    # h = torch.rand(masks.size(), device=masks.device) - masks
+    # t = h.sort(dim=1).values[:, args.nb_hints, None]
+    # mask_hints = (h < t).long()
+    mask_hints = (
+        torch.rand(input.size(), device=input.device) < args.proba_hint
+    ).long() * masks
+    masks = (1 - mask_hints) * masks
+    input = (1 - mask_hints) * input + mask_hints * targets
+    return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
+def add_noise_imt(imt_set):
+    """Replace every component of the input by a random value with
+    probability args.proba_prompt_noise."""
+    input, masks, targets = imt_set.unbind(dim=1)
+    noise = problem.pure_noise(input.size(0), input.device)
+    change = (1 - masks) * (
+        torch.rand(input.size(), device=input.device) < args.proba_prompt_noise
+    ).long()
+    input = (1 - change) * input + change * noise
+    return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
 
-        nb_test_samples, acc_test_loss = 0, 0.0
-        nb_samples_accumulated = 0
+######################################################################
+# Prediction
 
-        for input in quiz_machine.batches(model, split="test"):
-            input = input.to(local_device)
 
-            bs = model(mygpt.BracketedSequence(input))
-            output = bs.x
+def samples_for_prediction_imt(input):
+    nb = input.size(0)
+    masks = input.new_zeros(input.size())
+    u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
+    masks.view(nb, 4, -1)[...] = u[:, :, None]
+    targets = input
+    input = (1 - masks) * targets
 
-            loss = F.cross_entropy(output.transpose(1, 2), input)
+    return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
-            acc_test_loss += loss.item() * input.size(0)
 
-            nb_test_samples += input.size(0)
+def ae_predict(model, imt_set, local_device=main_device):
+    model.eval().to(local_device)
 
-        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+    record = []
 
-        log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
+    src = tqdm.tqdm(
+        imt_set.split(args.eval_batch_size),
+        dynamic_ncols=True,
+        desc="predict",
+        total=imt_set.size(0) // args.eval_batch_size,
+        delay=10,
+    )
 
-        model.main_test_accuracy = quiz_machine.produce_results(
-            n_epoch=n_epoch,
-            model=model,
-            result_dir=args.result_dir,
-            deterministic_synthesis=deterministic_synthesis,
+    for imt in src:
+        # some paranoia
+        imt = imt.clone()
+        imt[:, 0] = imt[:, 0] * (1 - imt[:, 1])
+
+        with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+            logits = model(imt[:, 0] * 2 + imt[:, 1])
+        dist = torch.distributions.categorical.Categorical(logits=logits)
+        result = (1 - imt[:, 1]) * imt[:, 0] + imt[:, 1] * dist.sample()
+        record.append(result)
+
+    return torch.cat(record)
+
+
+def predict_the_four_grids(
+    model, input, with_noise=False, with_hints=False, local_device=main_device
+):
+    input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
+    nb = input.size(0)
+    masks = input.new_zeros(input.size())
+    u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4)
+    masks.view(nb, 4, -1)[...] = u[:, :, None]
+    targets = input
+    input = (1 - masks) * targets
+    imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+    if with_hints:
+        imt_set = add_hints_imt(imt_set)
+
+    if with_noise:
+        imt_set = add_noise_imt(imt_set)
+
+    result = ae_predict(model, imt_set, local_device=local_device)
+    result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
+
+    return result
+
+
+######################################################################
+
+
+def samples_for_generation_imt(input):
+    nb = input.size(0)
+    probs_iterations = 0.1 ** torch.linspace(
+        0, 1, args.diffusion_nb_iterations, device=input.device
+    )
+    probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+    probs_iterations = probs_iterations.expand(nb, -1)
+    dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
+    t = dist.sample() + 1
+    r = torch.rand(input.size(), device=input.device)
+    proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t
+    mask_erased = (r <= proba_erased[:, None]).long()
+
+    noise = problem.pure_noise(nb, input.device)
+    targets = input
+    input = (1 - mask_erased) * input + mask_erased * noise
+    masks = input.new_full(input.size(), 1)
+
+    return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
+def prioritized_rand(low):
+    x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
+    k = torch.rand(low.size(), device=low.device) + low.long()
+    k = k.sort(dim=1).indices
+    y = x.new(x.size())
+    y.scatter_(dim=1, index=k, src=x)
+    return y
+
+
+def ae_generate(model, nb, local_device=main_device):
+    model.eval().to(local_device)
+
+    # We loop through the iterations first and through the
+    # mini-batches second so that we keep only the samples that have
+    # not stabilized
+
+    all_input = problem.pure_noise(nb, local_device)
+    all_masks = all_input.new_full(all_input.size(), 1)
+    all_changed = torch.full((all_input.size(0),), True, device=all_input.device)
+
+    for it in range(args.diffusion_nb_iterations):
+        # log_string(f"nb_changed {all_changed.long().sum().item()}")
+
+        if not all_changed.any():
+            break
+
+        sub_input = all_input[all_changed].clone()
+        sub_masks = all_masks[all_changed].clone()
+        sub_changed = all_changed[all_changed].clone()
+
+        src = zip(
+            sub_input.split(args.eval_batch_size),
+            sub_masks.split(args.eval_batch_size),
+            sub_changed.split(args.eval_batch_size),
         )
 
+        for input, masks, changed in src:
+            with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+                logits = model(input * 2 + masks)
+            dist = torch.distributions.categorical.Categorical(logits=logits)
+            output = dist.sample()
+            r = prioritized_rand(input != output)
+            mask_changes = (r <= args.diffusion_proba_corruption).long() * masks
+            update = (1 - mask_changes) * input + mask_changes * output
+            changed[...] = changed & (update != input).max(dim=1).values
+            input[...] = update
 
-def one_epoch(model, quiz_machine, local_device=main_device):
-    model.to(local_device).train()
+        a = all_changed.clone()
+        all_input[a] = sub_input
+        all_masks[a] = sub_masks
+        all_changed[a] = sub_changed
 
-    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+    return all_input
 
-    nb_train_samples, acc_train_loss = 0, 0.0
 
-    for input in quiz_machine.batches(model, split="train"):
-        input = input.to(local_device)
+######################################################################
 
-        if nb_train_samples % args.batch_size == 0:
-            optimizer.zero_grad()
 
-        output = model(mygpt.BracketedSequence(input)).x
-        loss = F.cross_entropy(output.transpose(1, 2), input)
-        acc_train_loss += loss.item() * input.size(0)
+def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
+    quizzes = generate_quiz_set(
+        args.nb_train_samples if train else args.nb_test_samples,
+        c_quizzes,
+        args.c_quiz_multiplier,
+    )
+
+    q_p, q_g = quizzes.to(local_device).chunk(2)
+
+    # Half of the samples train the prediction, and we inject noise in
+    # all, and hints in half
+    b_p = samples_for_prediction_imt(q_p)
+    b_p = add_noise_imt(b_p)
+    half = torch.rand(b_p.size(0)) < 0.5
+    b_p[half] = add_hints_imt(b_p[half])
+
+    # The other half are denoising examples for the generation
+    b_g = samples_for_generation_imt(q_g)
+
+    imt_set = torch.cat([b_p, b_g])
+    imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
+
+    if train:
+        label = "train"
+        model.train().to(local_device)
+        optimizer_to(model.optimizer, local_device)
+        batch_size = args.train_batch_size
+    else:
+        label = "test"
+        model.eval().to(local_device)
+        batch_size = args.eval_batch_size
+
+    nb_samples, acc_loss = 0, 0.0
+
+    for imt in tqdm.tqdm(
+        imt_set.split(batch_size),
+        dynamic_ncols=True,
+        desc=label,
+        total=quizzes.size(0) // batch_size,
+        delay=10,
+    ):
+        input, masks, targets = imt.unbind(dim=1)
+        if train and nb_samples % args.batch_size == 0:
+            model.optimizer.zero_grad()
+
+        with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+            logits = model(input * 2 + masks)
+
+        loss_per_token = F.cross_entropy(
+            logits.transpose(1, 2), targets, reduction="none"
+        )
+        loss = (loss_per_token * masks).mean()
+        acc_loss += loss.item() * imt.size(0)
+        nb_samples += imt.size(0)
+
+        if train:
+            loss.backward()
 
-        nb_train_samples += input.size(0)
+            if nb_samples % args.batch_size == 0:
+                model.optimizer.step()
 
-        loss.backward()
+    log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
 
-        if nb_train_samples % args.batch_size == 0:
-            optimizer.step()
 
-    train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+######################################################################
 
-    log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}")
 
-    run_tests(model, quiz_machine, deterministic_synthesis=False)
+def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_device):
+    # Save some images of the prediction results
 
-    model.to(main_device)
+    quizzes = generate_quiz_set(150, c_quizzes, args.c_quiz_multiplier)
+    imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+    result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
+    masks = imt_set[:, 1].to("cpu")
+
+    correct = (quizzes == result).min(dim=1).values.long()
+    correct_parts = (2 * correct - 1)[:, None] * masks.reshape(masks.size(0), 4, -1)[
+        :, :, 1
+    ]
+    predicted_parts = correct_parts.abs()
+
+    problem.save_quizzes_as_image(
+        args.result_dir,
+        f"culture_prediction_{n_epoch}_{model.id}.png",
+        quizzes=result[:128],
+        predicted_parts=predicted_parts[:128],
+        correct_parts=correct_parts[:128],
+    )
+
+    # Save some images of the ex nihilo generation of the four grids
+
+    result = ae_generate(model, 150, local_device=local_device).to("cpu")
+    problem.save_quizzes_as_image(
+        args.result_dir,
+        f"culture_generation_{n_epoch}_{model.id}.png",
+        quizzes=result[:128],
+    )
 
 
 ######################################################################
 
-# This is the key routine that decides what generated quizzes to keep
 
+def one_complete_epoch(
+    model, n_epoch, train_c_quizzes, test_c_quizzes, local_device=main_device
+):
+    one_epoch(model, n_epoch, train_c_quizzes, train=True, local_device=local_device)
+
+    one_epoch(model, n_epoch, test_c_quizzes, train=False, local_device=local_device)
+
+    # Compute the test accuracy
+
+    quizzes = generate_quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier)
+    imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+    result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
+    correct = (quizzes == result).min(dim=1).values.long()
+
+    nb_correct, nb_total = correct.sum().item(), quizzes.size(0)
+    model.test_accuracy = nb_correct / nb_total
 
-# token_logprobas are NxMxT where M is the number of models
+    log_string(
+        f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy*100:.02f}%)"
+    )
 
+    save_inference_images(
+        model, n_epoch, c_quizzes, args.c_quiz_multiplier, local_device=local_device
+    )
 
-def compute_valid_quizzes_(token_logprobas):
-    warnings.warn("validation with uniform constraints", RuntimeWarning)
-    l = token_logprobas.min(dim=-1).values.sort(dim=-1).values
-    return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5))
 
+######################################################################
 
-def compute_valid_quizzes(token_logprobas):
-    l = token_logprobas.sum(dim=-1).sort(dim=-1).values
-    return (l[:, 0] < math.log(args.proba_not_understands)) & (
-        l[:, 1] > math.log(args.proba_understands)
+
+def max_nb_mistakes_on_one_grid(quizzes, prediction):
+    return (
+        (prediction != quizzes)
+        .long()
+        .reshape(quizzes.size(0), 4, -1)
+        .sum(dim=2)
+        .max(dim=1)
+        .values
     )
 
 
-def extract_valid_quizzes_and_logprobas(recorded):
-    validated_quizzes, validated_logprobas = [], []
-    for quizzes, token_logprobas in recorded:
-        validated_indices = compute_valid_quizzes(token_logprobas)
-        validated_quizzes.append(quizzes[validated_indices])
-        validated_logprobas.append(token_logprobas[validated_indices])
+def evaluate_quizzes(quizzes, models, with_hints, local_device):
+    nb_correct, nb_wrong = 0, 0
 
-    if len(validated_quizzes) > 0:
-        return torch.cat(validated_quizzes, dim=0), torch.cat(
-            validated_logprobas, dim=0
+    for model in models:
+        model = copy.deepcopy(model).to(local_device).eval()
+        predicted = predict_the_four_grids(
+            model=model,
+            input=quizzes,
+            with_noise=False,
+            with_hints=with_hints,
+            local_device=local_device,
         )
-    else:
-        return None, None
+        nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, predicted)
+        nb_correct += (nb_mistakes == 0).long()
+        nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long()
+
+    # print("\n\n", nb_correct, nb_wrong)
+
+    return nb_correct, nb_wrong
 
 
 ######################################################################
 
 
-def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
-    nb_to_create = nb_for_train + nb_for_test
+def identity_quizzes(quizzes):
+    quizzes = quizzes.reshape(quizzes.size(0), 4, -1)
+    return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values | (
+        quizzes[:, 2] == quizzes[:, 3]
+    ).min(dim=1).values
 
-    recorded_quizzes_logprobas = []
 
+def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
+    record = []
     nb_validated = 0
 
-    while nb_validated < nb_to_create:
-        model_for_generation = models[torch.randint(len(models), (1,))]
+    start_time = time.perf_counter()
+    last_log = -1
+
+    while nb_validated < nb_to_generate:
+        # Generate new quizzes
+
+        model = models[torch.randint(len(models), (1,)).item()]
+        model = copy.deepcopy(model).to(local_device).eval()
+        generator_id = model.id
 
-        c_quizzes = quiz_machine.generate_quizzes(
-            nb_to_create,
-            model_for_generation=model_for_generation,
-            temperature=args.generation_temperature,
+        c_quizzes = ae_generate(
+            model=model, nb=args.eval_batch_size * 10, local_device=local_device
         )
 
-        c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
+        c_quizzes = c_quizzes[identity_quizzes(c_quizzes) == False]
 
         if c_quizzes.size(0) > 0:
-            token_logproba = quiz_machine.solution_token_logprobas(models, c_quizzes)
-            recorded_quizzes_logprobas.append((c_quizzes, token_logproba))
+            # Select the ones that are solved properly by some models and
+            # not understood by others
+
+            nb_correct, nb_wrong = evaluate_quizzes(
+                quizzes=c_quizzes,
+                models=models,
+                with_hints=True,
+                local_device=local_device,
+            )
+
+            to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
+                nb_wrong >= args.nb_have_to_be_wrong
+            )
+
+            nb_validated += to_keep.long().sum().item()
+            record.append(c_quizzes[to_keep])
+
+        #####################
+
+        duration = time.perf_counter() - start_time
+
+        if last_log < 0 or duration > last_log + 10:
+            last_log = duration
+            if nb_validated > 0:
+                if nb_validated < nb_to_generate:
+                    d = (nb_to_generate - nb_validated) * duration / nb_validated
+                    e = (
+                        datetime.datetime.now() + datetime.timedelta(seconds=d)
+                    ).strftime("%a %H:%M")
+                else:
+                    e = "now!"
+            else:
+                e = "???"
+
+            log_string(
+                f"nb_validated {nb_validated} model {generator_id} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h)"
+            )
+
+        #####################
+
+    duration = time.perf_counter() - start_time
+
+    log_string(f"generate_c_quizz_speed {int(3600 * nb_validated / duration)}/h")
+
+    return torch.cat(record).to("cpu")
+
+
+######################################################################
+
+
+def multithread_execution(fun, arguments):
+    # Single instance, no thread
+    if len(arguments) == 1:
+        return fun(*(arguments[0]))
 
-            (
-                validated_quizzes,
-                validated_logprobas,
-            ) = extract_valid_quizzes_and_logprobas(recorded_quizzes_logprobas)
+    records, threads = [], []
 
-            if validated_quizzes is not None:
-                nb_validated = validated_quizzes.size(0)
+    def threadable_fun(*args):
+        r = fun(*args)
+        if type(r) is not tuple:
+            r = (r,)
+        records.append(r)
 
-        log_string(
-            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}"
+    for args in arguments:
+        # To get a different sequence between threads
+        log_string(f"dummy_rand {torch.rand(1)}")
+        # torch.rand(1)
+        t = threading.Thread(target=threadable_fun, daemon=True, args=args)
+        threads.append(t)
+        t.start()
+
+    for t in threads:
+        t.join()
+
+    if records[0] == (None,):
+        return
+    else:
+        return [
+            torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
+        ]
+
+
+######################################################################
+
+
+def save_models(models, suffix=""):
+    if suffix != "":
+        suffix = "_" + suffix
+
+    for model in models:
+        filename = f"ae_{model.id:03d}{suffix}.pth"
+        torch.save(
+            {
+                "state_dict": model.state_dict(),
+                "optimizer_state_dict": model.optimizer.state_dict(),
+                "test_accuracy": model.test_accuracy,
+            },
+            os.path.join(args.result_dir, filename),
         )
 
-    # store the new c_quizzes which have been validated
+    log_string(f"wrote ae_*{suffix}.pth")
+
+
+######################################################################
 
-    quiz_machine.reverse_random_half_in_place(validated_quizzes)
-    quiz_machine.store_c_quizzes(validated_quizzes[:nb_for_train], for_train=True)
-    quiz_machine.store_c_quizzes(
-        validated_quizzes[nb_for_train:nb_to_create], for_train=False
+
+def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
+    c_quizzes = c_quizzes.to(local_device)
+
+    nb_correct, nb_wrong = evaluate_quizzes(
+        quizzes=c_quizzes,
+        models=models,
+        with_hints=False,
+        local_device=local_device,
     )
 
-    ######################################################################
-    # save images with their logprobas
+    comments = [f"nb_correct {c} nb_wrong {w}" for c, w in zip(nb_correct, nb_wrong)]
 
-    vq = validated_quizzes[:72]
-    vl = validated_logprobas[:72]
+    problem.save_quizzes_as_image(
+        args.result_dir,
+        filename,
+        quizzes=c_quizzes,
+        comments=comments,
+        delta=True,
+        nrow=8,
+    )
 
-    if vq.size(0) > 0:
-        prefix = f"culture_c_quiz_{n_epoch:04d}"
-        filename = os.path.join(args.result_dir, prefix + "_logp.pth")
-        torch.save(vl, filename)
-        # with open(file_name, "w") as logp_file:
-        # for l in vl:
-        # s = " ".join([str(x.item()) for x in l])
-        # logp_file.write(s + "\n")
+    log_string(f"wrote {filename}")
 
-        quiz_machine.save_quiz_illustrations(args.result_dir, prefix, vq)
 
+######################################################################
+
+problem = grids.Grids(
+    max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+    chunk_size=100,
+    nb_threads=args.nb_threads,
+    tasks=args.grids_world_tasks,
+)
+
+if not args.resume:
+    problem.save_some_examples(args.result_dir)
+
+
+log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
+
+vocabulary_size = problem.vocabulary_size()
+
+log_string(f"vocabulary_size {vocabulary_size}")
 
 ######################################################################
 
 models = []
 
-for k in range(args.nb_gpts):
-    log_string(f"creating model {k} and its w_quizzes")
-    model = mygpt.MyGPT(
-        vocabulary_size=vocabulary_size,
+if args.model_type == "standard":
+    model_constructor = attae.AttentionAE
+elif args.model_type == "functional":
+    model_constructor = attae.FunctionalAttentionAE
+else:
+    raise ValueError(f"Unknown model type {args.model_type}")
+
+
+for i in range(args.nb_models):
+    model = model_constructor(
+        vocabulary_size=vocabulary_size * 2,
         dim_model=args.dim_model,
         dim_keys=args.dim_keys,
         dim_hidden=args.dim_hidden,
         nb_heads=args.nb_heads,
         nb_blocks=args.nb_blocks,
-        causal=True,
         dropout=args.dropout,
-    ).to(main_device)
+    )
 
-    model.main_test_accuracy = 0.0
-    model.id = k
+    # model = torch.compile(model)
 
-    model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples)
-    quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
-    model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
-    quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
+    model.id = i
+    model.test_accuracy = 0.0
+    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
 
     models.append(model)
 
 ######################################################################
 
-if args.resume:
-    try:
-        for model in models:
-            filename = f"gpt_{model.id:03d}.pth"
-
-            try:
-                d = torch.load(os.path.join(args.result_dir, filename))
-                model.load_state_dict(d[0])
-                model.main_test_accuracy = d[1]
-                log_string(f"successfully loaded {filename}")
-            except FileNotFoundError:
-                log_string(f"cannot find {filename}")
-                pass
-
-        try:
-            filename = "c_quizzes.pth"
-            quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename))
-            log_string(f"successfully loaded {filename}")
-        except FileNotFoundError:
-            log_string(f"cannot find {filename}")
-            pass
-
-    except:
-        log_string(f"error when loading {filename}.")
-        exit(1)
-
-######################################################################
+current_epoch = 0
 
-nb_parameters = sum(p.numel() for p in models[0].parameters())
-log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+if args.resume:
+    for model in models:
+        filename = f"ae_{model.id:03d}.pth"
 
-######################################################################
+        d = torch.load(
+            os.path.join(args.result_dir, filename),
+            map_location="cpu",
+            weights_only=False,
+        )
+        model.load_state_dict(d["state_dict"])
+        model.optimizer.load_state_dict(d["optimizer_state_dict"])
+        model.test_accuracy = d["test_accuracy"]
+        log_string(f"successfully loaded {filename}")
+
+    filename = "state.pth"
+    state = torch.load(
+        os.path.join(args.result_dir, filename),
+        map_location="cpu",
+        weights_only=False,
+    )
 
-# Compute the entropy of the training tokens
+    log_string(f"successfully loaded {filename}")
 
-token_count = 0
-for input in quiz_machine.batches(models[0], split="train", desc="train-entropy"):
-    token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
-        (0, 1)
-    )
-token_probas = token_count / token_count.sum()
-entropy = -torch.xlogy(token_probas, token_probas).sum()
-train_set_perplexity = math.exp(entropy)
+    current_epoch = state["current_epoch"]
+    train_c_quizzes = state["train_c_quizzes"]
+    test_c_quizzes = state["test_c_quizzes"]
 
 ######################################################################
-# A bit of paranoia never hurts
-
-if args.max_percents_of_test_in_train >= 0:
-
-    def subsets_as_tuples(batches, cs):
-        s = set()
-        for batch in batches:
-            for x in batch:
-                s.add(tuple([v.item() for v in x]))
-                if len(s) == cs:
-                    yield s
-                    s = set()
-        yield s
-
-    nb_test, nb_in_train = 0, 0
-    for test_subset in subsets_as_tuples(
-        quiz_machine.batches(models[0], split="test", desc="test-check"), 25000
-    ):
-        in_train = set()
-        for train_subset in subsets_as_tuples(
-            quiz_machine.batches(models[0], split="train", desc="train-check"), 25000
-        ):
-            in_train.update(test_subset.intersection(train_subset))
-        nb_in_train += len(in_train)
-        nb_test += len(test_subset)
 
-    log_string(
-        f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
-    )
+nb_parameters = sum(p.numel() for p in models[0].parameters())
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
-    assert (
-        nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
-    ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
 
 ######################################################################
 
-if args.nb_new_c_quizzes_for_train is None:
-    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50
-
-if args.nb_new_c_quizzes_for_test is None:
-    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50
-
-log_string(
-    f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
-)
+train_c_quizzes, test_c_quizzes = None, None
 
 ######################################################################
 
-if args.dirty_debug:
-    args.accuracy_to_make_c_quizzes = 0.0
-    args.nb_gpts = 2
-    args.nb_new_c_quizzes_for_train = 100
-    args.nb_new_c_quizzes_for_test = 10
+for n_epoch in range(current_epoch, args.nb_epochs):
+    start_time = time.perf_counter()
 
+    state = {
+        "current_epoch": n_epoch,
+        "train_c_quizzes": train_c_quizzes,
+        "test_c_quizzes": test_c_quizzes,
+    }
 
-######################################################################
+    filename = "state.pth"
+    torch.save(state, os.path.join(args.result_dir, filename))
+    log_string(f"wrote {filename}")
 
-for n_epoch in range(args.nb_epochs):
     log_string(f"--- epoch {n_epoch} ----------------------------------------")
 
-    cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
+    cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models])
     log_string(f"current_test_accuracies {cta}")
 
-    ##################################################
-    # If all the models are good enough, generate new quizzes and
-    # re-compute the test errors
+    # --------------------------------------------------------------------
 
-    if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
-        create_c_quizzes(
-            models,
-            quiz_machine,
-            nb_for_train=args.nb_new_c_quizzes_for_train,
-            nb_for_test=args.nb_new_c_quizzes_for_test,
-        )
-
-        filename = "c_quizzes.pth"
-        quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename))
-        log_string(f"wrote {filename}")
+    lowest_test_accuracy = min([float(m.test_accuracy) for m in models])
 
-        # Force one epoch of training
-        for model in models:
-            model.main_test_accuracy = 0.0
+    if lowest_test_accuracy >= args.accuracy_to_make_c_quizzes:
+        if train_c_quizzes is None:
+            save_models(models, "naive")
 
-    ##################################################
-    # Select, improve, and eval the worst model
+        nb_gpus = len(gpus)
+        nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
 
-    ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
+        (new_c_quizzes,) = multithread_execution(
+            generate_c_quizzes,
+            [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
+        )
 
-    weakest_models = ranked_models[: len(gpus)]
+        save_quiz_image(
+            models, new_c_quizzes[:256], f"culture_c_quiz_{n_epoch:04d}.png"
+        )
 
-    threads = []
+        log_string(f"generated_c_quizzes {new_c_quizzes.size()}")
 
-    for gpu, model in zip(gpus, weakest_models):
-        log_string(f"training model {model.id}")
+        train_c_quizzes = (
+            new_c_quizzes
+            if train_c_quizzes is None
+            else torch.cat([train_c_quizzes, new_c_quizzes])
+        )
+        train_c_quizzes = train_c_quizzes[-args.nb_train_samples :]
 
-        t = threading.Thread(
-            target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
+        nb_correct, _ = evaluate_quizzes(
+            quizzes=train_c_quizzes,
+            models=models,
+            with_hints=False,
+            local_device=local_device,
         )
 
-        threads.append(t)
+        test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct]
 
-        t.start()
+        for model in models:
+            model.test_accuracy = 0
 
-    for t in threads:
-        t.join()
+    if train_c_quizzes is None:
+        log_string("no_c_quiz")
+    else:
+        log_string(f"nb_c_quizzes {train_c_quizzes.size(0)}")
 
-    # Save the models to disk
+    # --------------------------------------------------------------------
 
-    for model in weakest_models:
-        filename = f"gpt_{model.id:03d}.pth"
-        torch.save(
-            (model.state_dict(), model.main_test_accuracy),
-            os.path.join(args.result_dir, filename),
-        )
-        log_string(f"wrote {filename}")
+    ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
+    weakest_models = ranked_models[: len(gpus)]
+
+    log_string(
+        f"weakest_accuracies {[model.test_accuracy for model in weakest_models]}"
+    )
 
-    # Renew the training samples
+    multithread_execution(
+        one_complete_epoch,
+        [
+            (model, n_epoch, train_c_quizzes, test_c_quizzes, gpu)
+            for model, gpu in zip(weakest_models, gpus)
+        ],
+    )
 
-    for model in weakest_models:
-        quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
+    save_models(models)
 
+    # --------------------------------------------------------------------
 
-######################################################################
+    duration = time.perf_counter() - start_time
+    str_duration = ""
+    if duration >= 60:
+        str_duration += f"{int(duration)//60}min"
+    str_duration += f"{int(duration)%60}s"
+    str_next = (
+        datetime.datetime.now() + datetime.timedelta(seconds=duration)
+    ).strftime("%H:%M:%S")
+    log_string(f"epoch_duration {str_duration} next_finish {str_next}")
diff --git a/mygpt.py b/mygpt.py
deleted file mode 100755 (executable)
index d0fda7e..0000000
--- a/mygpt.py
+++ /dev/null
@@ -1,339 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-# This is an implementation from scratch of a "GPT", that is a model
-# composed of several causal self-attention blocks. It is equipped
-# with a caching mechanism for keys and values to avoid a O(N^3) cost
-# for auto-regression.
-
-import math
-
-import torch
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-# A BracketedSequence is a BxTx... tensor with a first and a nb time
-# steps to compute.
-
-# Modules able to process it expect that they will have to process a
-# first bracket starting at t=0, followed by a succession of brackets
-# that move forward in time, do not overlap, and cover the axis T with
-# no holes.
-#
-# Although it is more general, for a classical prompt-conditioned
-# auto-regressive process it will be a first bracket starting at 0 and
-# of arbitrary length for the "prompt", followed by brackets of length
-# 1 for the successive tokens.
-#
-# Modules able to process brackets may implement a cache that is
-# resetted when the input bracket starts at t=0
-
-
-class BracketedSequence:
-    def __init__(self, x, first=None, nb=None):
-        self.x = x
-        self.first = 0 if first is None else first
-        self.nb = x.size(1) if nb is None else nb
-
-    def slice(self):
-        return self.x[:, self.first : self.first + self.nb]
-
-    def complete(self):
-        return self.first == 0 and self.nb == self.x.size(1)
-
-
-######################################################################
-
-
-class CacheWrapper(nn.Module):
-    def __init__(self, *f):
-        super().__init__()
-        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
-    def forward(self, bs):
-        if bs.first == 0:
-            y = self.f(bs.slice())
-            self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
-            self.cache_y[:, bs.first : bs.first + bs.nb] = y
-        else:
-            self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
-
-        return BracketedSequence(self.cache_y, bs.first, bs.nb)
-
-
-##############################
-
-
-class WithResidual(nn.Module):
-    def __init__(self, *f):
-        super().__init__()
-        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
-    def forward(self, bs):
-        return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
-
-
-##############################
-
-
-class AddPositionalEncoding(nn.Module):
-    def __init__(self, len_max):
-        super().__init__()
-        self.len_max = len_max
-
-    # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
-
-    def forward(self, bs):
-        if bs.first == 0:
-            t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
-                :, None
-            ]
-            j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
-                None, :
-            ]
-            k = j % 2
-            self.pe = torch.sin(
-                t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
-            )
-            self.cache_y = bs.x.new(bs.x.size())
-
-        self.cache_y[:, bs.first : bs.first + bs.nb] = (
-            bs.slice() + self.pe[bs.first : bs.first + bs.nb]
-        )
-
-        return BracketedSequence(self.cache_y, bs.first, bs.nb)
-
-
-##############################
-
-
-class QKVAttention(nn.Module):
-    def __init__(
-        self,
-        dim_in,
-        dim_qk,
-        dim_v,
-        nb_heads=1,
-        causal=False,
-        attention_dropout=0.0,
-    ):
-        super().__init__()
-
-        def randw(*d):
-            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
-
-        self.causal = causal
-        self.attention_dropout = attention_dropout
-        self.record_attention = False
-
-        self.w_q = randw(nb_heads, dim_qk, dim_in)
-        self.w_k = randw(nb_heads, dim_qk, dim_in)
-        self.w_v = randw(nb_heads, dim_v, dim_in)
-        self.w_o = randw(dim_v * nb_heads, dim_in)
-
-    def forward(self, bs_q):
-        x_q = bs_q.x
-
-        assert (
-            self.causal or bs_q.complete()
-        ), "Partial evaluation is only possible for causal models"
-
-        if bs_q.first == 0:
-            self.cache_k = x_q.new_zeros(
-                x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
-            )
-            self.cache_v = x_q.new_zeros(
-                x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
-            )
-            self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
-
-        q = torch.einsum(
-            "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
-        )
-
-        self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
-            "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
-        )
-        self.cache_v[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
-            "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_v
-        )
-
-        a = torch.einsum(
-            "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
-        ) / math.sqrt(self.w_q.size(1))
-
-        if self.causal:
-            if bs_q.first == 0:
-                self.cache_attzero = (
-                    torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
-                    < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
-                )
-            a = a.masked_fill(
-                self.cache_attzero[
-                    :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
-                ],
-                float("-inf"),
-            )
-
-        a = a.softmax(dim=3)
-
-        if self.record_attention:
-            self.a = a
-
-        a = F.dropout(a, self.attention_dropout, self.training)
-
-        y = torch.einsum(
-            "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_q.first + bs_q.nb]
-        ).flatten(2)
-
-        self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
-
-        return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
-
-
-##############################
-
-
-class NoiseInjector(nn.Module):
-    def __init__(self):
-        super().__init__()
-        self.noise_std = 0.0
-
-    def forward(self, x):
-        if self.noise_std > 0:
-            x = x + torch.randn(x.size(), device=x.device) * self.noise_std
-        return x
-
-
-def set_noise_injection(model, noise_std):
-    for m in model.modules():
-        if isinstance(m, NoiseInjector):
-            m.noise_std = noise_std
-
-
-##############################
-
-
-class MyGPT(nn.Module):
-    def __init__(
-        self,
-        vocabulary_size,
-        dim_model,
-        dim_keys,
-        dim_hidden,
-        nb_heads,
-        nb_blocks,
-        causal=False,
-        dropout=0.0,
-        len_max=1e5,
-    ):
-        super().__init__()
-
-        assert dim_model % nb_heads == 0
-
-        self.embedding = nn.Sequential(
-            CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
-            AddPositionalEncoding(len_max),
-        )
-
-        trunk_blocks = []
-
-        for b in range(nb_blocks):
-            trunk_blocks += [
-                WithResidual(
-                    CacheWrapper(
-                        nn.LayerNorm((dim_model,)),
-                        NoiseInjector(),
-                    ),
-                    QKVAttention(
-                        dim_in=dim_model,
-                        dim_qk=dim_keys,
-                        dim_v=dim_model // nb_heads,
-                        nb_heads=nb_heads,
-                        causal=causal,
-                        attention_dropout=dropout,
-                    ),
-                ),
-                WithResidual(
-                    CacheWrapper(
-                        nn.LayerNorm((dim_model,)),
-                        NoiseInjector(),
-                        nn.Linear(in_features=dim_model, out_features=dim_hidden),
-                        nn.ReLU(),
-                        nn.Linear(in_features=dim_hidden, out_features=dim_model),
-                        nn.Dropout(dropout),
-                    ),
-                ),
-            ]
-
-        self.trunk = nn.Sequential(*trunk_blocks)
-
-        self.readout = CacheWrapper(
-            nn.Linear(in_features=dim_model, out_features=vocabulary_size)
-        )
-
-        with torch.no_grad():
-            for m in self.modules():
-                if isinstance(m, nn.Embedding):
-                    m.weight.normal_(mean=0, std=2e-2)
-                elif isinstance(m, nn.LayerNorm):
-                    m.bias.zero_()
-                    m.weight.fill_(1.0)
-
-    def forward(self, bs):
-        # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
-        bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
-        bs = self.embedding(bs)
-        bs = self.trunk(bs)
-        bs = self.readout(bs)
-        return bs
-
-    def record_attention(self, v=True):
-        for m in self.modules():
-            if isinstance(m, QKVAttention):
-                m.record_attention = v
-
-    def retrieve_attention(self):
-        a = []
-        for m in self.modules():
-            if isinstance(m, QKVAttention):
-                a.append(m.a)
-        return a
-
-
-######################################################################
-
-if __name__ == "__main__":
-    print("Basic check.")
-
-    vocabulary_size = 3
-    x = torch.randint(vocabulary_size, (1, 5))
-
-    model = MyGPT(
-        vocabulary_size=vocabulary_size,
-        dim_model=4,
-        dim_keys=2,
-        dim_hidden=2,
-        nb_heads=2,
-        nb_blocks=2,
-        dropout=0.1,
-        causal=True,
-    )
-
-    model.eval()
-    y1 = model(BracketedSequence(x)).x
-    y2 = torch.randn_like(y1)
-    for s in range(x.size(1)):
-        z = model(BracketedSequence(x, s, 1))
-        y2[:, s] = z.slice()
-
-    print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
-
-######################################################################
index 05f3b20..8c1db63 100755 (executable)
@@ -25,14 +25,57 @@ class Problem:
         else:
             return self.queue.qsize() * self.chunk_size
 
-    def nb_token_values(self):
-        pass
+    def fill_cache(self):
+        while True:
+            quizzes = self.generate_w_quizzes_(self.chunk_size)
+            self.queue.put(quizzes.to("cpu"), block=True)
+
+    def generate_w_quizzes(self, nb, progress_bar=True):
+        if self.queue is None:
+            return self.generate_w_quizzes_(nb)
+
+        if self.rest is not None:
+            quizzes = rest
+        else:
+            quizzes = []
+
+        self.rest = None
+
+        n = sum([q.size(0) for q in quizzes])
+
+        if progress_bar:
+            with tqdm.tqdm(
+                total=nb, dynamic_ncols=True, desc="world generation", delay=10
+            ) as pbar:
+                while n < nb:
+                    q = self.queue.get(block=True)
+                    quizzes.append(q)
+                    n += q.size(0)
+                    pbar.update(q.size(0))
+        else:
+            while n < nb:
+                q = self.queue.get(block=True)
+                quizzes.append(q)
+                n += q.size(0)
+
+        quizzes = torch.cat(quizzes, dim=0)
+        assert n == quizzes.size(0)
+
+        k = n - nb
+
+        if k > 0:
+            rest = quizzes[-k:]
+            quizzes = quizzes[:-k]
+
+        return quizzes
+
+    ######################################################################
 
     def trivial_prompts_and_answers(self, prompts, answers):
         pass
 
     # The one to implement, returns two tensors nb x D and nb x D'
-    def generate_prompts_and_answers_(self, nb):
+    def generate_w_quizzes_(self, nb):
         pass
 
     # save a file to vizualize quizzes, you can save a txt or png file
@@ -47,47 +90,7 @@ class Problem:
     ):
         pass
 
-    def fill_cache(self):
-        while True:
-            prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
-
-            self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
-
-    def generate_prompts_and_answers(self, nb):
-        if self.queue is None:
-            return self.generate_prompts_and_answers_(nb)
-
-        if self.rest is not None:
-            prompts, answers = rest
-        else:
-            prompts, answers = [], []
-
-        self.rest = None
-
-        n = sum([p.size(0) for p in prompts])
-
-        with tqdm.tqdm(
-            total=nb,
-            dynamic_ncols=True,
-            desc="world generation",
-        ) as pbar:
-            while n < nb:
-                p, s = self.queue.get(block=True)
-                prompts.append(p)
-                answers.append(s)
-                n += p.size(0)
-                pbar.update(p.size(0))
-
-        prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
-        assert n == prompts.size(0)
-
-        k = n - nb
-
-        if k > 0:
-            rest = (prompts[-k:], answers[-k:])
-            prompts, answers = prompts[:-k], answers[:-k]
-
-        return prompts, answers
-
     def save_some_examples(self, result_dir):
         pass
+
+    ######################################################################
diff --git a/quiz_machine.py b/quiz_machine.py
deleted file mode 100755 (executable)
index bc468d3..0000000
+++ /dev/null
@@ -1,611 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, os, tqdm, warnings, sys
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-import mygpt
-from mygpt import BracketedSequence
-
-import threading
-
-######################################################################
-# if output is log(P(X=y)) and target is Y, returns -log P(X=Y) + H(X
-# | X != Y)
-
-
-# output is NxCxT and target is NxT
-def confusion(output, target, reduction="mean"):
-    N, C, T = output.shape
-    output = output.permute(0, 2, 1).reshape(-1, C)
-    target = target.flatten()
-    all_t = torch.arange(N * T, device=output.device)
-    output = output.log_softmax(dim=-1)
-    result = -output[all_t, target]
-
-    output[all_t, target] = float("-inf")
-    output = output.log_softmax(dim=-1)
-    e = output.exp()
-    output[all_t, target] = 0
-    result = result - (output * e).sum(-1)
-
-    if reduction == "none":
-        return result.reshape(N, T)
-    elif reduction == "mean":
-        return result.reshape(N, T).mean()
-    elif reduction == "sum":
-        return result.reshape(N, T).sum()
-    else:
-        raise ValueError(f"unknown reduction '{reduction}'.")
-
-
-######################################################################
-
-# ar_mask is a tensor with 0s and 1s, of same shape as input, with
-# 1s where tokens should be generated. The others are kept
-# unchanged.
-
-
-def one_batch_masked_inplace_autoregression(
-    model,
-    input,
-    ar_mask,
-    seq_logproba,
-    temperature,
-    deterministic_synthesis,
-):
-    to_generate = (ar_mask.sum(0) > 0).nonzero()
-
-    if to_generate.min() > 0:
-        model(
-            BracketedSequence(input, 0, to_generate.min())
-        )  # Needed to initialize the model's cache
-    for s in range(to_generate.min(), to_generate.max() + 1):
-        output = model(BracketedSequence(input, s, 1)).x
-
-        logits = output[:, s]
-
-        logits = (logits / temperature).log_softmax(dim=-1)
-
-        if deterministic_synthesis:
-            t_next = logits.argmax(-1)
-        else:
-            dist = torch.distributions.categorical.Categorical(logits=logits)
-            t_next = dist.sample()
-
-        all_n = torch.arange(t_next.size(0))
-
-        seq_logproba += logits[all_n, t_next]
-
-        input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
-
-
-def masked_inplace_autoregression(
-    model,
-    batch_size,
-    input,
-    ar_mask,
-    seq_logproba,
-    temperature,
-    deterministic_synthesis,
-    forbidden_tokens=None,
-    logit_biases=None,
-    progress_bar_desc=None,
-    device=torch.device("cpu"),
-):
-    assert input.size() == ar_mask.size()
-
-    batches = zip(
-        input.split(batch_size),
-        ar_mask.split(batch_size),
-        seq_logproba.split(batch_size),
-    )
-
-    if progress_bar_desc is not None:
-        batches = tqdm.tqdm(
-            batches,
-            dynamic_ncols=True,
-            desc=progress_bar_desc,
-            total=(input.size(0) + batch_size - 1) // batch_size,
-        )
-
-    with torch.autograd.no_grad():
-        t = model.training
-        model.eval()
-
-        for input, ar_mask, seq_logproba in batches:
-            one_batch_masked_inplace_autoregression(
-                model=model,
-                input=input,
-                ar_mask=ar_mask,
-                seq_logproba=seq_logproba,
-                temperature=temperature,
-                deterministic_synthesis=deterministic_synthesis,
-            )
-
-        model.train(t)
-
-
-######################################################################
-
-
-class QuizMachine:
-    def indices_forward_and_backward(self, quizzes):
-        i_forward = quizzes[:, 0] == self.token_forward
-        j_forward = quizzes[:, 1 + self.prompt_len] == self.token_forward
-        i_backward = quizzes[:, 0] == self.token_backward
-        j_backward = quizzes[:, 1 + self.answer_len] == self.token_backward
-        assert torch.logical_or(
-            torch.logical_and(i_forward, j_forward),
-            torch.logical_and(i_backward, j_backward),
-        ).all()
-        return i_forward, i_backward
-
-    def non_trivial(self, quizzes):
-        quizzes = quizzes.clone()
-        n_forward = quizzes[quizzes[:, 0] == self.token_forward]
-        n_backward = quizzes[:, 0] == self.token_backward
-        backward = quizzes[n_backward]
-        quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
-        return torch.logical_not(
-            self.problem.trivial_prompts_and_answers(
-                quizzes[:, 1 : 1 + self.prompt_len],
-                quizzes[:, 2 + self.prompt_len :],
-            )
-        )
-
-    def reverse_time(self, quizzes):
-        i_forward, i_backward = self.indices_forward_and_backward(quizzes)
-
-        forward_to_backward = torch.cat(
-            [
-                quizzes[:, 0:1],
-                quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len],
-                quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1],
-                quizzes[:, 1 : 1 + self.prompt_len],
-            ],
-            dim=1,
-        )
-
-        forward_to_backward[:, 0] = self.token_backward
-        forward_to_backward[:, 1 + self.answer_len] = self.token_backward
-
-        backward_to_forward = torch.cat(
-            [
-                quizzes[:, 0:1],
-                quizzes[:, 2 + self.answer_len :],
-                quizzes[:, 1 + self.answer_len : 2 + self.answer_len],
-                quizzes[:, 1 : 1 + self.answer_len],
-            ],
-            dim=1,
-        )
-
-        backward_to_forward[:, 0] = self.token_forward
-        backward_to_forward[:, 1 + self.prompt_len] = self.token_forward
-
-        m = i_forward.long()[:, None]
-
-        return m * forward_to_backward + (1 - m) * backward_to_forward
-
-    def reverse_random_half_in_place(self, quizzes):
-        i = torch.rand(quizzes.size(0)) < 0.5
-        if i.any():
-            quizzes[i] = self.reverse_time(quizzes[i])
-
-    def make_ar_mask(self, quizzes, first=False):
-        i_forward, i_backward = self.indices_forward_and_backward(quizzes)
-
-        t = torch.arange(quizzes.size(1), device=quizzes.device)
-
-        if first:
-            m_forward = (t >= 1).long() * (t < 1 + self.prompt_len).long()
-            m_backward = (t >= 1).long() * (t < 1 + self.answer_len).long()
-        else:
-            m_forward = (t >= 2 + self.prompt_len).long()
-            m_backward = (t >= 2 + self.answer_len).long()
-
-        m = i_forward.long()[:, None]
-
-        return m * m_forward + (1 - m) * m_backward
-
-    def generate_token_sequences(self, nb):
-        prompts, answers = self.problem.generate_prompts_and_answers(nb)
-
-        if self.prompt_len is None:
-            self.prompt_len = prompts.size(1)
-
-        if self.answer_len is None:
-            self.answer_len = answers.size(1)
-
-        assert prompts.size(1) == self.prompt_len and answers.size(1) == self.answer_len
-
-        result = []
-
-        for prompt, answer in zip(prompts, answers):
-            a = [
-                torch.tensor([self.token_forward]),
-                prompt,
-                torch.tensor([self.token_forward]),
-                answer,
-            ]
-
-            result.append(torch.cat(a, dim=0)[None, :])
-
-        return torch.cat(result, dim=0)
-
-    def __init__(
-        self,
-        problem,
-        nb_train_samples,
-        nb_test_samples,
-        back_accuracy,
-        batch_size,
-        result_dir,
-        logger,
-        device=torch.device("cpu"),
-    ):
-        super().__init__()
-
-        v = problem.nb_token_values()
-        self.token_forward = v
-        self.token_backward = v + 1
-        self.nb_token_values = v + 2
-
-        self.problem = problem
-        self.back_accuracy = back_accuracy
-        self.batch_size = batch_size
-        self.device = device
-        self.logger = logger
-        self.prompt_len = None
-        self.answer_len = None
-
-        self.LOCK_C_QUIZZES = threading.Lock()
-        self.train_c_quizzes = []
-        self.test_c_quizzes = []
-
-    def save_quiz_illustrations(
-        self,
-        result_dir,
-        filename_prefix,
-        quizzes,
-        mistakes=None,
-    ):
-        quizzes = quizzes.clone().to("cpu")
-        n_forward = quizzes[quizzes[:, 0] == self.token_forward]
-        n_backward = quizzes[:, 0] == self.token_backward
-        backward = quizzes[n_backward]
-        assert n_forward.size(0) + backward.size(0) == quizzes.size(0)
-        quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
-
-        predicted_prompts = n_backward.long()
-        predicted_answers = 1 - predicted_prompts
-        if mistakes is not None:
-            # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
-            predicted_prompts *= mistakes.to("cpu")
-            predicted_answers *= mistakes.to("cpu")
-        else:
-            # 0/2 ~ not-to-predict / to predict
-            predicted_prompts *= 2
-            predicted_answers *= 2
-
-        self.problem.save_quiz_illustrations(
-            result_dir,
-            filename_prefix,
-            quizzes[:, 1 : 1 + self.prompt_len],
-            quizzes[:, 2 + self.prompt_len :],
-            predicted_prompts,
-            predicted_answers,
-        )
-
-    def vocabulary_size(self):
-        return self.nb_token_values
-
-    ######################################################################
-
-    def batches(self, model, split="train", desc=None):
-        assert split in {"train", "test"}
-
-        with self.LOCK_C_QUIZZES:
-            if split == "train":
-                w_quizzes = model.train_w_quizzes
-                c_quizzes = self.train_c_quizzes
-            else:
-                w_quizzes = model.test_w_quizzes
-                c_quizzes = self.test_c_quizzes
-
-            if len(c_quizzes) > 0:
-                c_quizzes = torch.cat(c_quizzes, dim=0)
-                if c_quizzes.size(0) > w_quizzes.size(0) // 2:
-                    i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
-                    c_quizzes = c_quizzes[i]
-
-                i = torch.randperm(w_quizzes.size(0))[
-                    : w_quizzes.size(0) - c_quizzes.size(0)
-                ]
-                w_quizzes = w_quizzes[i]
-
-                self.nb_batch_w_quizzes = w_quizzes.size(0)
-                self.nb_batch_c_quizzes = c_quizzes.size(0)
-
-                input = torch.cat([w_quizzes, c_quizzes], dim=0)
-            else:
-                input = w_quizzes
-                self.nb_batch_w_quizzes = w_quizzes.size(0)
-                self.nb_batch_c_quizzes = 0
-
-        # Shuffle
-        input = input[torch.randperm(input.size(0))]
-
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=desc
-        ):
-            yield batch
-
-    ######################################################################
-
-    def produce_results(
-        self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
-    ):
-        def compute_accuracy(input, log_prefix=None):
-            input = input.to(self.device)
-            ar_mask = self.make_ar_mask(input)
-            result = input.clone() * (1 - ar_mask)
-            seq_logproba = torch.empty(input.size(0), device=self.device)
-
-            masked_inplace_autoregression(
-                model=model,
-                batch_size=self.batch_size,
-                input=result,
-                ar_mask=ar_mask,
-                seq_logproba=seq_logproba,
-                temperature=1.0,
-                deterministic_synthesis=deterministic_synthesis,
-                progress_bar_desc=None,
-                device=self.device,
-            )
-
-            correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device)
-
-            n_forward = input[:, 0] == self.token_forward
-            n_backward = input[:, 0] == self.token_backward
-
-            correct[n_forward] = (
-                (input[n_forward] == result[n_forward]).long().min(dim=1).values
-            )
-
-            if self.back_accuracy and n_backward.any():
-                # accuracy of B->A*->B*=B instead of B->A*=A
-                back_input = self.reverse_time(result[n_backward])
-                back_input[:, 2 + self.prompt_len :] = input[
-                    n_backward, 1 : 1 + self.answer_len
-                ]
-                _, correct[n_backward] = compute_accuracy(back_input)
-
-            if log_prefix is not None:
-                forward_nb_correct = correct[n_forward].sum()
-                forward_nb_total = correct[n_forward].size(0)
-                backward_nb_correct = correct[n_backward].sum()
-                backward_nb_total = correct[n_backward].size(0)
-
-                self.logger(
-                    f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}"
-                )
-
-            return result, correct
-
-        # compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
-
-        test_result, test_correct = compute_accuracy(
-            model.test_w_quizzes[:nmax], log_prefix="test"
-        )
-
-        main_test_accuracy = test_correct.sum() / test_correct.size(0)
-        self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
-
-        ##############################
-
-        self.save_quiz_illustrations(
-            result_dir,
-            f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
-            quizzes=test_result[:72],
-            mistakes=test_correct[:72] * 2 - 1,
-        )
-
-        return main_test_accuracy
-
-    ######################################################################
-
-    def renew_w_quizzes(self, model, nb, for_train=True):
-        input = model.train_w_quizzes if for_train else model.test_w_quizzes
-        nb = min(nb, input.size(0))
-        input[:-nb] = input[nb:].clone()
-        fresh_w_quizzes = self.generate_token_sequences(nb)
-        self.reverse_random_half_in_place(fresh_w_quizzes)
-        input[-nb:] = fresh_w_quizzes.to("cpu")
-
-    ######################################################################
-
-    def store_c_quizzes(self, new_c_quizzes, for_train=True):
-        with self.LOCK_C_QUIZZES:
-            if for_train:
-                self.train_c_quizzes.append(new_c_quizzes.to("cpu"))
-            else:
-                self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
-
-    def save_c_quizzes(self, filename):
-        torch.save((self.train_c_quizzes, self.test_c_quizzes), filename)
-
-    def load_c_quizzes(self, filename):
-        self.train_c_quizzes, self.test_c_quizzes = torch.load(filename)
-
-    ######################################################################
-
-    def solution_token_logprobas(self, models, c_quizzes):
-        logproba = c_quizzes.new_zeros(
-            c_quizzes.size(0),
-            len(models),
-            c_quizzes.size(1),
-            device=self.device,
-            dtype=torch.float32,
-        )
-
-        for model in models:
-            with torch.autograd.no_grad():
-                t = model.training
-                model.eval()
-
-                for input, l in zip(
-                    c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
-                ):
-                    input = input.to(self.device)
-                    ar_mask = self.make_ar_mask(input)
-                    output = model(mygpt.BracketedSequence(input)).x
-                    l[:, model.id] = (
-                        -F.cross_entropy(
-                            output.transpose(1, 2), input, reduction="none"
-                        )
-                        * ar_mask
-                    )
-
-                model.train(t)
-
-        return logproba.to("cpu")
-
-    ###############################################################
-
-    def compute_correctness(
-        self,
-        c_quizzes,
-        models_for_validation,
-        bidirectional_validation=False,
-        deterministic_validation=True,
-    ):
-        if bidirectional_validation:
-            backward_c_quizzes = self.forward_to_backward(c_quizzes)
-
-        seq_logproba = torch.zeros(
-            c_quizzes.size(0),
-            max([m.id for m in models_for_validation]) + 1,
-            device=self.device,
-        )
-
-        nb_correct = 0
-
-        seq_logproba[...] = 0.0
-
-        for model in models_for_validation:
-            result = c_quizzes.clone()
-
-            ar_mask = self.make_ar_mask(result)
-
-            masked_inplace_autoregression(
-                model=model,
-                batch_size=self.batch_size,
-                input=result,
-                ar_mask=ar_mask,
-                seq_logproba=seq_logproba[:, model.id],
-                temperature=1.0,
-                deterministic_synthesis=deterministic_validation,
-                # progress_bar_desc="solving c_quizzes",
-                device=self.device,
-            )
-
-            correct = (c_quizzes == result).long().min(dim=-1).values
-
-            if bidirectional_validation:
-                backward_result = backward_c_quizzes.clone()
-
-                ar_mask = self.make_ar_mask(backward_result)
-
-                masked_inplace_autoregression(
-                    model=model,
-                    batch_size=self.batch_size,
-                    input=backward_result,
-                    ar_mask=ar_mask,
-                    seq_logproba=seq_logproba[:, model.id],
-                    temperature=1.0,
-                    deterministic_synthesis=deterministic_validation,
-                    # progress_bar_desc="solving backward c_quizzes",
-                    device=self.device,
-                )
-
-                backward_correct = (
-                    (backward_c_quizzes == backward_result).long().min(dim=-1).values
-                )
-
-                correct *= backward_correct
-
-            # endif
-
-            nb_correct += correct
-
-        return nb_correct, seq_logproba
-
-    ###############################################################
-
-    def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
-        c_quizzes = torch.empty(
-            nb,
-            self.prompt_len + self.answer_len + 2,
-            device=self.device,
-            dtype=torch.int64,
-        )
-
-        seq_logproba = torch.zeros(nb, device=self.device)
-
-        # First, we generate the answer at high temperature
-
-        c_quizzes[:, 0] = self.token_backward
-        c_quizzes[:, 1 + self.answer_len] = self.token_backward
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=self.make_ar_mask(c_quizzes, first=True),
-            seq_logproba=seq_logproba,
-            temperature=temperature,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        # Then, we generate the prompt at low temperature
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=self.make_ar_mask(c_quizzes),
-            seq_logproba=seq_logproba,
-            temperature=1 / temperature,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        # Then we return the quizz, and re-generate the response, now
-        # at low temperature
-
-        c_quizzes = self.reverse_time(c_quizzes)
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=self.make_ar_mask(c_quizzes),
-            seq_logproba=seq_logproba,
-            temperature=1 / temperature,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        return c_quizzes.to("cpu")
diff --git a/report/culture.tex b/report/culture.tex
new file mode 100644 (file)
index 0000000..d9f39e3
--- /dev/null
@@ -0,0 +1,596 @@
+%% -*- mode: latex; mode: reftex; mode: flyspell; coding: utf-8; tex-command: "pdflatex.sh" -*-
+
+%% Any copyright is dedicated to the Public Domain.
+%% https://creativecommons.org/publicdomain/zero/1.0/
+%% Written by Francois Fleuret <francois@fleuret.org>
+
+\documentclass[11pt,a4paper,oneside]{article}
+\usepackage[paperheight=15cm,paperwidth=8cm,top=2mm,bottom=15mm,right=5mm,left=5mm]{geometry}
+%\usepackage[a4paper,top=2.5cm,bottom=2cm,left=2.5cm,right=2.5cm]{geometry}
+\usepackage[utf8]{inputenc}
+\usepackage{amsmath,amssymb,dsfont}
+\usepackage[pdftex]{graphicx}
+\usepackage[colorlinks=true,linkcolor=blue,urlcolor=blue,citecolor=blue]{hyperref}
+\urlstyle{same}
+\usepackage{tikz}
+\usetikzlibrary{arrows,arrows.meta,calc}
+\usetikzlibrary{patterns,backgrounds}
+\usetikzlibrary{positioning,fit}
+\usetikzlibrary{shapes.geometric,shapes.multipart}
+\usetikzlibrary{patterns.meta,decorations.pathreplacing,calligraphy}
+\usetikzlibrary{tikzmark}
+\usetikzlibrary{decorations.pathmorphing}
+\usepackage[round]{natbib}
+\usepackage[osf]{libertine}
+\usepackage{microtype}
+
+\usepackage{mleftright}
+
+\usepackage{enumitem}
+\setlist[itemize]{leftmargin=0pt,itemindent=1em,itemsep=2ex}
+\setlist{nosep} % or \setlist{noitemsep} to leave space around whole list
+
+\newcommand{\setmuskip}[2]{#1=#2\relax}
+\setmuskip{\thinmuskip}{1.5mu} % by default it is equal to 3 mu
+\setmuskip{\medmuskip}{2mu} % by default it is equal to 4 mu
+\setmuskip{\thickmuskip}{3.5mu} % by default it is equal to 5 mu
+
+\setlength{\parindent}{0cm}
+\setlength{\parskip}{1ex}
+%\renewcommand{\baselinestretch}{1.3}
+%\setlength{\tabcolsep}{0pt}
+%\renewcommand{\arraystretch}{1.0}
+
+\def\argmax{\operatornamewithlimits{argmax}}
+\def\argmin{\operatornamewithlimits{argmin}}
+
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+\def\given{\,\middle\vert\,}
+\def\proba{\operatorname{P}}
+\newcommand{\seq}{{S}}
+\newcommand{\expect}{\mathds{E}}
+\newcommand{\variance}{\mathds{V}}
+\newcommand{\empexpect}{\hat{\mathds{E}}}
+\newcommand{\mutinf}{\mathds{I}}
+\newcommand{\empmutinf}{\hat{\mathds{I}}}
+\newcommand{\entropy}{\mathds{H}}
+\newcommand{\empentropy}{\hat{\mathds{H}}}
+\newcommand{\ganG}{\mathbf{G}}
+\newcommand{\ganD}{\mathbf{D}}
+\newcommand{\ganF}{\mathbf{F}}
+
+\newcommand{\dkl}{\mathds{D}_{\mathsf{KL}}}
+\newcommand{\djs}{\mathds{D}_{\mathsf{JS}}}
+
+\newcommand*{\vertbar}{\rule[-1ex]{0.5pt}{2.5ex}}
+\newcommand*{\horzbar}{\rule[.5ex]{2.5ex}{0.5pt}}
+
+\def\positionalencoding{\operatorname{pos-enc}}
+\def\concat{\operatorname{concat}}
+\def\crossentropy{\LL_{\operatorname{ce}}}
+
+\newcommand{\separator}{\begin{center}
+*
+\end{center}}
+
+\newcommand{\pic}[2]{%
+\hspace*{\stretch{1}}
+%
+\includegraphics[scale=0.25]{#1}
+%
+\hspace*{\stretch{1}}%
+}
+
+\newcommand{\birdpic}[2]{%
+\hspace*{\stretch{1}}
+%
+\includegraphics[scale=0.35]{#1}
+%
+\hspace*{\stretch{1}}%
+}
+
+\newenvironment{example}{%
+
+\vspace*{2ex}
+
+\begin{minipage}{\textwidth}
+
+\setlength{\parindent}{0cm}
+\setlength{\parskip}{1ex}
+}{%
+\end{minipage}
+}
+
+\begin{document}
+
+\vspace*{-3ex}
+
+\begin{center}
+
+{\Large Self-Generated Culture}
+
+Fran\c cois Fleuret
+
+\today
+
+\vspace*{2ex}
+
+\centerline{\color{red}(work in progress, to be updated)}
+
+\medskip
+
+\centerline{\url{https://fleuret.org/public/culture/culture.pdf}}
+
+\end{center}
+
+\section{Introduction}
+
+The hypothesis behind this experiment is that high-level abstract
+thinking is fueled by social competition.
+
+A group of communicating agents that try to demonstrate their
+cognitive superiority would end up developing a rich and consistent
+culture.
+
+\subsection{Setup}
+
+The experiment is designed with a group of GPTs that alternatively
+learn to solve quizzes and generate new ones.
+
+A ``quiz'' is a pair composed of a prompt and a solution, both being
+sequence of tokens.
+
+We differentiate \textbf{world quizzes} that follow pre-defined and
+fixed regularities, and mimic the world's physical and environmental
+patterns that an organism has to grasp to survive, and \textbf{culture
+  quizzes} that are generated by the GPTs, and mimic the knowledge one
+has to master to perform socially.
+
+
+We train five GPTs on a a very large set of ``world quizzes''
+generated randomly. These models are trained to generate both the
+solution given the prompt, and the prompt given the solution.
+
+This is achieved by using for training both ``forward sequences'',
+composed of a token \texttt{[fwd]}, followed by the prompt's tokens,
+followed by another token \texttt{[fwd]}, followed by the solution's
+tokens, or ``backward sequences'' composed of a token \texttt{[bck]},
+followed by the solution's tokens, followed by another token
+\texttt{[bck]}, followed by the prompt's tokens,
+
+\subsection{Generating Culture Quizzes}
+
+When their accuracy get above $95\%$ we generate new quizzes as follows:
+%
+\begin{enumerate}
+
+\item generate a solution (without conditioning) at temperature $T=2$,
+  then generate a prompt for that solution at temperature $T=1/2$, and
+  then generate a solution for that prompt at temperature $T=1/2$.
+
+\item generate one solution for that prompt with each of the $5$ GPTs
+  at temperature $T=1$, if $4$ of them generate the correct solution,
+  validate that quiz and include it in the training data.
+
+\end{enumerate}
+
+This criterion assures that the new quizzes are both solvable and
+sophisticated, and incrementally complexify the culture. Imposing both
+direction prevents the generation of quizzes which are not trivial
+only because the prompt has been randomly degraded.
+
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+\pagebreak
+
+\section{Grid Quizzes}
+
+\subsection{World Quizzes}
+
+We define several types of quizzes and implement algorithmic
+procedures to generate randomly as many examples from each that we
+need.
+
+In these quizzes, the prompt is made of three grids $A, f(A), B$ and
+the solution is a single grid $f(B)$.
+
+\subsubsection{Half Fill}
+
+\pic{pics/task_color_grow.png}{``half fill''}
+
+The first grid contains three rectangles, each with a vertical or an
+horizontal line of another color in its middle. The second grid is
+identical with one of the rectangle having one half filled. The third
+grid contains three rectangles of identical colors as the firs grid,
+of different size and locations. The solution is obtained by filling
+similarly one of the half of a rectangle of the third image.
+
+\subsubsection{Detect}
+
+\pic{pics/task_detect.png}{``detect''}
+
+The first grid contains three rectangles, the second has two pixels of
+same colors located in the top-left corner of two of them. The
+solution is obtained by marking in the fourth image the top-left
+corners of the rectangles of same colors in the third.
+
+\subsubsection{Frame}
+
+\pic{pics/task_frame.png}{``frame''}
+
+The first grid contains three rectangles, and the second is identical
+except that one rectangle has been replaced by its frame. The same
+should be done to the similarly colored rectangles of the third grid
+to obtain the solution.
+
+\subsubsection{Grow}
+
+\pic{pics/task_grow.png}{``grow''}
+
+The first grid contains three rectangles, one of them getting one
+pixel thicker or thinner in the second. The same should be done to the
+similarly colored rectangles of the third grid to get the solution.
+
+\subsubsection{Replace color}
+
+\pic{pics/task_replace_color.png}{``replace color''}
+
+The first grid contains three rectangles, the second is obtained by
+changing one of the colors. The same should be done to the third grid
+to obtain the solution.
+
+\subsubsection{Translate}
+
+\pic{pics/task_translate.png}{``translate''}
+
+The first grid contains three rectangles. The second is obtained by
+displacing one of them by one pixel in both direction. The solution is
+obtained by applying the same motion to the similarly colored
+rectangle in the third grid.
+
+%% \subsubsection{Bounce}
+
+%% \pic{pics/task_bounce.png}{``bounce''}
+
+%% The solution should join the two pixels of same color, with a path of
+%% another color, starting in the direction indicated by a pixel of that
+%% color, and changing direction only when colliding with a pixel of a
+%% third color or one of the lattice border.
+
+%% \subsubsection{count}
+
+%% \pic{pics/task_count.png}{``count''}
+
+%% \subsubsection{scale}
+
+%% \pic{pics/task_scale.png}{``scale''}
+
+%% \subsubsection{trajectory}
+
+%% \pic{pics/task_trajectory.png}{``trajectory''}
+
+\subsection{Culture Quizzes}
+
+We list here some generated quizzes that exhibit features that were not present in the ``world quizzes'' used for training.
+
+\bigskip
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0110_N4_validated/quiz_63.png}{0110/63}
+
+\pic{pics/culture_c_quiz_0115_N4_validated/quiz_37.png}{0115/37}
+
+The quizzes ``frame'' and ``half fill'' have been combined in a single
+quiz.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0120_N4_validated/quiz_05.png}{0110/05}
+
+The ``frame'' quiz has been generalized to non-rectangular shapes.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_01.png}{0078/01}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_02.png}{0078/02}
+
+More rectangles were added as distractors.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0087_N4_validated/quiz_62.png}{0087/62}
+
+\pic{pics/culture_c_quiz_0102_N4_validated/quiz_04.png}{0102/04}
+
+\pic{pics/culture_c_quiz_0102_N4_validated/quiz_11.png}{0102/11}
+
+\pic{pics/culture_c_quiz_0108_N4_validated/quiz_31.png}{0108/31}
+
+Variation of ``Detect'' with location markers colored according to the
+color of the rectangle they mark.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_16.png}{0078/16}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_21.png}{0084/21}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_42.png}{0078/42}
+
+\pic{pics/culture_c_quiz_0089_N4_validated/quiz_28.png}{0089/28}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_00.png}{0084/00}
+
+Variations of ``Half Fill'', ``Detect'', ``Translate'', ``Grow'', and
+``Frame'' with a number of rectangles not equal to three.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_27.png}{0078/27}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_18.png}{0078/18}
+
+\pic{pics/culture_c_quiz_0086_N4_validated/quiz_45.png}{0086/45}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_37.png}{0078/37}
+
+Variations of ``Half Fill'' where the shapes to change have more
+complex coloring.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_30.png}{0078/30}
+
+Variation of ``Translate'' where the moving part is occluded, which
+was never the case.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_31.png}{0078/31}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_10.png}{0084/10}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_12.png}{0084/12}
+
+\pic{pics/culture_c_quiz_0086_N4_validated/quiz_23.png}{0086/23}
+
+\pic{pics/culture_c_quiz_0086_N4_validated/quiz_28.png}{0086/28}
+
+Variations of ``Half Fill'' with non-rectangular shapes.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0078_N4_validated/quiz_60.png}{0078/60}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_41.png}{0084/41}
+
+\pic{pics/culture_c_quiz_0084_N4_validated/quiz_49.png}{0084/49}
+
+\pic{pics/culture_c_quiz_0086_N4_validated/quiz_04.png}{0086/04}
+
+Variations of ``Half Fill'' with two colors or two rectangles have to
+be modified.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\pic{pics/culture_c_quiz_0111_N4_validated/quiz_23.png}{0111/23}
+
+Variation of ``Frame'' with no rectangle of adequate size to be
+modified.
+
+\end{example}
+
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+\pagebreak
+
+\section{Bird World}
+
+These results were obtained with a slightly different procedure. In
+particular the quizzes were validated if the models could predict both
+the solution from the prompt and the prompt from the solution. We
+report them since they exhibit the same patterns of generalization
+although they are quite different.
+
+\subsection{World Quizzes}
+
+The initial set of quizzes consist of predicting the dynamics of a
+very simple world: A $6 \times 8$ grid with three colored ``birds'' moving in
+a straight line, possibly bouncing on the grid's borders. There are
+ten different colors.
+%
+\birdpic{pics/examples_train.png}{}
+%
+
+In each on these quizzes, $A$ is the left image serialized in
+raster-scan order as a sequence of $6 \times 8 = 48$ tokens, $d$ is
+either the token ``forward'' or the token ``backward'', and $B$ is the
+right image, also serialized. The direction of prediction is chosen at
+random.
+
+\subsection{Culture quizzes}
+
+This procedure results in the discovery of patterns which are not
+present in the original quizzes:
+
+\begin{example}
+
+\birdpic{pics/4_birds_1.png}{}
+
+\birdpic{pics/5_birds_1.png}{}
+
+\birdpic{pics/6_birds_1.png}{}
+
+More birds.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\birdpic{pics/other_shapes_2.png}{}
+
+\birdpic{pics/other_shapes_3.png}{}
+
+New bird shapes.
+
+\end{example}
+
+\separator
+
+\begin{example}
+
+\birdpic{pics/other_shapes_1.png}{}
+
+\birdpic{pics/occlusions_1.png}{}
+
+Occlusions.
+
+\end{example}
+
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+\pagebreak
+
+\section{Various thoughts}
+
+\begin{itemize}
+
+\item The whole process can be envisioned as natural selection of
+  quizzes in the representation landscape of GPTs. There probably is a
+  subtle relation between the temperature (mutation rate) and the
+  number of models used to validate with the ``all but one'' criterion
+  (survival criterion).
+
+\item The ``all but one'' could be ``all but K'', and there may be
+  some information-theoretical thing, where the goal is to maximize
+  mutual information, with $K=N$ being total randomness, so high
+  entropy but no structure, and $K=0$ is total determinism, so no
+  information to share.
+
+\item The setup does not push toward any specific invariance or
+  property in the generated quizzes, their consistency is entirely due
+  to the statistics of the ``world quizzes'' that remain in the
+  training set, and to the GPTs' inductive biased.
+
+\item The GPTs obviously get a sense of objectness and 2d topology
+  early on, since they rapidly increase the number of birds and
+  ``discover'' occlusion even though they never was in the world
+  quizzes.
+
+\item There may not be so many problems that can be cast as pairs of
+  patterns that are each a deterministic function of the other, which
+  is probably critical here.
+
+\item This overall process probably fight the ``simplicity bias'': If
+  a model is lacking a ``cue'' that the others have, there will
+  rapidly be quizzes that require this cue, they will be added to the
+  training data, and that model will catch up.
+
+\item The randomness of the process probably allow to even go beyond
+  just synchronizing the abilities of the models. There may be some
+  additional complexification of quizzes that get accepted by chance.
+
+\item It can be parallelized by dispatching the GPTs across multiples
+  nodes, and avoiding a quadratic cost by limiting the validation of
+  the quizzes to a subset of them.
+
+\item The current process to generate new quizzes, which simply
+  samples them at random is very rudimentary and probably not
+  sufficient in a real-data setup. It can probably be supplemented
+  with a MCTS-type search.
+
+\item There may be already in the generated quizzes some structure
+  that \emph{we} do not pick up (e.g. certain color or motion
+  patterns).
+
+\end{itemize}
+
+\section*{Appendix}
+
+The code is available at
+
+\medskip
+
+\centerline{\url{https://fleuret.org/git/culture}}
+
+The experiments are done with a GTX 4090.
+
+The GPT used has 37M parameters and the following structure:
+
+\begin{center}
+\begin{tabular}{lc}
+    \texttt{dim\_model}  & 512  \\
+    \texttt{dim\_keys}   & 64   \\
+    \texttt{dim\_hidden} & 2048 \\
+    \texttt{nb\_heads}   & 8    \\
+    \texttt{nb\_blocks}  & 12
+\end{tabular}
+\end{center}
+
+Adam, $\eta = 1e-4$, no scheduling.
+
+There are $N_{\text{train}}=250'000$ original quizzes for training and
+$N_{\text{test}} = 10'000$ for test.
+
+At each epoch, for both train and test samples, we mix original
+quizzes and the generated ones.
+
+For training for instance, if there are less than $N_{\text{train}}/2$
+new quizzes, we take all of them, otherwise we sample
+$N_{\text{train}}/2$ of them without replacement, and then we sample
+without replacement enough original quizzes to get $N_{\text{train}}$
+samples in total.
+
+We proceed similarly to get $N_{\text{test}}$ samples for test.
+
+\end{document}
diff --git a/report/pics/4_birds_1.png b/report/pics/4_birds_1.png
new file mode 100644 (file)
index 0000000..961b95d
Binary files /dev/null and b/report/pics/4_birds_1.png differ
diff --git a/report/pics/5_birds_1.png b/report/pics/5_birds_1.png
new file mode 100644 (file)
index 0000000..09870c7
Binary files /dev/null and b/report/pics/5_birds_1.png differ
diff --git a/report/pics/6_birds_1.png b/report/pics/6_birds_1.png
new file mode 100644 (file)
index 0000000..3717298
Binary files /dev/null and b/report/pics/6_birds_1.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_01.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_01.png
new file mode 100644 (file)
index 0000000..23aeac0
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_01.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_02.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_02.png
new file mode 100644 (file)
index 0000000..d17d796
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_02.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_16.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_16.png
new file mode 100644 (file)
index 0000000..8727132
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_16.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_18.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_18.png
new file mode 100644 (file)
index 0000000..7eef189
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_18.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_27.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_27.png
new file mode 100644 (file)
index 0000000..3b8eb8d
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_27.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_30.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_30.png
new file mode 100644 (file)
index 0000000..8f93b14
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_30.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_31.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_31.png
new file mode 100644 (file)
index 0000000..7deacbf
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_31.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_37.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_37.png
new file mode 100644 (file)
index 0000000..762bd89
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_37.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_42.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_42.png
new file mode 100644 (file)
index 0000000..ef92823
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_42.png differ
diff --git a/report/pics/culture_c_quiz_0078_N4_validated/quiz_60.png b/report/pics/culture_c_quiz_0078_N4_validated/quiz_60.png
new file mode 100644 (file)
index 0000000..4ec6305
Binary files /dev/null and b/report/pics/culture_c_quiz_0078_N4_validated/quiz_60.png differ
diff --git a/report/pics/culture_c_quiz_0084_N4_validated/quiz_00.png b/report/pics/culture_c_quiz_0084_N4_validated/quiz_00.png
new file mode 100644 (file)
index 0000000..6272aeb
Binary files /dev/null and b/report/pics/culture_c_quiz_0084_N4_validated/quiz_00.png differ
diff --git a/report/pics/culture_c_quiz_0084_N4_validated/quiz_10.png b/report/pics/culture_c_quiz_0084_N4_validated/quiz_10.png
new file mode 100644 (file)
index 0000000..72cd85e
Binary files /dev/null and b/report/pics/culture_c_quiz_0084_N4_validated/quiz_10.png differ
diff --git a/report/pics/culture_c_quiz_0084_N4_validated/quiz_12.png b/report/pics/culture_c_quiz_0084_N4_validated/quiz_12.png
new file mode 100644 (file)
index 0000000..b7b386d
Binary files /dev/null and b/report/pics/culture_c_quiz_0084_N4_validated/quiz_12.png differ
diff --git a/report/pics/culture_c_quiz_0084_N4_validated/quiz_21.png b/report/pics/culture_c_quiz_0084_N4_validated/quiz_21.png
new file mode 100644 (file)
index 0000000..a4eeb42
Binary files /dev/null and b/report/pics/culture_c_quiz_0084_N4_validated/quiz_21.png differ
diff --git a/report/pics/culture_c_quiz_0084_N4_validated/quiz_41.png b/report/pics/culture_c_quiz_0084_N4_validated/quiz_41.png
new file mode 100644 (file)
index 0000000..3edad5a
Binary files /dev/null and b/report/pics/culture_c_quiz_0084_N4_validated/quiz_41.png differ
diff --git a/report/pics/culture_c_quiz_0084_N4_validated/quiz_49.png b/report/pics/culture_c_quiz_0084_N4_validated/quiz_49.png
new file mode 100644 (file)
index 0000000..ee10aea
Binary files /dev/null and b/report/pics/culture_c_quiz_0084_N4_validated/quiz_49.png differ
diff --git a/report/pics/culture_c_quiz_0086_N4_validated/quiz_04.png b/report/pics/culture_c_quiz_0086_N4_validated/quiz_04.png
new file mode 100644 (file)
index 0000000..fa153ae
Binary files /dev/null and b/report/pics/culture_c_quiz_0086_N4_validated/quiz_04.png differ
diff --git a/report/pics/culture_c_quiz_0086_N4_validated/quiz_23.png b/report/pics/culture_c_quiz_0086_N4_validated/quiz_23.png
new file mode 100644 (file)
index 0000000..aab6285
Binary files /dev/null and b/report/pics/culture_c_quiz_0086_N4_validated/quiz_23.png differ
diff --git a/report/pics/culture_c_quiz_0086_N4_validated/quiz_28.png b/report/pics/culture_c_quiz_0086_N4_validated/quiz_28.png
new file mode 100644 (file)
index 0000000..f0edfb2
Binary files /dev/null and b/report/pics/culture_c_quiz_0086_N4_validated/quiz_28.png differ
diff --git a/report/pics/culture_c_quiz_0086_N4_validated/quiz_45.png b/report/pics/culture_c_quiz_0086_N4_validated/quiz_45.png
new file mode 100644 (file)
index 0000000..bc9c9d5
Binary files /dev/null and b/report/pics/culture_c_quiz_0086_N4_validated/quiz_45.png differ
diff --git a/report/pics/culture_c_quiz_0087_N4_validated/quiz_62.png b/report/pics/culture_c_quiz_0087_N4_validated/quiz_62.png
new file mode 100644 (file)
index 0000000..6a3c50f
Binary files /dev/null and b/report/pics/culture_c_quiz_0087_N4_validated/quiz_62.png differ
diff --git a/report/pics/culture_c_quiz_0089_N4_validated/quiz_28.png b/report/pics/culture_c_quiz_0089_N4_validated/quiz_28.png
new file mode 100644 (file)
index 0000000..f460aba
Binary files /dev/null and b/report/pics/culture_c_quiz_0089_N4_validated/quiz_28.png differ
diff --git a/report/pics/culture_c_quiz_0102_N4_validated/quiz_04.png b/report/pics/culture_c_quiz_0102_N4_validated/quiz_04.png
new file mode 100644 (file)
index 0000000..ccfe8cc
Binary files /dev/null and b/report/pics/culture_c_quiz_0102_N4_validated/quiz_04.png differ
diff --git a/report/pics/culture_c_quiz_0102_N4_validated/quiz_11.png b/report/pics/culture_c_quiz_0102_N4_validated/quiz_11.png
new file mode 100644 (file)
index 0000000..c92f874
Binary files /dev/null and b/report/pics/culture_c_quiz_0102_N4_validated/quiz_11.png differ
diff --git a/report/pics/culture_c_quiz_0108_N4_validated/quiz_31.png b/report/pics/culture_c_quiz_0108_N4_validated/quiz_31.png
new file mode 100644 (file)
index 0000000..9a0c00e
Binary files /dev/null and b/report/pics/culture_c_quiz_0108_N4_validated/quiz_31.png differ
diff --git a/report/pics/culture_c_quiz_0110_N4_validated/quiz_63.png b/report/pics/culture_c_quiz_0110_N4_validated/quiz_63.png
new file mode 100644 (file)
index 0000000..00a0650
Binary files /dev/null and b/report/pics/culture_c_quiz_0110_N4_validated/quiz_63.png differ
diff --git a/report/pics/culture_c_quiz_0111_N4_validated/quiz_23.png b/report/pics/culture_c_quiz_0111_N4_validated/quiz_23.png
new file mode 100644 (file)
index 0000000..3ddd703
Binary files /dev/null and b/report/pics/culture_c_quiz_0111_N4_validated/quiz_23.png differ
diff --git a/report/pics/culture_c_quiz_0115_N4_validated/quiz_37.png b/report/pics/culture_c_quiz_0115_N4_validated/quiz_37.png
new file mode 100644 (file)
index 0000000..8507fa4
Binary files /dev/null and b/report/pics/culture_c_quiz_0115_N4_validated/quiz_37.png differ
diff --git a/report/pics/culture_c_quiz_0120_N4_validated/quiz_05.png b/report/pics/culture_c_quiz_0120_N4_validated/quiz_05.png
new file mode 100644 (file)
index 0000000..7a27afe
Binary files /dev/null and b/report/pics/culture_c_quiz_0120_N4_validated/quiz_05.png differ
diff --git a/report/pics/examples_train.png b/report/pics/examples_train.png
new file mode 100644 (file)
index 0000000..d1b349f
Binary files /dev/null and b/report/pics/examples_train.png differ
diff --git a/report/pics/occlusions_1.png b/report/pics/occlusions_1.png
new file mode 100644 (file)
index 0000000..28c39ba
Binary files /dev/null and b/report/pics/occlusions_1.png differ
diff --git a/report/pics/other_shapes_1.png b/report/pics/other_shapes_1.png
new file mode 100644 (file)
index 0000000..620fd45
Binary files /dev/null and b/report/pics/other_shapes_1.png differ
diff --git a/report/pics/other_shapes_2.png b/report/pics/other_shapes_2.png
new file mode 100644 (file)
index 0000000..fa1e3d4
Binary files /dev/null and b/report/pics/other_shapes_2.png differ
diff --git a/report/pics/other_shapes_3.png b/report/pics/other_shapes_3.png
new file mode 100644 (file)
index 0000000..5779ebb
Binary files /dev/null and b/report/pics/other_shapes_3.png differ
diff --git a/report/pics/task_bounce.png b/report/pics/task_bounce.png
new file mode 100644 (file)
index 0000000..d62d165
Binary files /dev/null and b/report/pics/task_bounce.png differ
diff --git a/report/pics/task_color_grow.png b/report/pics/task_color_grow.png
new file mode 100644 (file)
index 0000000..d872af2
Binary files /dev/null and b/report/pics/task_color_grow.png differ
diff --git a/report/pics/task_count.png b/report/pics/task_count.png
new file mode 100644 (file)
index 0000000..84da321
Binary files /dev/null and b/report/pics/task_count.png differ
diff --git a/report/pics/task_detect.png b/report/pics/task_detect.png
new file mode 100644 (file)
index 0000000..0beb8af
Binary files /dev/null and b/report/pics/task_detect.png differ
diff --git a/report/pics/task_frame.png b/report/pics/task_frame.png
new file mode 100644 (file)
index 0000000..8d3e015
Binary files /dev/null and b/report/pics/task_frame.png differ
diff --git a/report/pics/task_grow.png b/report/pics/task_grow.png
new file mode 100644 (file)
index 0000000..ce08e0d
Binary files /dev/null and b/report/pics/task_grow.png differ
diff --git a/report/pics/task_replace_color.png b/report/pics/task_replace_color.png
new file mode 100644 (file)
index 0000000..d6c9582
Binary files /dev/null and b/report/pics/task_replace_color.png differ
diff --git a/report/pics/task_scale.png b/report/pics/task_scale.png
new file mode 100644 (file)
index 0000000..b0e3820
Binary files /dev/null and b/report/pics/task_scale.png differ
diff --git a/report/pics/task_trajectory.png b/report/pics/task_trajectory.png
new file mode 100644 (file)
index 0000000..20b7c2b
Binary files /dev/null and b/report/pics/task_trajectory.png differ
diff --git a/report/pics/task_translate.png b/report/pics/task_translate.png
new file mode 100644 (file)
index 0000000..5f2bc2a
Binary files /dev/null and b/report/pics/task_translate.png differ
diff --git a/sky.py b/sky.py
deleted file mode 100755 (executable)
index cc5bd4f..0000000
--- a/sky.py
+++ /dev/null
@@ -1,364 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, sys, tqdm, os, warnings
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-import problem
-
-
-class Sky(problem.Problem):
-    colors = torch.tensor(
-        [
-            [255, 255, 255],
-            [255, 0, 0],
-            [0, 192, 0],
-            [0, 0, 255],
-            [255, 192, 0],
-            [0, 255, 255],
-            [255, 0, 255],
-            [192, 255, 192],
-            [255, 192, 192],
-            [192, 192, 255],
-            [192, 192, 192],
-        ]
-    )
-
-    token_background = 0
-    first_bird_token = 1
-    nb_bird_tokens = colors.size(0) - 1
-
-    token2char = (
-        "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
-    )
-
-    def __init__(
-        self,
-        height=6,
-        width=8,
-        nb_birds=3,
-        speed=2,
-        nb_iterations=2,
-        avoid_collision=True,
-        max_nb_cached_chunks=None,
-        chunk_size=None,
-        nb_threads=-1,
-    ):
-        super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
-        self.height = height
-        self.width = width
-        self.nb_birds = nb_birds
-        self.speed = speed
-        self.nb_iterations = nb_iterations
-        self.avoid_collision = avoid_collision
-
-    def generate_frame_sequences(self, nb):
-        frame_sequences = []
-
-        for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
-            i, j, vi, vj = (
-                torch.empty(self.nb_birds, dtype=torch.int64),
-                torch.empty(self.nb_birds, dtype=torch.int64),
-                torch.empty(self.nb_birds, dtype=torch.int64),
-                torch.empty(self.nb_birds, dtype=torch.int64),
-            )
-
-            def collision_okay():
-                if not self.avoid_collision:
-                    return True
-
-                count = torch.zeros(self.height, self.width, dtype=torch.int64)
-
-                for n in range(self.nb_birds):
-                    count[i[n], j[n]] += 1
-                    count[i[n] - vi[n], j[n]] += 1
-                    count[i[n], j[n] - vj[n]] += 1
-
-                return count.max() <= 1
-
-            col = (
-                torch.randperm(self.colors.size(0) - 1)[: self.nb_birds].sort().values
-                + 1
-            )
-
-            while True:
-                while True:
-                    for n in range(self.nb_birds):
-                        while True:
-                            i[n] = torch.randint(self.height, (1,))
-                            j[n] = torch.randint(self.width, (1,))
-                            vm = torch.randint(4, (1,))
-                            vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
-                            if (
-                                i[n] - vi[n] >= 0
-                                and i[n] - vi[n] < self.height
-                                and j[n] - vj[n] >= 0
-                                and j[n] - vj[n] < self.width
-                            ):
-                                break
-
-                    if collision_okay():
-                        break
-
-                result = torch.zeros(
-                    self.nb_iterations * self.speed,
-                    self.height,
-                    self.width,
-                    dtype=torch.int64,
-                )
-
-                fine = torch.empty(self.nb_iterations * self.speed)
-
-                t_to_keep = (
-                    torch.arange(self.nb_iterations, device=result.device) * self.speed
-                )
-
-                for l in range(self.nb_iterations * self.speed):
-                    fine[l] = collision_okay()
-                    for n in range(self.nb_birds):
-                        c = col[n]
-                        result[l, i[n], j[n]] = c
-                        result[l, i[n] - vi[n], j[n]] = c
-                        result[l, i[n], j[n] - vj[n]] = c
-
-                        if (i[n] == 0 and vi[n] == -1) or (
-                            i[n] == self.height - 1 and vi[n] == 1
-                        ):
-                            vi[n] = -vi[n]
-
-                        if (j[n] == 0 and vj[n] == -1) or (
-                            j[n] == self.width - 1 and vj[n] == 1
-                        ):
-                            vj[n] = -vj[n]
-
-                        i[n] += vi[n]
-                        j[n] += vj[n]
-
-                result = result[t_to_keep]
-                fine = fine[t_to_keep]
-
-                if fine[-1]:
-                    break
-
-            frame_sequences.append(result)
-
-        return frame_sequences
-
-    ######################################################################
-
-    def frame2img(self, x, scale=15):
-        x = x.reshape(x.size(0), self.height, -1)
-        m = torch.logical_and(
-            x >= 0, x < self.first_bird_token + self.nb_bird_tokens
-        ).long()
-        x = self.colors[x * m].permute(0, 3, 1, 2)
-        s = x.shape
-        x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
-        x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
-
-        x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
-        x[:, :, torch.arange(0, x.size(2), scale), :] = 0
-        x = x[:, :, 1:, 1:]
-
-        for n in range(m.size(0)):
-            for i in range(m.size(1)):
-                for j in range(m.size(2)):
-                    if m[n, i, j] == 0:
-                        for k in range(2, scale - 2):
-                            for l in [0, 1]:
-                                x[n, :, i * scale + k, j * scale + k - l] = 0
-                                x[
-                                    n, :, i * scale + scale - 1 - k, j * scale + k - l
-                                ] = 0
-
-        return x
-
-    def seq2str(self, seq):
-        result = []
-        for s in seq:
-            result.append("".join([self.token2char[v] for v in s]))
-        return result
-
-    def save_image(
-        self,
-        result_dir,
-        filename,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
-    ):
-        if predicted_prompts is None:
-            predicted_prompts = 255
-
-        if predicted_answers is None:
-            predicted_answers = 255
-
-        def add_frame(x, c, margin, bottom=False):
-            if bottom:
-                h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
-            else:
-                h, w, di, dj = (
-                    x.size(2) + 2 * margin,
-                    x.size(3) + 2 * margin,
-                    margin,
-                    margin,
-                )
-
-            y = x.new_full((x.size(0), x.size(1), h, w), 0)
-
-            if type(c) is int:
-                y[...] = c
-            else:
-                c = c.long()[:, None]
-                c = (
-                    (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
-                    + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
-                    + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
-                )
-                y[...] = c[:, :, None, None]
-
-            y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
-
-            return y
-
-        margin = 4
-
-        img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
-        h = img_prompts.size(2)
-        img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
-
-        img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
-        img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
-
-        img_prompts = add_frame(
-            img_prompts, c=predicted_prompts, margin=margin, bottom=True
-        )
-        img_answers = add_frame(
-            img_answers, c=predicted_answers, margin=margin, bottom=True
-        )
-
-        marker_size = 16
-
-        separator = img_prompts.new_full(
-            (
-                img_prompts.size(0),
-                img_prompts.size(1),
-                img_prompts.size(2),
-                marker_size,
-            ),
-            255,
-        )
-
-        separator[:, :, 0] = 0
-        separator[:, :, h - 1] = 0
-
-        for k in range(1, 2 * marker_size - 8):
-            i = k - (marker_size - 4)
-            j = marker_size - 5 - abs(i)
-            separator[:, :, h // 2 - 1 + i, 2 + j] = 0
-            separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
-
-        img = torch.cat([img_prompts, separator, img_answers], dim=3)
-
-        image_name = os.path.join(result_dir, filename)
-        torchvision.utils.save_image(
-            img.float() / 255.0, image_name, nrow=6, padding=margin * 4, pad_value=1.0
-        )
-
-    ######################################################################
-
-    def nb_token_values(self):
-        return len(self.colors)
-
-    def generate_prompts_and_answers(self, nb):
-        frame_sequences = self.generate_frame_sequences(nb)
-        frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
-
-        prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
-
-        answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
-
-        # warnings.warn("dirty test with longer answer", RuntimeWarning)
-        # answers = torch.cat(
-        # [
-        # frame_sequences[:, frame_sequences.size(1) // 2 :],
-        # frame_sequences[:, frame_sequences.size(1) // 2 :],
-        # ],
-        # dim=3,
-        # ).flatten(1)
-
-        return prompts, answers
-
-    def save_quiz_illustrations(
-        self,
-        result_dir,
-        filename_prefix,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
-    ):
-        self.save_image(
-            result_dir,
-            filename_prefix + ".png",
-            prompts,
-            answers,
-            predicted_prompts,
-            predicted_answers,
-        )
-
-
-######################################################################
-
-if __name__ == "__main__":
-    import time
-
-    sky = Sky(height=6, width=8, speed=1, nb_iterations=4)
-
-    prompts, answers = sky.generate_prompts_and_answers(4)
-
-    predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
-    predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
-
-    sky.save_quiz_illustrations(
-        "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
-    )
-
-    # start_time = time.perf_counter()
-    # token_sequences = sky.generate_token_sequences(nb=64)
-    # delay = time.perf_counter() - start_time
-    # print(f"{token_sequences.size(0)/delay:02f} seq/s")
-
-    # print(sky.seq2str(seq[:4]))
-
-    # for t in range(len(it[0])):
-    # img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
-    # torchvision.utils.save_image(
-    # img.float() / 255.0,
-    # f"/tmp/frame_{t:03d}.png",
-    # nrow=8,
-    # padding=6,
-    # pad_value=0,
-    # )
-
-    # m = (torch.rand(seq.size()) < 0.05).long()
-    # seq = (1 - m) * seq + m * 23
-
-    # print(seq.size())
-    # img = sky.seq2img(token_sequences)
-    # print(img.size())
-
-    # torchvision.utils.save_image(
-    # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
-    # )
diff --git a/wireworld.py b/wireworld.py
deleted file mode 100755 (executable)
index 8257cad..0000000
+++ /dev/null
@@ -1,357 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, sys, tqdm, os
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-import problem
-
-
-class Wireworld(problem.Problem):
-    colors = torch.tensor(
-        [
-            [128, 128, 128],
-            [128, 128, 255],
-            [255, 0, 0],
-            [255, 255, 0],
-        ]
-    )
-
-    token_empty = 0
-    token_head = 1
-    token_tail = 2
-    token_conductor = 3
-    token_forward = 4
-    token_backward = 5
-
-    token2char = (
-        "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
-    )
-
-    def __init__(
-        self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4
-    ):
-        self.height = height
-        self.width = width
-        self.nb_objects = nb_objects
-        self.nb_walls = nb_walls
-        self.speed = speed
-        self.nb_iterations = nb_iterations
-
-    def direction_tokens(self):
-        return self.token_forward, self.token_backward
-
-    def generate_frame_sequences(self, nb):
-        result = []
-        N = 100
-        for _ in tqdm.tqdm(
-            range(0, nb + N, N), dynamic_ncols=True, desc="world generation"
-        ):
-            result.append(self.generate_frame_sequences_hard(100))
-        return torch.cat(result, dim=0)[:nb]
-
-    def generate_frame_sequences_hard(self, nb):
-        frame_sequences = []
-        nb_frames = (self.nb_iterations - 1) * self.speed + 1
-
-        result = torch.full(
-            (nb * 4, nb_frames, self.height, self.width),
-            self.token_empty,
-        )
-
-        for n in range(result.size(0)):
-            while True:
-                i = torch.randint(self.height, (1,))
-                j = torch.randint(self.width, (1,))
-                v = torch.randint(2, (2,))
-                vi = v[0] * (v[1] * 2 - 1)
-                vj = (1 - v[0]) * (v[1] * 2 - 1)
-                while True:
-                    if i < 0 or i >= self.height or j < 0 or j >= self.width:
-                        break
-                    o = 0
-                    if i > 0:
-                        o += (result[n, 0, i - 1, j] == self.token_conductor).long()
-                    if i < self.height - 1:
-                        o += (result[n, 0, i + 1, j] == self.token_conductor).long()
-                    if j > 0:
-                        o += (result[n, 0, i, j - 1] == self.token_conductor).long()
-                    if j < self.width - 1:
-                        o += (result[n, 0, i, j + 1] == self.token_conductor).long()
-                    if o > 1:
-                        break
-                    result[n, 0, i, j] = self.token_conductor
-                    i += vi
-                    j += vj
-                if (
-                    result[n, 0] == self.token_conductor
-                ).long().sum() > self.width and torch.rand(1) < 0.5:
-                    break
-
-            while True:
-                for _ in range(self.height * self.width):
-                    i = torch.randint(self.height, (1,))
-                    j = torch.randint(self.width, (1,))
-                    v = torch.randint(2, (2,))
-                    vi = v[0] * (v[1] * 2 - 1)
-                    vj = (1 - v[0]) * (v[1] * 2 - 1)
-                    if (
-                        i + vi >= 0
-                        and i + vi < self.height
-                        and j + vj >= 0
-                        and j + vj < self.width
-                        and result[n, 0, i, j] == self.token_conductor
-                        and result[n, 0, i + vi, j + vj] == self.token_conductor
-                    ):
-                        result[n, 0, i, j] = self.token_head
-                        result[n, 0, i + vi, j + vj] = self.token_tail
-                        break
-
-                # if torch.rand(1) < 0.75:
-                break
-
-        weight = torch.full((1, 1, 3, 3), 1.0)
-
-        mask = (torch.rand(result[:, 0].size()) < 0.01).long()
-        rand = torch.randint(4, mask.size())
-        result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
-
-        # empty->empty
-        # head->tail
-        # tail->conductor
-        # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
-
-        nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1)
-        valid = nb_heads > 0
-
-        for l in range(nb_frames - 1):
-            nb_head_neighbors = (
-                F.conv2d(
-                    input=(result[:, l] == self.token_head).float()[:, None, :, :],
-                    weight=weight,
-                    padding=1,
-                )
-                .long()
-                .squeeze(1)
-            )
-            mask_1_or_2_heads = (nb_head_neighbors == 1).long() + (
-                nb_head_neighbors == 2
-            ).long()
-            result[:, l + 1] = (
-                (result[:, l] == self.token_empty).long() * self.token_empty
-                + (result[:, l] == self.token_head).long() * self.token_tail
-                + (result[:, l] == self.token_tail).long() * self.token_conductor
-                + (result[:, l] == self.token_conductor).long()
-                * (
-                    mask_1_or_2_heads * self.token_head
-                    + (1 - mask_1_or_2_heads) * self.token_conductor
-                )
-            )
-            pred_nb_heads = nb_heads
-            nb_heads = (
-                (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1)
-            )
-            valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads))
-
-        result = result[valid]
-
-        result = result[
-            :, torch.arange(self.nb_iterations, device=result.device) * self.speed
-        ]
-
-        i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
-        result = result[i]
-
-        # print(f"{result.size(0)=} {nb=}")
-
-        if result.size(0) < nb:
-            # print(result.size(0))
-            result = torch.cat(
-                [result, self.generate_frame_sequences(nb - result.size(0))], dim=0
-            )
-
-        return result[:nb]
-
-    def generate_token_sequences(self, nb):
-        frame_sequences = self.generate_frame_sequences(nb)
-
-        result = []
-
-        for frame_sequence in frame_sequences:
-            a = []
-            if torch.rand(1) < 0.5:
-                for frame in frame_sequence:
-                    if len(a) > 0:
-                        a.append(torch.tensor([self.token_forward]))
-                    a.append(frame.flatten())
-            else:
-                for frame in reversed(frame_sequence):
-                    if len(a) > 0:
-                        a.append(torch.tensor([self.token_backward]))
-                    a.append(frame.flatten())
-
-            result.append(torch.cat(a, dim=0)[None, :])
-
-        return torch.cat(result, dim=0)
-
-    ######################################################################
-
-    def frame2img(self, x, scale=15):
-        x = x.reshape(-1, self.height, self.width)
-        m = torch.logical_and(x >= 0, x < 4).long()
-
-        x = self.colors[x * m].permute(0, 3, 1, 2)
-        s = x.shape
-        x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
-        x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
-
-        x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
-        x[:, :, torch.arange(0, x.size(2), scale), :] = 0
-        x = x[:, :, 1:, 1:]
-
-        for n in range(m.size(0)):
-            for i in range(m.size(1)):
-                for j in range(m.size(2)):
-                    if m[n, i, j] == 0:
-                        for k in range(2, scale - 2):
-                            for l in [0, 1]:
-                                x[n, :, i * scale + k, j * scale + k - l] = 0
-                                x[
-                                    n, :, i * scale + scale - 1 - k, j * scale + k - l
-                                ] = 0
-
-        return x
-
-    def seq2img(self, seq, scale=15):
-        all = [
-            self.frame2img(
-                seq[:, : self.height * self.width].reshape(-1, self.height, self.width),
-                scale,
-            )
-        ]
-
-        separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0)
-
-        t = self.height * self.width
-
-        while t < seq.size(1):
-            direction_tokens = seq[:, t]
-            t += 1
-
-            direction_images = self.colors[
-                torch.full(
-                    (direction_tokens.size(0), self.height * scale - 1, scale), 0
-                )
-            ].permute(0, 3, 1, 2)
-
-            for n in range(direction_tokens.size(0)):
-                if direction_tokens[n] == self.token_forward:
-                    for k in range(scale):
-                        for l in [0, 1]:
-                            direction_images[
-                                n,
-                                :,
-                                (self.height * scale) // 2 - scale // 2 + k - l,
-                                3 + scale // 2 - abs(k - scale // 2),
-                            ] = 0
-                elif direction_tokens[n] == self.token_backward:
-                    for k in range(scale):
-                        for l in [0, 1]:
-                            direction_images[
-                                n,
-                                :,
-                                (self.height * scale) // 2 - scale // 2 + k - l,
-                                3 + abs(k - scale // 2),
-                            ] = 0
-                else:
-                    for k in range(2, scale - 2):
-                        for l in [0, 1]:
-                            direction_images[
-                                n,
-                                :,
-                                (self.height * scale) // 2 - scale // 2 + k - l,
-                                k,
-                            ] = 0
-                            direction_images[
-                                n,
-                                :,
-                                (self.height * scale) // 2 - scale // 2 + k - l,
-                                scale - 1 - k,
-                            ] = 0
-
-            all += [
-                separator,
-                direction_images,
-                separator,
-                self.frame2img(
-                    seq[:, t : t + self.height * self.width].reshape(
-                        -1, self.height, self.width
-                    ),
-                    scale,
-                ),
-            ]
-
-            t += self.height * self.width
-
-        return torch.cat(all, dim=3)
-
-    def seq2str(self, seq):
-        result = []
-        for s in seq:
-            result.append("".join([self.token2char[v] for v in s]))
-        return result
-
-    def save_image(self, input, result_dir, filename):
-        img = self.seq2img(input.to("cpu"))
-        image_name = os.path.join(result_dir, filename)
-        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
-
-    def save_quizzes(self, input, result_dir, filename_prefix):
-        self.save_image(input, result_dir, filename_prefix + ".png")
-
-
-######################################################################
-
-if __name__ == "__main__":
-    import time
-
-    wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1)
-
-    start_time = time.perf_counter()
-    frame_sequences = wireworld.generate_frame_sequences(nb=96)
-    delay = time.perf_counter() - start_time
-    print(f"{frame_sequences.size(0)/delay:02f} seq/s")
-
-    # print(wireworld.seq2str(seq[:4]))
-
-    for t in range(frame_sequences.size(1)):
-        img = wireworld.seq2img(frame_sequences[:, t])
-        torchvision.utils.save_image(
-            img.float() / 255.0,
-            f"/tmp/frame_{t:03d}.png",
-            nrow=8,
-            padding=6,
-            pad_value=0,
-        )
-
-    # m = (torch.rand(seq.size()) < 0.05).long()
-    # seq = (1 - m) * seq + m * 23
-
-    wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
-    token_sequences = wireworld.generate_token_sequences(32)
-    wireworld.save_quizzes(token_sequences, "/tmp", "seq")
-    # img = wireworld.seq2img(frame_sequences[:60])
-
-    # torchvision.utils.save_image(
-    # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1
-    # )