Merge branch 'dev'
authorFrançois Fleuret <francois@fleuret.org>
Tue, 9 Jul 2024 08:48:14 +0000 (11:48 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 9 Jul 2024 08:48:14 +0000 (11:48 +0300)
Seems to work. Maybe the generated quizzes are bit unstructured.

grids.py [new file with mode: 0755]
main.py
problem.py
quiz_machine.py [new file with mode: 0755]
quizz_machine.py [deleted file]
sky.py

diff --git a/grids.py b/grids.py
new file mode 100755 (executable)
index 0000000..47e5861
--- /dev/null
+++ b/grids.py
@@ -0,0 +1,852 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, sys, tqdm, os, warnings
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+import problem
+
+
+class Grids(problem.Problem):
+    named_colors = [
+        ("white", [255, 255, 255]),
+        ("red", [255, 0, 0]),
+        ("green", [0, 192, 0]),
+        ("blue", [0, 0, 255]),
+        ("yellow", [255, 224, 0]),
+        ("cyan", [0, 255, 255]),
+        ("violet", [224, 128, 255]),
+        ("lightgreen", [192, 255, 192]),
+        ("brown", [165, 42, 42]),
+        ("lightblue", [192, 192, 255]),
+        ("gray", [128, 128, 128]),
+    ]
+
+    def __init__(self, device=torch.device("cpu")):
+        self.colors = torch.tensor([c for _, c in self.named_colors])
+        self.height = 10
+        self.width = 10
+        self.device = device
+
+    ######################################################################
+
+    def frame2img(self, x, scale=15):
+        x = x.reshape(x.size(0), self.height, -1)
+        m = torch.logical_and(x >= 0, x < self.nb_token_values()).long()
+        x = self.colors[x * m].permute(0, 3, 1, 2)
+        s = x.shape
+        x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
+        x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
+
+        x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
+        x[:, :, torch.arange(0, x.size(2), scale), :] = 0
+        x = x[:, :, 1:, 1:]
+
+        for n in range(m.size(0)):
+            for i in range(m.size(1)):
+                for j in range(m.size(2)):
+                    if m[n, i, j] == 0:
+                        for k in range(2, scale - 2):
+                            for l in [0, 1]:
+                                x[n, :, i * scale + k, j * scale + k - l] = 0
+                                x[
+                                    n, :, i * scale + scale - 1 - k, j * scale + k - l
+                                ] = 0
+
+        return x
+
+    def save_image(
+        self,
+        result_dir,
+        filename,
+        prompts,
+        answers,
+        predicted_prompts=None,
+        predicted_answers=None,
+        nrow=4,
+        margin=8,
+    ):
+        S = self.height * self.width
+        As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width)
+        f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view(
+            -1, self.height, self.width
+        )
+        Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width)
+        prompts = torch.cat([As, f_As, Bs], dim=2)
+        answers = answers.reshape(answers.size(0), self.height, self.width)
+
+        if predicted_prompts is None:
+            predicted_prompts = 255
+
+        if predicted_answers is None:
+            predicted_answers = 255
+
+        def add_frame(x, c, margin, bottom=False):
+            if bottom:
+                h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
+            else:
+                h, w, di, dj = (
+                    x.size(2) + 2 * margin,
+                    x.size(3) + 2 * margin,
+                    margin,
+                    margin,
+                )
+
+            y = x.new_full((x.size(0), x.size(1), h, w), 0)
+
+            if type(c) is int:
+                y[...] = c
+            else:
+                c = c.long()[:, None]
+                c = (
+                    (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
+                    * torch.tensor([64, 64, 64], device=c.device)
+                    + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
+                    + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
+                    + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
+                )
+                y[...] = c[:, :, None, None]
+
+            y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
+
+            return y
+
+        img_prompts = torch.cat(
+            [
+                add_frame(
+                    add_frame(self.frame2img(x), c=0, margin=1),
+                    c=predicted_prompts,
+                    margin=margin,
+                )
+                for x in prompts.to("cpu").split(split_size=self.width, dim=2)
+            ],
+            dim=3,
+        )
+
+        h = img_prompts.size(2)
+        img_answers = add_frame(
+            add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1),
+            c=predicted_answers,
+            margin=margin,
+        )
+
+        separator_size = 2 * margin
+
+        separator = img_prompts.new_full(
+            (
+                img_prompts.size(0),
+                img_prompts.size(1),
+                img_prompts.size(2),
+                separator_size,
+            ),
+            255,
+        )
+
+        marker = img_prompts.new_full(
+            (
+                img_prompts.size(0),
+                img_prompts.size(1),
+                img_prompts.size(2),
+                separator_size,
+            ),
+            255,
+        )
+
+        # marker[:, :, 0] = 0
+        # marker[:, :, h - 1] = 0
+
+        for k in range(1, 2 * separator_size - 8):
+            i = k - (separator_size - 4)
+            j = separator_size - 5 - abs(i)
+            marker[:, :, h // 2 - 1 + i, 2 + j] = 0
+            marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+
+        img = torch.cat(
+            [
+                img_prompts,
+                marker,
+                img_answers,
+            ],
+            dim=3,
+        )
+
+        image_name = os.path.join(result_dir, filename)
+        torchvision.utils.save_image(
+            img.float() / 255.0,
+            image_name,
+            nrow=nrow,
+            padding=margin * 4,
+            pad_value=1.0,
+        )
+
+    ######################################################################
+
+    def nb_token_values(self):
+        return len(self.colors)
+
+    # @torch.compile
+    def rec_coo_(self, nb_rec, min_height=3, min_width=3):
+        # @torch.compile
+        def overlap(ia, ja, ib, jb):
+            return (
+                ia[1] >= ib[0] and ia[0] <= ib[1] and ja[1] >= jb[0] and ja[0] <= jb[1]
+            )
+
+        if nb_rec == 3:
+            while True:
+                i = torch.randint(self.height + 1, (nb_rec, 2)).sort(dim=1).values
+                j = torch.randint(self.width + 1, (nb_rec, 2)).sort(dim=1).values
+                if (
+                    not (
+                        overlap(i[0], j[0], i[1], j[1])
+                        or overlap(i[0], j[0], i[2], j[2])
+                        or overlap(i[1], j[1], i[2], j[2])
+                    )
+                    and (i[:, 1] - i[:, 0]).min() >= min_height
+                    and (j[:, 1] - j[:, 0]).min() >= min_width
+                ):
+                    break
+            return (
+                (i[0, 0], j[0, 0], i[0, 1], j[0, 1]),
+                (i[1, 0], j[1, 0], i[1, 1], j[1, 1]),
+                (i[2, 0], j[2, 0], i[2, 1], j[2, 1]),
+            )
+
+    # That's quite a tensorial spaghetti mess to sample
+    # non-overlapping rectangles quickly, but made the generation of
+    # 100k samples go from 1h50 with a lame pure python code to 3min30s
+    # with this one.
+    # @torch.compile
+    def rec_coo(self, nb_rec, min_height=3, min_width=3):
+        nb_trials = 200
+
+        while True:
+            v = (
+                (
+                    torch.rand(nb_trials * nb_rec, self.height + 1, device=self.device)
+                    .sort(dim=-1)
+                    .indices
+                    < 2
+                )
+                .long()
+                .cumsum(dim=1)
+                == 1
+            ).long()
+
+            h = (
+                (
+                    torch.rand(nb_trials * nb_rec, self.width + 1, device=self.device)
+                    .sort(dim=-1)
+                    .indices
+                    < 2
+                )
+                .long()
+                .cumsum(dim=1)
+                == 1
+            ).long()
+
+            i = torch.logical_and(
+                v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width
+            )
+
+            v, h = v[i], h[i]
+            v = v[: v.size(0) - v.size(0) % nb_rec]
+            h = h[: h.size(0) - h.size(0) % nb_rec]
+            v = v.reshape(v.size(0) // nb_rec, nb_rec, -1)
+            h = h.reshape(h.size(0) // nb_rec, nb_rec, -1)
+
+            r = v[:, :, :, None] * h[:, :, None, :]
+
+            valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1
+
+            v = v[valid]
+            h = h[valid]
+
+            if v.size(0) > 0:
+                break
+
+        av = torch.arange(v.size(2), device=self.device)[None, :]
+        ah = torch.arange(h.size(2), device=self.device)[None, :]
+
+        return [
+            (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1)
+            for i1, j1, i2, j2 in zip(
+                v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values,
+                h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values,
+                (v[0] * av).max(dim=-1).values,
+                (h[0] * ah).max(dim=-1).values,
+            )
+        ]
+
+    # @torch.compile
+    def rec_coo_(self, x, n, min_height=3, min_width=3):
+        collision = x.new(x.size())
+        while True:
+            collision[...] = 0
+            result = []
+            for _ in range(n):
+                while True:
+                    i1, i2 = torch.randint(x.size(0), (2,))
+                    if i1 + min_height <= i2:
+                        break
+                while True:
+                    j1, j2 = torch.randint(x.size(1), (2,))
+                    if j1 + min_width <= j2:
+                        break
+                collision[i1:i2, j1:j2] += 1
+                if collision.max() > 1:
+                    break
+                result.append((i1, j1, i2, j2))
+            if collision.max() == 1:
+                break
+        return result
+
+    ######################################################################
+
+    # @torch.compile
+    def task_replace_color(self, A, f_A, B, f_B):
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            r = self.rec_coo(nb_rec)
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
+
+    # @torch.compile
+    def task_translate(self, A, f_A, B, f_B):
+        di, dj = torch.randint(3, (2,)) - 1
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                r = self.rec_coo(nb_rec)
+                i1, j1, i2, j2 = r[nb_rec - 1]
+                if (
+                    i1 + di >= 0
+                    and i2 + di < X.size(0)
+                    and j1 + dj >= 0
+                    and j2 + dj < X.size(1)
+                ):
+                    break
+
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                if n == nb_rec - 1:
+                    f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n]
+                else:
+                    f_X[i1:i2, j1:j2] = c[n]
+
+    # @torch.compile
+    def task_grow(self, A, f_A, B, f_B):
+        di, dj = torch.randint(2, (2,)) * 2 - 1
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+        direction = torch.randint(2, (1,))
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                r = self.rec_coo(nb_rec)
+                i1, j1, i2, j2 = r[nb_rec - 1]
+                if i1 + 3 < i2 and j1 + 3 < j2:
+                    break
+
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                if n == nb_rec - 1:
+                    if direction == 0:
+                        X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
+                        f_X[i1:i2, j1:j2] = c[n]
+                    else:
+                        X[i1:i2, j1:j2] = c[n]
+                        f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n]
+                else:
+                    X[i1:i2, j1:j2] = c[n]
+                    f_X[i1:i2, j1:j2] = c[n]
+
+    # @torch.compile
+    def task_color_grow(self, A, f_A, B, f_B):
+        di, dj = torch.randint(2, (2,)) * 2 - 1
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
+        direction = torch.randint(4, (1,))
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            r = self.rec_coo(nb_rec)
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[2 * n]
+                f_X[i1:i2, j1:j2] = c[2 * n]
+                # Not my proudest moment
+                if direction == 0:
+                    i = (i1 + i2) // 2
+                    X[i : i + 1, j1:j2] = c[2 * n + 1]
+                    if n == nb_rec - 1:
+                        f_X[i:i2, j1:j2] = c[2 * n + 1]
+                    else:
+                        f_X[i : i + 1, j1:j2] = c[2 * n + 1]
+                elif direction == 1:
+                    i = (i1 + i2 - 1) // 2
+                    X[i : i + 1, j1:j2] = c[2 * n + 1]
+                    if n == nb_rec - 1:
+                        f_X[i1 : i + 1, j1:j2] = c[2 * n + 1]
+                    else:
+                        f_X[i : i + 1, j1:j2] = c[2 * n + 1]
+                elif direction == 2:
+                    j = (j1 + j2) // 2
+                    X[i1:i2, j : j + 1] = c[2 * n + 1]
+                    if n == nb_rec - 1:
+                        f_X[i1:i2, j:j2] = c[2 * n + 1]
+                    else:
+                        f_X[i1:i2, j : j + 1] = c[2 * n + 1]
+                elif direction == 3:
+                    j = (j1 + j2 - 1) // 2
+                    X[i1:i2, j : j + 1] = c[2 * n + 1]
+                    if n == nb_rec - 1:
+                        f_X[i1:i2, j1 : j + 1] = c[2 * n + 1]
+                    else:
+                        f_X[i1:i2, j : j + 1] = c[2 * n + 1]
+
+    # @torch.compile
+    def task_frame(self, A, f_A, B, f_B):
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            r = self.rec_coo(nb_rec)
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                f_X[i1:i2, j1:j2] = c[n]
+                if n == nb_rec - 1:
+                    f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
+
+    # @torch.compile
+    def task_detect(self, A, f_A, B, f_B):
+        nb_rec = 3
+        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            r = self.rec_coo(nb_rec)
+            for n in range(nb_rec):
+                i1, j1, i2, j2 = r[n]
+                X[i1:i2, j1:j2] = c[n]
+                if n < nb_rec - 1:
+                    f_X[i1, j1] = c[-1]
+
+    # @torch.compile
+    def contact(self, X, i, j, q):
+        nq, nq_diag = 0, 0
+        no = 0
+
+        for ii, jj in [
+            (i - 1, j - 1),
+            (i - 1, j),
+            (i - 1, j + 1),
+            (i, j - 1),
+            (i, j + 1),
+            (i + 1, j - 1),
+            (i + 1, j),
+            (i + 1, j + 1),
+        ]:
+            if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
+                if X[ii, jj] != 0 and X[ii, jj] != q:
+                    no += 1
+
+        for ii, jj in [
+            (i - 1, j - 1),
+            (i - 1, j + 1),
+            (i + 1, j - 1),
+            (i + 1, j + 1),
+        ]:
+            if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
+                if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q:
+                    nq_diag += 1
+
+        for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]:
+            if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width:
+                if X[ii, jj] == q:
+                    nq += 1
+
+        return no, nq, nq_diag
+
+    # @torch.compile
+    def task_count(self, A, f_A, B, f_B):
+        N = (torch.randint(4, (1,)) + 2).item()
+        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            nb = torch.zeros(N, dtype=torch.int64)
+            q = torch.randint(N, (self.height * self.width,))
+            k = torch.randperm(self.height * self.width)
+            for p in range(self.height * self.width):
+                i, j = k[p] % self.height, k[p] // self.height
+                no, nq, nq_diag = self.contact(X, i, j, c[q[p]])
+                if no == 0 and nq_diag == 0:
+                    if nq == 0:
+                        if nb[q[p]] < self.width:
+                            X[i, j] = c[q[p]]
+                            nb[q[p]] += 1
+                    if nq == 1:
+                        X[i, j] = c[q[p]]
+
+            for n in range(N):
+                for j in range(nb[n]):
+                    f_X[n, j] = c[n]
+
+    # @torch.compile
+    def task_trajectory(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[:2] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                di, dj = torch.randint(7, (2,)) - 3
+                i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
+                if (
+                    abs(di) + abs(dj) > 0
+                    and i + 2 * di >= 0
+                    and i + 2 * di < self.height
+                    and j + 2 * dj >= 0
+                    and j + 2 * dj < self.width
+                ):
+                    break
+
+            k = 0
+            while (
+                i + k * di >= 0
+                and i + k * di < self.height
+                and j + k * dj >= 0
+                and j + k * dj < self.width
+            ):
+                if k < 2:
+                    X[i + k * di, j + k * dj] = c[k]
+                f_X[i + k * di, j + k * dj] = c[min(k, 1)]
+                k += 1
+
+    # @torch.compile
+    def task_bounce(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[:3] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            # @torch.compile
+            def free(i, j):
+                return (
+                    i >= 0
+                    and i < self.height
+                    and j >= 0
+                    and j < self.width
+                    and f_X[i, j] == 0
+                )
+
+            while True:
+                f_X[...] = 0
+                X[...] = 0
+
+                for _ in range((self.height * self.width) // 10):
+                    i, j = torch.randint(self.height, (1,)), torch.randint(
+                        self.width, (1,)
+                    )
+                    X[i, j] = c[0]
+                    f_X[i, j] = c[0]
+
+                while True:
+                    di, dj = torch.randint(7, (2,)) - 3
+                    if abs(di) + abs(dj) == 1:
+                        break
+
+                i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
+
+                X[i, j] = c[1]
+                f_X[i, j] = c[1]
+                l = 0
+
+                while True:
+                    l += 1
+                    if free(i + di, j + dj):
+                        pass
+                    elif free(i - dj, j + di):
+                        di, dj = -dj, di
+                        if free(i + dj, j - di):
+                            if torch.rand(1) < 0.5:
+                                di, dj = -di, -dj
+                    elif free(i + dj, j - di):
+                        di, dj = dj, -di
+                    else:
+                        break
+
+                    i, j = i + di, j + dj
+                    f_X[i, j] = c[2]
+                    if l <= 1:
+                        X[i, j] = c[2]
+
+                    if l >= self.width:
+                        break
+
+                f_X[i, j] = c[1]
+                X[i, j] = c[1]
+
+                if l > 3:
+                    break
+
+    # @torch.compile
+    def task_scale(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[:2] + 1
+
+        i, j = torch.randint(self.height // 2, (1,)), torch.randint(
+            self.width // 2, (1,)
+        )
+
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            for _ in range(3):
+                while True:
+                    i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
+                        self.width // 2 + 1, (1,)
+                    )
+                    i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
+                        self.width // 2 + 1, (1,)
+                    )
+                    if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
+                        break
+                X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
+                f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
+
+            X[i, j] = c[1]
+            f_X[0:2, 0:2] = c[1]
+
+    # @torch.compile
+    def task_symbols(self, A, f_A, B, f_B):
+        nb_rec = 4
+        c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
+        delta = 3
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                i, j = torch.randint(self.height - delta + 1, (nb_rec,)), torch.randint(
+                    self.width - delta + 1, (nb_rec,)
+                )
+                d = (i[None, :] - i[:, None]).abs().max((j[None, :] - j[:, None]).abs())
+                d.fill_diagonal_(delta + 1)
+                if d.min() > delta:
+                    break
+
+            for k in range(1, nb_rec):
+                X[i[k] : i[k] + delta, j[k] : j[k] + delta] = c[k]
+
+            ai, aj = i.float().mean(), j.float().mean()
+
+            q = torch.randint(3, (1,)) + 1
+
+            X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
+            X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
+            X[i[0] + delta // 2 + 1, j[0] + delta // 2 - 1] = c[0]
+            X[i[0] + delta // 2 + 1, j[0] + delta // 2 + 1] = c[0]
+
+            assert i[q] != ai and j[q] != aj
+
+            X[
+                i[0] + delta // 2 + (i[q] - ai).sign().long(),
+                j[0] + delta // 2 + (j[q] - aj).sign().long(),
+            ] = c[nb_rec]
+
+            f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
+
+    # @torch.compile
+    def task_ortho(self, A, f_A, B, f_B):
+        nb_rec = 3
+        di, dj = torch.randint(3, (2,)) - 1
+        o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
+        m = torch.eye(2)
+        for _ in range(torch.randint(4, (1,))):
+            m = m @ o
+        if torch.rand(1) < 0.5:
+            m[0, :] = -m[0, :]
+
+        ci, cj = (self.height - 1) / 2, (self.width - 1) / 2
+
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                X[...] = 0
+                f_X[...] = 0
+
+                c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
+
+                for r in range(nb_rec):
+                    while True:
+                        i1, i2 = torch.randint(self.height - 2, (2,)) + 1
+                        j1, j2 = torch.randint(self.width - 2, (2,)) + 1
+                        if (
+                            i2 >= i1
+                            and j2 >= j1
+                            and max(i2 - i1, j2 - j1) >= 2
+                            and min(i2 - i1, j2 - j1) <= 3
+                        ):
+                            break
+                    X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
+
+                    i1, j1, i2, j2 = i1 - ci, j1 - cj, i2 - ci, j2 - cj
+
+                    i1, j1 = m[0, 0] * i1 + m[0, 1] * j1, m[1, 0] * i1 + m[1, 1] * j1
+                    i2, j2 = m[0, 0] * i2 + m[0, 1] * j2, m[1, 0] * i2 + m[1, 1] * j2
+
+                    i1, j1, i2, j2 = i1 + ci, j1 + cj, i2 + ci, j2 + cj
+                    i1, i2 = i1.long() + di, i2.long() + di
+                    j1, j2 = j1.long() + dj, j2.long() + dj
+                    if i1 > i2:
+                        i1, i2 = i2, i1
+                    if j1 > j2:
+                        j1, j2 = j2, j1
+
+                    f_X[i1 : i2 + 1, j1 : j2 + 1] = c[r]
+
+                n = F.one_hot(X.flatten()).sum(dim=0)[1:]
+                if (
+                    n.sum() > self.height * self.width // 4
+                    and (n > 0).long().sum() == nb_rec
+                ):
+                    break
+
+    # @torch.compile
+    def task_islands(self, A, f_A, B, f_B):
+        pass
+
+    # for X, f_X in [(A, f_A), (B, f_B)]:
+    # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
+    # k = torch.randperm(self.height * self.width)
+    # X[...]=-1
+    # for q in k:
+    # i,j=q%self.height,q//self.height
+    # if
+
+    ######################################################################
+
+    def all_tasks(self):
+        return [
+            self.task_replace_color,
+            self.task_translate,
+            self.task_grow,
+            self.task_color_grow,
+            self.task_frame,
+            self.task_detect,
+            self.task_count,
+            self.task_trajectory,
+            self.task_bounce,
+            self.task_scale,
+            self.task_symbols,
+            self.task_ortho,
+            # self.task_islands,
+        ]
+
+    def trivial_prompts_and_answers(self, prompts, answers):
+        S = self.height * self.width
+        Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S]
+        f_Bs = answers
+        return (Bs == f_Bs).long().min(dim=-1).values > 0
+
+    def generate_prompts_and_answers(
+        self, nb, tasks=None, progress_bar=False, device="cpu"
+    ):
+        if tasks is None:
+            tasks = self.all_tasks()
+
+        S = self.height * self.width
+        prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
+        answers = torch.zeros(nb, S, dtype=torch.int64)
+
+        bunch = zip(prompts, answers)
+
+        if progress_bar:
+            bunch = tqdm.tqdm(
+                bunch,
+                dynamic_ncols=True,
+                desc="world generation",
+                total=prompts.size(0),
+            )
+
+        for prompt, answer in bunch:
+            A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width)
+            f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
+            B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
+            f_B = answer.view(self.height, self.width)
+            task = tasks[torch.randint(len(tasks), (1,))]
+            task(A, f_A, B, f_B)
+
+        return prompts.flatten(1), answers.flatten(1)
+
+    def save_quizzes(
+        self,
+        result_dir,
+        filename_prefix,
+        prompts,
+        answers,
+        predicted_prompts=None,
+        predicted_answers=None,
+        nrow=4,
+    ):
+        self.save_image(
+            result_dir,
+            filename_prefix + ".png",
+            prompts,
+            answers,
+            predicted_prompts,
+            predicted_answers,
+            nrow,
+        )
+
+
+######################################################################
+
+if __name__ == "__main__":
+    import time
+
+    grids = Grids()
+
+    # nb = 1000
+    # grids = problem.MultiThreadProblem(
+    # grids, max_nb_cached_chunks=50, chunk_size=100, nb_threads=1
+    # )
+    #    time.sleep(10)
+    # start_time = time.perf_counter()
+    # prompts, answers = grids.generate_prompts_and_answers(nb)
+    # delay = time.perf_counter() - start_time
+    # print(f"{prompts.size(0)/delay:02f} seq/s")
+    # exit(0)
+
+    if True:
+        nb = 72
+
+        for t in grids.all_tasks():
+            # for t in [grids.task_ortho]:
+            print(t.__name__)
+            prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t])
+            grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)
+
+        exit(0)
+
+    nb = 500
+
+    for t in grids.all_tasks():
+        start_time = time.perf_counter()
+        prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t])
+        delay = time.perf_counter() - start_time
+        print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
+
+    exit(0)
+
+    m = torch.randint(2, (prompts.size(0),))
+    predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
+    predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
+
+    grids.save_quizzes(
+        "/tmp",
+        "test",
+        prompts[:nb],
+        answers[:nb],
+        # You can add a bool to put a frame around the predicted parts
+        predicted_prompts[:nb],
+        predicted_answers[:nb],
+    )
diff --git a/main.py b/main.py
index d412e6c..9c3d7f1 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -12,18 +12,15 @@ from torch import nn
 from torch.nn import functional as F
 
 import ffutils
+
 import mygpt
-import sky, wireworld, quizz_machine
+import sky, grids, quiz_machine
+from problem import MultiThreadProblem
 
 # world quizzes vs. culture quizzes
 
 ######################################################################
 
-nb_new_c_quizzes_for_train = 1000
-nb_new_c_quizzes_for_test = 100
-
-######################################################################
-
 if torch.cuda.is_available():
     device = torch.device("cuda")
     torch.backends.cuda.matmul.allow_tf32 = True
@@ -57,7 +54,7 @@ parser.add_argument("--nb_train_samples", type=int, default=None)
 
 parser.add_argument("--nb_test_samples", type=int, default=None)
 
-parser.add_argument("--learning_rate", type=float, default=1e-3)
+parser.add_argument("--learning_rate", type=float, default=5e-4)
 
 ########################################
 
@@ -79,22 +76,28 @@ parser.add_argument("--dropout", type=float, default=0.1)
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
-parser.add_argument("--reverse_cleanup", action="store_true", default=True)
+parser.add_argument("--problem", type=str, default="grids")
 
-parser.add_argument("--validation_forward_only", action="store_true", default=False)
-
-parser.add_argument("--problem", type=str, default="sky")
+parser.add_argument("--multi_thread_problem", action="store_true", default=False)
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
-parser.add_argument("--min_to_validate", type=int, default=4)
+parser.add_argument("--min_to_validate", type=int, default=None)
 
-parser.add_argument("--max_to_validate", type=int, default=4)
+parser.add_argument("--max_to_validate", type=int, default=None)
 
 parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
 
+parser.add_argument("--generation_temperature", type=float, default=2.0)
+
+parser.add_argument("--deterministic_validation", action="store_true", default=False)
+
+parser.add_argument("--bidirectional_validation", action="store_true", default=False)
+
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
+######################################################################
+
 parser.add_argument("--sky_height", type=int, default=6)
 
 parser.add_argument("--sky_width", type=int, default=8)
@@ -109,15 +112,14 @@ parser.add_argument("--sky_speed", type=int, default=3)
 
 args = parser.parse_args()
 
-if args.result_dir is None:
-    args.result_dir = f"results_culture"
+if args.min_to_validate is None:
+    args.min_to_validate = args.nb_gpts - 1
 
-######################################################################
+if args.max_to_validate is None:
+    args.max_to_validate = args.nb_gpts - 1
 
-if args.dirty_debug:
-    args.accuracy_to_make_c_quizzes = 0.0
-    nb_new_c_quizzes_for_train = 100
-    nb_new_c_quizzes_for_test = 10
+if args.result_dir is None:
+    args.result_dir = f"results_culture"
 
 ######################################################################
 
@@ -239,15 +241,21 @@ if args.problem == "sky":
         nb_iterations=args.sky_nb_iterations,
         speed=args.sky_speed,
     )
-elif args.problem == "wireworld":
-    problem = wireworld.Wireworld(height=8, width=10, nb_iterations=2, speed=5)
+    back_accuracy = False
+elif args.problem == "grids":
+    problem = grids.Grids(device=device)
+    back_accuracy = True
 else:
     raise ValueError
 
-quizz_machine = quizz_machine.QuizzMachine(
+if args.multi_thread_problem:
+    problem = MultiThreadProblem(problem, args.nb_train_samples, chunk_size=1000)
+
+quiz_machine = quiz_machine.QuizMachine(
     problem=problem,
     nb_train_samples=args.nb_train_samples,
     nb_test_samples=args.nb_test_samples,
+    back_accuracy=back_accuracy,
     batch_size=args.physical_batch_size,
     result_dir=args.result_dir,
     logger=log_string,
@@ -258,7 +266,7 @@ quizz_machine = quizz_machine.QuizzMachine(
 
 log_string(f"device {device}")
 
-vocabulary_size = quizz_machine.vocabulary_size()
+vocabulary_size = quiz_machine.vocabulary_size()
 
 log_string(f"vocabulary_size {vocabulary_size}")
 
@@ -267,8 +275,8 @@ log_string(f"vocabulary_size {vocabulary_size}")
 # Compute the entropy of the training tokens
 
 token_count = 0
-for input in quizz_machine.batches(split="train", desc="train-entropy"):
-    token_count += F.one_hot(input, num_classes=quizz_machine.vocabulary_size()).sum(
+for input in quiz_machine.batches(split="train", desc="train-entropy"):
+    token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
         (0, 1)
     )
 token_probas = token_count / token_count.sum()
@@ -292,11 +300,11 @@ if args.max_percents_of_test_in_train >= 0:
 
     nb_test, nb_in_train = 0, 0
     for test_subset in subsets_as_tuples(
-        quizz_machine.batches(split="test", desc="test-check"), 25000
+        quiz_machine.batches(split="test", desc="test-check"), 25000
     ):
         in_train = set()
         for train_subset in subsets_as_tuples(
-            quizz_machine.batches(split="train", desc="train-check"), 25000
+            quiz_machine.batches(split="train", desc="train-check"), 25000
         ):
             in_train.update(test_subset.intersection(train_subset))
         nb_in_train += len(in_train)
@@ -313,14 +321,14 @@ if args.max_percents_of_test_in_train >= 0:
 ##############################
 
 
-def one_epoch(model, quizz_machine):
+def one_epoch(model, quiz_machine):
     optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
 
     model.train()
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    for input in quizz_machine.batches(split="train"):
+    for input in quiz_machine.batches(split="train"):
         input = input.to(device)
 
         if nb_train_samples % args.batch_size == 0:
@@ -345,14 +353,14 @@ def one_epoch(model, quizz_machine):
 ######################################################################
 
 
-def run_tests(model, quizz_machine, deterministic_synthesis):
+def run_tests(model, quiz_machine, deterministic_synthesis):
     with torch.autograd.no_grad():
         model.eval()
 
         nb_test_samples, acc_test_loss = 0, 0.0
         nb_samples_accumulated = 0
 
-        for input in quizz_machine.batches(split="test"):
+        for input in quiz_machine.batches(split="test"):
             input = input.to(device)
 
             bs = model(mygpt.BracketedSequence(input))
@@ -364,17 +372,17 @@ def run_tests(model, quizz_machine, deterministic_synthesis):
 
             nb_test_samples += input.size(0)
 
-        model.main_test_accuracy = quizz_machine.produce_results(
+        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+
+        log_string(f"test_perplexity {n_epoch} {test_perplexity}")
+
+        model.main_test_accuracy = quiz_machine.produce_results(
             n_epoch=n_epoch,
             model=model,
             result_dir=args.result_dir,
             deterministic_synthesis=deterministic_synthesis,
         )
 
-        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
-
-        log_string(f"test_perplexity {n_epoch} {test_perplexity}")
-
 
 ######################################################################
 
@@ -389,11 +397,11 @@ def valid_c_quizzes(recorded, criteria):
 
 def create_c_quizzes(
     models,
-    quizz_machine,
+    quiz_machine,
     nb_for_train=1000,
     nb_for_test=100,
 ):
-    recorded = []
+    quizzes_and_nb_correct_records = []
 
     nb_to_create = nb_for_train + nb_for_test
 
@@ -403,41 +411,63 @@ def create_c_quizzes(
         nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate
     )
 
-    while valid_c_quizzes(recorded, standard_validity).size(0) < nb_to_create:
-        model_for_generation = models[torch.randint(len(models), (1,))]
+    file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat")
 
-        c_quizzes, ave_seq_logproba = quizz_machine.generate_quizzes(
-            nb_to_create,
-            model_for_generation=model_for_generation,
-            reverse_cleanup=args.reverse_cleanup,
-        )
+    with open(file_name, "w") as logp_file:
+        while (
+            valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity).size(0)
+            < nb_to_create
+        ):
+            # Select a model at random to generate the new quizzes
 
-        nb_correct = quizz_machine.compute_correctness(
-            c_quizzes, models, both_directions=not args.validation_forward_only
-        )
+            model_for_generation = models[torch.randint(len(models), (1,))]
 
-        if args.dirty_debug:
-            nb_correct = torch.randint(
-                len(models) + 1, nb_correct.size(), device=c_quizzes.device
+            c_quizzes = quiz_machine.generate_quizzes(
+                nb_to_create,
+                model_for_generation=model_for_generation,
+                temperature=args.generation_temperature,
             )
 
-        recorded.append((c_quizzes, nb_correct))
+            c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
 
-        nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0)
-        nv = " ".join([str(x.item()) for x in nv])
+            if c_quizzes.size(0) > 0:
+                nb_correct, seq_logproba = quiz_machine.compute_correctness(
+                    c_quizzes,
+                    models,
+                    bidirectional_validation=args.bidirectional_validation,
+                    deterministic_validation=args.deterministic_validation,
+                )
 
-        nb_validated = valid_c_quizzes(recorded, standard_validity).size(0)
+                for n, l in zip(nb_correct, seq_logproba):
+                    s = " ".join([str(x.item()) for x in l])
+                    logp_file.write(f"{n} {s}\n")
 
-        log_string(
-            f"keep c_quizzes kept {nv} nb_accumulated {nb_validated} / {nb_to_create}"
-        )
+                if args.dirty_debug:
+                    nb_correct = torch.randint(
+                        len(models) + 1, nb_correct.size(), device=c_quizzes.device
+                    )
+
+                quizzes_and_nb_correct_records.append((c_quizzes, nb_correct))
+
+            nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0)
+            nv = " ".join([str(x.item()) for x in nv])
+
+            nb_validated = valid_c_quizzes(
+                quizzes_and_nb_correct_records, standard_validity
+            ).size(0)
+
+            log_string(
+                f"keep c_quizzes model {model_for_generation.id} kept {nv} nb_accumulated {nb_validated} / {nb_to_create}"
+            )
 
     # store the new c_quizzes which have been validated
 
-    new_c_quizzes = valid_c_quizzes(recorded, standard_validity)
+    new_c_quizzes = valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity)
 
-    quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
-    quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
+    quiz_machine.reverse_random_half_in_place(new_c_quizzes)
+
+    quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
+    quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
 
     # save a bunch of images to investigate what quizzes with a
     # certain nb of correct predictions look like
@@ -449,13 +479,15 @@ def create_c_quizzes(
             else ""
         )
 
-        q = valid_c_quizzes(recorded, criteria=lambda nb_correct: nb_correct == n)[:72]
+        q = valid_c_quizzes(
+            quizzes_and_nb_correct_records, criteria=lambda nb_correct: nb_correct == n
+        )[:72]
+
+        quiz_machine.reverse_random_half_in_place(q)
 
         if q.size(0) > 0:
-            quizz_machine.problem.save_quizzes(
-                q,
-                args.result_dir,
-                f"culture_c_quiz_{n_epoch:04d}_N{n}{s}",
+            quiz_machine.save_quizzes(
+                args.result_dir, f"culture_c_quiz_{n_epoch:04d}_N{n}{s}", q
             )
 
 
@@ -486,46 +518,69 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 ######################################################################
 
+nb_new_c_quizzes_for_train = args.nb_train_samples // 50
+nb_new_c_quizzes_for_test = args.nb_test_samples // 50
+
+log_string(
+    f"nb_new_c_quizzes_for_train {nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {nb_new_c_quizzes_for_test}"
+)
+
+######################################################################
+
+if args.dirty_debug:
+    args.accuracy_to_make_c_quizzes = 0.0
+    args.nb_gpts = 2
+    nb_new_c_quizzes_for_train = 100
+    nb_new_c_quizzes_for_test = 10
+
+######################################################################
+
 for n_epoch in range(args.nb_epochs):
     log_string(f"--- epoch {n_epoch} ----------------------------------------")
 
+    cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
+    log_string(f"current_test_accuracies {cta}")
+
+    ##################################################
+    # Select, improve, and eval the worst model
+
     weakest_model = min(models, key=lambda m: float(m.main_test_accuracy))
 
     log_string(
         f"training model {weakest_model.id} main_test_accuracy {weakest_model.main_test_accuracy}"
     )
 
-    # improve it
-    one_epoch(weakest_model, quizz_machine)
+    one_epoch(weakest_model, quiz_machine)
 
     log_string(
-        f"train_set_composition w_quizzes {quizz_machine.nb_batch_w_quizzes} c_quizzes {quizz_machine.nb_batch_c_quizzes}"
+        f"train_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}"
     )
 
-    # test it
-    run_tests(weakest_model, quizz_machine, deterministic_synthesis=False)
+    run_tests(weakest_model, quiz_machine, deterministic_synthesis=False)
 
     log_string(
-        f"test_set_composition w_quizzes {quizz_machine.nb_batch_w_quizzes} c_quizzes {quizz_machine.nb_batch_c_quizzes}"
+        f"test_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}"
     )
 
-    cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
-    log_string(f"current_test_accuracies {cta}")
+    ##################################################
+    # Replace a fraction of the w_quizzes with fresh ones
+
+    quiz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
 
-    # replace a fraction of the w_quizzes with fresh ones
-    quizz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
+    ##################################################
+    # If all the models are good enough, generate new quizzes and
+    # re-compute the test errors
 
     if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
         create_c_quizzes(
             models,
-            quizz_machine,
+            quiz_machine,
             nb_for_train=nb_new_c_quizzes_for_train,
             nb_for_test=nb_new_c_quizzes_for_test,
         )
 
-        # We update everyone
         for model in models:
-            run_tests(model, quizz_machine, deterministic_synthesis=False)
+            run_tests(model, quiz_machine, deterministic_synthesis=False)
 
 
 ######################################################################
index 0795de1..a49634d 100755 (executable)
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
+import threading, queue, torch, tqdm
+
 
 class Problem:
-    # returns a nb x (L+1+L) long tensor where L is the length of one
-    # of the two states of a quizz
-    def generate_token_sequences(self, nb):
+    def nb_token_values(self):
         pass
 
-    # save a file to vizualize quizzes, you can save a txt or png file
-    def save_quizzes(self, input, result_dir, filename_prefix):
+    def trivial_prompts_and_answers(self, prompts, answers):
         pass
 
-    # returns a pair (forward_tokens, backward_token)
-    def direction_tokens(self):
+    # returns two tensors nb x D and nb x D'
+    def generate_prompts_and_answers(self, nb):
+        pass
+
+    # save a file to vizualize quizzes, you can save a txt or png file
+    def save_quizzes(
+        self,
+        result_dir,
+        filename_prefix,
+        prompts,
+        answers,
+        predicted_prompts=None,
+        predicted_answers=None,
+    ):
         pass
+
+
+class MultiThreadProblem:
+    def __init__(self, problem, max_nb_cached_chunks, chunk_size, nb_threads=1):
+        self.problem = problem
+        self.chunk_size = chunk_size
+        self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
+        for _ in range(nb_threads):
+            threading.Thread(target=self.fill_cache, daemon=True).start()
+        self.rest = None
+
+    def nb_token_values(self):
+        return self.problem.nb_token_values()
+
+    def save_quizzes(
+        self,
+        result_dir,
+        filename_prefix,
+        prompts,
+        answers,
+        predicted_prompts=None,
+        predicted_answers=None,
+    ):
+        self.problem.save_quizzes(
+            result_dir,
+            filename_prefix,
+            prompts,
+            answers,
+            predicted_prompts=None,
+            predicted_answers=None,
+        )
+
+    def fill_cache(self):
+        while True:
+            prompts, answers = self.problem.generate_prompts_and_answers(
+                self.chunk_size
+            )
+
+            self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
+
+    def trivial_prompts_and_answers(self, prompts, answers):
+        return self.problem.trivial_prompts_and_answers(prompts, answers)
+
+    def generate_prompts_and_answers(self, nb):
+        if self.rest is not None:
+            prompts, answers = rest
+        else:
+            prompts, answers = [], []
+
+        self.rest = None
+
+        n = sum([p.size(0) for p in prompts])
+
+        with tqdm.tqdm(
+            total=nb,
+            dynamic_ncols=True,
+            desc="world generation",
+        ) as pbar:
+            while n < nb:
+                p, s = self.queue.get(block=True)
+                prompts.append(p)
+                answers.append(s)
+                n += p.size(0)
+                pbar.update(p.size(0))
+
+        prompts, answers = torch.cat(prompts, dim=0), torch.cat(answers, dim=0)
+        assert n == prompts.size(0)
+
+        k = n - nb
+
+        if k > 0:
+            rest = (prompts[-k:], answers[-k:])
+            prompts, answers = prompts[:-k], answers[:-k]
+
+        return prompts, answers
diff --git a/quiz_machine.py b/quiz_machine.py
new file mode 100755 (executable)
index 0000000..f0fb408
--- /dev/null
@@ -0,0 +1,541 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import math, os, tqdm, warnings
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+import mygpt
+from mygpt import BracketedSequence
+
+######################################################################
+
+# ar_mask is a tensor with 0s and 1s, of same shape as input, with
+# 1s where tokens should be generated. The others are kept
+# unchanged.
+
+
+def one_batch_masked_inplace_autoregression(
+    model,
+    input,
+    ar_mask,
+    seq_logproba,
+    temperature,
+    deterministic_synthesis,
+):
+    to_generate = (ar_mask.sum(0) > 0).nonzero()
+
+    if to_generate.min() > 0:
+        model(
+            BracketedSequence(input, 0, to_generate.min())
+        )  # Needed to initialize the model's cache
+    for s in range(to_generate.min(), to_generate.max() + 1):
+        output = model(BracketedSequence(input, s, 1)).x
+
+        logits = output[:, s]
+
+        logits = (logits / temperature).log_softmax(dim=-1)
+
+        if deterministic_synthesis:
+            t_next = logits.argmax(-1)
+        else:
+            dist = torch.distributions.categorical.Categorical(logits=logits)
+            t_next = dist.sample()
+
+        all_n = torch.arange(t_next.size(0))
+
+        seq_logproba += logits[all_n, t_next]
+
+        input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+
+
+def masked_inplace_autoregression(
+    model,
+    batch_size,
+    input,
+    ar_mask,
+    seq_logproba,
+    temperature,
+    deterministic_synthesis,
+    forbidden_tokens=None,
+    logit_biases=None,
+    progress_bar_desc=None,
+    device=torch.device("cpu"),
+):
+    assert input.size() == ar_mask.size()
+
+    batches = zip(
+        input.split(batch_size),
+        ar_mask.split(batch_size),
+        seq_logproba.split(batch_size),
+    )
+
+    if progress_bar_desc is not None:
+        batches = tqdm.tqdm(
+            batches,
+            dynamic_ncols=True,
+            desc=progress_bar_desc,
+            total=(input.size(0) + batch_size - 1) // batch_size,
+        )
+
+    with torch.autograd.no_grad():
+        t = model.training
+        model.eval()
+
+        for input, ar_mask, seq_logproba in batches:
+            one_batch_masked_inplace_autoregression(
+                model=model,
+                input=input,
+                ar_mask=ar_mask,
+                seq_logproba=seq_logproba,
+                temperature=temperature,
+                deterministic_synthesis=deterministic_synthesis,
+            )
+
+        model.train(t)
+
+
+######################################################################
+
+
+class QuizMachine:
+    def indices_forward_and_backward(self, quizzes):
+        i_forward = quizzes[:, 0] == self.token_forward
+        j_forward = quizzes[:, 1 + self.prompt_len] == self.token_forward
+        i_backward = quizzes[:, 0] == self.token_backward
+        j_backward = quizzes[:, 1 + self.answer_len] == self.token_backward
+        assert torch.logical_or(
+            torch.logical_and(i_forward, j_forward),
+            torch.logical_and(i_backward, j_backward),
+        ).all()
+        return i_forward, i_backward
+
+    def non_trivial(self, quizzes):
+        quizzes = quizzes.clone()
+        n_forward = quizzes[quizzes[:, 0] == self.token_forward]
+        n_backward = quizzes[:, 0] == self.token_backward
+        backward = quizzes[n_backward]
+        quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
+        return torch.logical_not(
+            self.problem.trivial_prompts_and_answers(
+                quizzes[:, 1 : 1 + self.prompt_len],
+                quizzes[:, 2 + self.prompt_len :],
+            )
+        )
+
+    def reverse_time(self, quizzes):
+        i_forward, i_backward = self.indices_forward_and_backward(quizzes)
+
+        forward_to_backward = torch.cat(
+            [
+                quizzes[:, 0:1],
+                quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len],
+                quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1],
+                quizzes[:, 1 : 1 + self.prompt_len],
+            ],
+            dim=1,
+        )
+
+        forward_to_backward[:, 0] = self.token_backward
+        forward_to_backward[:, 1 + self.answer_len] = self.token_backward
+
+        backward_to_forward = torch.cat(
+            [
+                quizzes[:, 0:1],
+                quizzes[:, 2 + self.answer_len :],
+                quizzes[:, 1 + self.answer_len : 2 + self.answer_len],
+                quizzes[:, 1 : 1 + self.answer_len],
+            ],
+            dim=1,
+        )
+
+        backward_to_forward[:, 0] = self.token_forward
+        backward_to_forward[:, 1 + self.prompt_len] = self.token_forward
+
+        m = i_forward.long()[:, None]
+
+        return m * forward_to_backward + (1 - m) * backward_to_forward
+
+    def reverse_random_half_in_place(self, quizzes):
+        i = torch.rand(quizzes.size(0)) < 0.5
+        if i.any():
+            quizzes[i] = self.reverse_time(quizzes[i])
+
+    def make_ar_mask(self, quizzes, first=False):
+        i_forward, i_backward = self.indices_forward_and_backward(quizzes)
+
+        t = torch.arange(quizzes.size(1), device=quizzes.device)
+
+        if first:
+            m_forward = (t >= 1).long() * (t < 1 + self.prompt_len).long()
+            m_backward = (t >= 1).long() * (t < 1 + self.answer_len).long()
+        else:
+            m_forward = (t >= 2 + self.prompt_len).long()
+            m_backward = (t >= 2 + self.answer_len).long()
+
+        m = i_forward.long()[:, None]
+
+        return m * m_forward + (1 - m) * m_backward
+
+    def generate_token_sequences(self, nb):
+        prompts, answers = self.problem.generate_prompts_and_answers(nb)
+
+        if self.prompt_len is None:
+            self.prompt_len = prompts.size(1)
+
+        if self.answer_len is None:
+            self.answer_len = answers.size(1)
+
+        assert prompts.size(1) == self.prompt_len and answers.size(1) == self.answer_len
+
+        result = []
+
+        for prompt, answer in zip(prompts, answers):
+            a = [
+                torch.tensor([self.token_forward]),
+                prompt,
+                torch.tensor([self.token_forward]),
+                answer,
+            ]
+
+            result.append(torch.cat(a, dim=0)[None, :])
+
+        return torch.cat(result, dim=0)
+
+    def __init__(
+        self,
+        problem,
+        nb_train_samples,
+        nb_test_samples,
+        back_accuracy,
+        batch_size,
+        result_dir,
+        logger,
+        device=torch.device("cpu"),
+    ):
+        super().__init__()
+
+        v = problem.nb_token_values()
+        self.token_forward = v
+        self.token_backward = v + 1
+        self.nb_token_values = v + 2
+
+        self.problem = problem
+        self.back_accuracy = back_accuracy
+        self.batch_size = batch_size
+        self.device = device
+        self.logger = logger
+        self.prompt_len = None
+        self.answer_len = None
+
+        self.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
+        self.reverse_random_half_in_place(self.train_w_quizzes)
+        self.train_w_quizzes = self.train_w_quizzes.to(device)
+
+        self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
+        self.reverse_random_half_in_place(self.test_w_quizzes)
+        self.test_w_quizzes = self.test_w_quizzes.to(device)
+
+        self.train_c_quizzes = []
+        self.test_c_quizzes = []
+
+        if result_dir is not None:
+            self.save_quizzes(
+                result_dir,
+                "culture_w_quizzes",
+                self.train_w_quizzes[:72],
+            )
+
+    def save_quizzes(
+        self,
+        result_dir,
+        filename_prefix,
+        quizzes,
+        mistakes=None,
+    ):
+        quizzes = quizzes.clone()
+        n_forward = quizzes[quizzes[:, 0] == self.token_forward]
+        n_backward = quizzes[:, 0] == self.token_backward
+        backward = quizzes[n_backward]
+        assert n_forward.size(0) + backward.size(0) == quizzes.size(0)
+        quizzes[n_backward] = self.reverse_time(quizzes[n_backward])
+
+        predicted_prompts = n_backward.long()
+        predicted_answers = 1 - predicted_prompts
+        if mistakes is not None:
+            # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
+            predicted_prompts *= mistakes
+            predicted_answers *= mistakes
+        else:
+            # 0/2 ~ not-to-predict / to predict
+            predicted_prompts *= 2
+            predicted_answers *= 2
+
+        self.problem.save_quizzes(
+            result_dir,
+            filename_prefix,
+            quizzes[:, 1 : 1 + self.prompt_len],
+            quizzes[:, 2 + self.prompt_len :],
+            predicted_prompts,
+            predicted_answers,
+        )
+
+    def batches(self, split="train", desc=None):
+        assert split in {"train", "test"}
+        if split == "train":
+            w_quizzes = self.train_w_quizzes
+            c_quizzes = self.train_c_quizzes
+        else:
+            w_quizzes = self.test_w_quizzes
+            c_quizzes = self.test_c_quizzes
+
+        if len(c_quizzes) > 0:
+            c_quizzes = torch.cat(c_quizzes, dim=0)
+            if c_quizzes.size(0) > w_quizzes.size(0) // 2:
+                i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
+                c_quizzes = c_quizzes[i]
+
+            i = torch.randperm(w_quizzes.size(0))[
+                : w_quizzes.size(0) - c_quizzes.size(0)
+            ]
+            w_quizzes = w_quizzes[i]
+
+            self.nb_batch_w_quizzes = w_quizzes.size(0)
+            self.nb_batch_c_quizzes = c_quizzes.size(0)
+
+            input = torch.cat([w_quizzes, c_quizzes], dim=0)
+        else:
+            input = w_quizzes
+            self.nb_batch_w_quizzes = w_quizzes.size(0)
+            self.nb_batch_c_quizzes = 0
+
+        # Shuffle
+        input = input[torch.randperm(input.size(0))]
+
+        if desc is None:
+            desc = f"epoch-{split}"
+        for batch in tqdm.tqdm(
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
+        ):
+            yield batch
+
+    def vocabulary_size(self):
+        return self.nb_token_values
+
+    def produce_results(
+        self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
+    ):
+        def compute_accuracy(input, log_prefix=None):
+            ar_mask = self.make_ar_mask(input)
+            result = input.clone() * (1 - ar_mask)
+            seq_logproba = torch.empty(input.size(0), device=self.device)
+
+            masked_inplace_autoregression(
+                model=model,
+                batch_size=self.batch_size,
+                input=result,
+                ar_mask=ar_mask,
+                seq_logproba=seq_logproba,
+                temperature=1.0,
+                deterministic_synthesis=deterministic_synthesis,
+                progress_bar_desc=None,
+                device=self.device,
+            )
+
+            correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device)
+
+            n_forward = input[:, 0] == self.token_forward
+            n_backward = input[:, 0] == self.token_backward
+
+            correct[n_forward] = (
+                (input[n_forward] == result[n_forward]).long().min(dim=1).values
+            )
+
+            if self.back_accuracy and n_backward.any():
+                # accuracy of B->A*->B*=B instead of B->A*=A
+                back_input = self.reverse_time(result[n_backward])
+                back_input[:, 2 + self.prompt_len :] = input[
+                    n_backward, 1 : 1 + self.answer_len
+                ]
+                _, correct[n_backward] = compute_accuracy(back_input)
+
+            if log_prefix is not None:
+                forward_nb_correct = correct[n_forward].sum()
+                forward_nb_total = correct[n_forward].size(0)
+                backward_nb_correct = correct[n_backward].sum()
+                backward_nb_total = correct[n_backward].size(0)
+
+                self.logger(
+                    f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)"
+                )
+
+                self.logger(
+                    f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)"
+                )
+
+            return result, correct
+
+        compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train")
+
+        test_result, test_correct = compute_accuracy(
+            self.test_w_quizzes[:nmax], log_prefix="test"
+        )
+
+        main_test_accuracy = test_correct.sum() / test_correct.size(0)
+        self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
+
+        ##############################
+
+        self.save_quizzes(
+            result_dir,
+            f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
+            quizzes=test_result[:72],
+            mistakes=test_correct[:72] * 2 - 1,
+        )
+
+        return main_test_accuracy
+
+    def renew_w_quizzes(self, nb, for_train=True):
+        input = self.train_w_quizzes if for_train else self.test_w_quizzes
+        nb = min(nb, input.size(0))
+        input[:-nb] = input[nb:].clone()
+        fresh_w_quizzes = self.generate_token_sequences(nb)
+        self.reverse_random_half_in_place(fresh_w_quizzes)
+        input[-nb:] = fresh_w_quizzes.to(self.device)
+
+    def store_c_quizzes(self, new_c_quizzes, for_train=True):
+        if for_train:
+            self.train_c_quizzes.append(new_c_quizzes)
+        else:
+            self.test_c_quizzes.append(new_c_quizzes)
+
+    def compute_correctness(
+        self,
+        c_quizzes,
+        models_for_validation,
+        bidirectional_validation=False,
+        deterministic_validation=True,
+    ):
+        if bidirectional_validation:
+            backward_c_quizzes = self.forward_to_backward(c_quizzes)
+
+        seq_logproba = torch.zeros(
+            c_quizzes.size(0),
+            max([m.id for m in models_for_validation]) + 1,
+            device=self.device,
+        )
+
+        nb_correct = 0
+
+        seq_logproba[...] = 0.0
+
+        for model in models_for_validation:
+            result = c_quizzes.clone()
+
+            ar_mask = self.make_ar_mask(result)
+
+            masked_inplace_autoregression(
+                model=model,
+                batch_size=self.batch_size,
+                input=result,
+                ar_mask=ar_mask,
+                seq_logproba=seq_logproba[:, model.id],
+                temperature=1.0,
+                deterministic_synthesis=deterministic_validation,
+                # progress_bar_desc="solving c_quizzes",
+                device=self.device,
+            )
+
+            correct = (c_quizzes == result).long().min(dim=-1).values
+
+            if bidirectional_validation:
+                backward_result = backward_c_quizzes.clone()
+
+                ar_mask = self.make_ar_mask(backward_result)
+
+                masked_inplace_autoregression(
+                    model=model,
+                    batch_size=self.batch_size,
+                    input=backward_result,
+                    ar_mask=ar_mask,
+                    seq_logproba=seq_logproba[:, model.id],
+                    temperature=1.0,
+                    deterministic_synthesis=deterministic_validation,
+                    # progress_bar_desc="solving backward c_quizzes",
+                    device=self.device,
+                )
+
+                backward_correct = (
+                    (backward_c_quizzes == backward_result).long().min(dim=-1).values
+                )
+
+                correct *= backward_correct
+
+            # endif
+
+            nb_correct += correct
+
+        return nb_correct, seq_logproba
+
+    ###############################################################
+
+    def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
+        c_quizzes = torch.empty(
+            nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
+        )
+
+        seq_logproba = torch.zeros(nb, device=self.device)
+
+        # First, we generate the answer at high temperature
+
+        c_quizzes[:, 0] = self.token_backward
+        c_quizzes[:, 1 + self.answer_len] = self.token_backward
+
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=self.make_ar_mask(c_quizzes, first=True),
+            seq_logproba=seq_logproba,
+            temperature=temperature,
+            deterministic_synthesis=False,
+            device=self.device,
+        )
+
+        # Then, we generate the prompt at low temperature
+
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=self.make_ar_mask(c_quizzes),
+            seq_logproba=seq_logproba,
+            temperature=1 / temperature,
+            deterministic_synthesis=False,
+            device=self.device,
+        )
+
+        # Then we return the quizz, and re-generate the response, now
+        # at low temperature
+
+        c_quizzes = self.reverse_time(c_quizzes)
+
+        masked_inplace_autoregression(
+            model=model_for_generation,
+            batch_size=self.batch_size,
+            input=c_quizzes,
+            ar_mask=self.make_ar_mask(c_quizzes),
+            seq_logproba=seq_logproba,
+            temperature=1 / temperature,
+            deterministic_synthesis=False,
+            device=self.device,
+        )
+
+        return c_quizzes
diff --git a/quizz_machine.py b/quizz_machine.py
deleted file mode 100755 (executable)
index 697f27e..0000000
+++ /dev/null
@@ -1,421 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, os, tqdm, warnings
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-import mygpt
-from mygpt import BracketedSequence
-
-######################################################################
-
-# ar_mask is a tensor with 0s and 1s, of same shape as input, with
-# 1s where tokens should be generated. The others are kept
-# unchanged.
-
-
-def one_batch_masked_inplace_autoregression(
-    model,
-    input,
-    ar_mask,
-    seq_logproba,
-    temperature=1.0,
-    deterministic_synthesis=False,
-    forbidden_tokens=None,
-    forced_biases=None,
-):
-    to_generate = (ar_mask.sum(0) > 0).nonzero()
-
-    if to_generate.min() > 0:
-        model(
-            BracketedSequence(input, 0, to_generate.min())
-        )  # Needed to initialize the model's cache
-    for s in range(to_generate.min(), to_generate.max() + 1):
-        output = model(BracketedSequence(input, s, 1)).x
-
-        logits = output[:, s]
-
-        logits = (logits / temperature).log_softmax(dim=-1)
-
-        if forbidden_tokens is not None:
-            logits = logits.masked_fill(forbidden_tokens, float("-inf"))
-
-        if forced_biases is not None:
-            logits = logits + forced_biases[None, :]
-
-        if deterministic_synthesis:
-            t_next = logits.argmax(-1)
-        else:
-            dist = torch.distributions.categorical.Categorical(logits=logits)
-            t_next = dist.sample()
-
-        all_n = torch.arange(t_next.size(0))
-        seq_logproba += logits[all_n, t_next].sum(dim=-1)
-
-        input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
-
-
-def masked_inplace_autoregression(
-    model,
-    batch_size,
-    input,
-    ar_mask,
-    seq_logproba,
-    temperature,
-    deterministic_synthesis,
-    forbidden_tokens=None,
-    logit_biases=None,
-    progress_bar_desc=None,
-    device=torch.device("cpu"),
-):
-    assert input.size() == ar_mask.size()
-
-    batches = zip(
-        input.split(batch_size),
-        ar_mask.split(batch_size),
-        seq_logproba.split(batch_size),
-    )
-
-    if progress_bar_desc is not None:
-        batches = tqdm.tqdm(
-            batches,
-            dynamic_ncols=True,
-            desc=progress_bar_desc,
-            total=(input.size(0) + batch_size - 1) // batch_size,
-        )
-
-    with torch.autograd.no_grad():
-        t = model.training
-        model.eval()
-
-        for input, ar_mask, seq_logproba in batches:
-            one_batch_masked_inplace_autoregression(
-                model=model,
-                input=input,
-                ar_mask=ar_mask,
-                seq_logproba=seq_logproba,
-                temperature=temperature,
-                deterministic_synthesis=deterministic_synthesis,
-                forbidden_tokens=forbidden_tokens,
-                forced_biases=logit_biases,
-            )
-
-        model.train(t)
-
-
-######################################################################
-
-
-class QuizzMachine:
-    def make_ar_mask(self, input):
-        b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
-        return b.long()[None, :].expand_as(input)
-
-    def __init__(
-        self,
-        problem,
-        nb_train_samples,
-        nb_test_samples,
-        batch_size,
-        result_dir,
-        logger,
-        device=torch.device("cpu"),
-    ):
-        super().__init__()
-
-        self.problem = problem
-        self.batch_size = batch_size
-        self.device = device
-        self.logger = logger
-
-        self.train_w_quizzes = self.problem.generate_token_sequences(
-            nb_train_samples
-        ).to(device)
-
-        self.test_w_quizzes = self.problem.generate_token_sequences(nb_test_samples).to(
-            device
-        )
-
-        self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1
-
-        self.train_c_quizzes = []
-        self.test_c_quizzes = []
-
-        if result_dir is not None:
-            self.problem.save_quizzes(
-                self.train_w_quizzes[:72], result_dir, "culture_w_quizzes"
-            )
-
-    def batches(self, split="train", desc=None):
-        assert split in {"train", "test"}
-        if split == "train":
-            w_quizzes = self.train_w_quizzes
-            c_quizzes = self.train_c_quizzes
-        else:
-            w_quizzes = self.test_w_quizzes
-            c_quizzes = self.test_c_quizzes
-
-        if len(c_quizzes) > 0:
-            c_quizzes = torch.cat(c_quizzes, dim=0)
-            if c_quizzes.size(0) > w_quizzes.size(0) // 2:
-                i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
-                c_quizzes = c_quizzes[i]
-
-            i = torch.randperm(w_quizzes.size(0))[
-                : w_quizzes.size(0) - c_quizzes.size(0)
-            ]
-            w_quizzes = w_quizzes[i]
-
-            self.nb_batch_w_quizzes = w_quizzes.size(0)
-            self.nb_batch_c_quizzes = c_quizzes.size(0)
-
-            input = torch.cat([w_quizzes, c_quizzes], dim=0)
-        else:
-            input = w_quizzes
-            self.nb_batch_w_quizzes = w_quizzes.size(0)
-            self.nb_batch_c_quizzes = 0
-
-        # Shuffle
-        input = input[torch.randperm(input.size(0))]
-
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=desc
-        ):
-            yield batch
-
-    def vocabulary_size(self):
-        return self.nb_codes
-
-    def produce_results(
-        self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
-    ):
-        def compute_accuracy(input):
-            input = input[:nmax]
-            ar_mask = self.make_ar_mask(input)
-            result = input.clone() * (1 - ar_mask)
-            seq_logproba = torch.empty(input.size(0), device=self.device)
-
-            masked_inplace_autoregression(
-                model=model,
-                batch_size=self.batch_size,
-                input=result,
-                ar_mask=ar_mask,
-                seq_logproba=seq_logproba,
-                temperature=1.0,
-                deterministic_synthesis=deterministic_synthesis,
-                progress_bar_desc=None,
-                device=self.device,
-            )
-
-            nb_total, nb_correct = (
-                input.size(0),
-                (input == result).long().min(dim=1).values.sum(),
-            )
-
-            return nb_total, nb_correct
-
-        train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
-
-        self.logger(
-            f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
-        )
-
-        test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes)
-
-        self.logger(
-            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
-        )
-
-        main_test_accuracy = test_nb_correct / test_nb_total
-        self.logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
-
-        ##############################
-
-        input = self.test_w_quizzes[:96]
-        ar_mask = self.make_ar_mask(input)
-        result = input.clone() * (1 - ar_mask)
-        seq_logproba = torch.empty(input.size(0), device=self.device)
-
-        masked_inplace_autoregression(
-            model=model,
-            batch_size=self.batch_size,
-            input=result,
-            ar_mask=ar_mask,
-            seq_logproba=seq_logproba,
-            temperature=1.0,
-            deterministic_synthesis=deterministic_synthesis,
-            progress_bar_desc=None,
-            device=self.device,
-        )
-
-        self.problem.save_quizzes(
-            result[:72], result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}"
-        )
-
-        return main_test_accuracy
-
-    def renew_w_quizzes(self, nb, for_train=True):
-        input = self.train_w_quizzes if for_train else self.test_w_quizzes
-        nb = min(nb, input.size(0))
-        input[:-nb] = input[nb:].clone()
-        input[-nb:] = self.problem.generate_token_sequences(nb).to(self.device)
-
-    def store_c_quizzes(self, new_c_quizzes, for_train=True):
-        if for_train:
-            self.train_c_quizzes.append(new_c_quizzes)
-        else:
-            self.test_c_quizzes.append(new_c_quizzes)
-
-    def reverse_time(self, c_quizzes):
-        token_forward, token_backward = self.problem.direction_tokens()
-
-        l = (c_quizzes.size(1) - 1) // 2
-        direction = c_quizzes[:, l : l + 1]
-        direction = self.problem.token_forward * (
-            direction == self.problem.token_backward
-        ) + self.problem.token_backward * (direction == self.problem.token_forward)
-
-        return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1)
-
-    def compute_correctness(
-        self, c_quizzes, models_for_validation, both_directions=True
-    ):
-        reversed_c_quizzes = self.reverse_time(c_quizzes)
-
-        ar_mask = self.make_ar_mask(c_quizzes)
-        seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
-
-        # Check how many of models can solve the quizzes in both directions
-
-        nb_correct = 0
-
-        for model in models_for_validation:
-            result = c_quizzes.clone()
-
-            masked_inplace_autoregression(
-                model=model,
-                batch_size=self.batch_size,
-                input=result,
-                ar_mask=ar_mask,
-                seq_logproba=seq_logproba,
-                temperature=1.0,
-                deterministic_synthesis=True,
-                # progress_bar_desc="solving c_quizzes",
-                device=self.device,
-            )
-
-            correct = (c_quizzes == result).long().min(dim=-1).values
-
-            if both_directions:
-                reversed_result = reversed_c_quizzes.clone()
-
-                masked_inplace_autoregression(
-                    model=model,
-                    batch_size=self.batch_size,
-                    input=reversed_result,
-                    ar_mask=ar_mask,
-                    seq_logproba=seq_logproba,
-                    temperature=1.0,
-                    deterministic_synthesis=True,
-                    # progress_bar_desc="solving reversed c_quizzes",
-                    device=self.device,
-                )
-
-                reversed_correct = (
-                    (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
-                )
-
-                correct *= reversed_correct
-
-            # endif
-
-            nb_correct += correct
-
-        return nb_correct
-
-    ###############################################################
-
-    def generate_quizzes(self, nb, model_for_generation, reverse_cleanup=False):
-        c_quizzes = torch.empty(
-            nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
-        )
-
-        ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device)
-        ar_mask_prompt[:, : ar_mask_prompt.size(1) // 2 + 1] = 1
-        ar_mask_solve = 1 - ar_mask_prompt
-        seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
-
-        if reverse_cleanup:
-            warnings.warn("very high temperature with reversed cleanup", RuntimeWarning)
-            temperature = 10.0
-        else:
-            temperature = 1.0
-
-        # warnings.warn("noise injection", RuntimeWarning)
-        # noise_std = torch.rand(1).item()
-        # self.logger(f"{noise_std=}")
-
-        # mygpt.set_noise_injection(model_for_generation, noise_std)
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=ar_mask_prompt,
-            seq_logproba=seq_logproba,
-            temperature=temperature,
-            deterministic_synthesis=False,
-            device=self.device,
-        )
-
-        # mygpt.set_noise_injection(model_for_generation, 0.0)
-
-        ave_seq_logproba = seq_logproba.mean()
-
-        masked_inplace_autoregression(
-            model=model_for_generation,
-            batch_size=self.batch_size,
-            input=c_quizzes,
-            ar_mask=ar_mask_solve,
-            seq_logproba=seq_logproba,
-            temperature=temperature,
-            deterministic_synthesis=True,
-            device=self.device,
-        )
-
-        if reverse_cleanup:
-            c_quizzes = self.reverse_time(c_quizzes)
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=ar_mask_solve,
-                seq_logproba=seq_logproba,
-                temperature=temperature,
-                deterministic_synthesis=True,
-                device=self.device,
-            )
-
-            c_quizzes = self.reverse_time(c_quizzes)
-            masked_inplace_autoregression(
-                model=model_for_generation,
-                batch_size=self.batch_size,
-                input=c_quizzes,
-                ar_mask=ar_mask_solve,
-                seq_logproba=seq_logproba,
-                temperature=temperature,
-                deterministic_synthesis=True,
-                device=self.device,
-            )
-
-        return c_quizzes, seq_logproba.mean()
diff --git a/sky.py b/sky.py
index 4ca4ba7..ed440d3 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, sys, tqdm, os
+import math, sys, tqdm, os, warnings
 
 import torch, torchvision
 
@@ -37,8 +37,6 @@ class Sky(problem.Problem):
     token_background = 0
     first_bird_token = 1
     nb_bird_tokens = colors.size(0) - 1
-    token_forward = first_bird_token + nb_bird_tokens
-    token_backward = token_forward + 1
 
     token2char = (
         "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
@@ -60,9 +58,6 @@ class Sky(problem.Problem):
         self.nb_iterations = nb_iterations
         self.avoid_collision = avoid_collision
 
-    def direction_tokens(self):
-        return self.token_forward, self.token_backward
-
     def generate_frame_sequences(self, nb):
         frame_sequences = []
 
@@ -157,38 +152,8 @@ class Sky(problem.Problem):
 
     ######################################################################
 
-    def generate_prompts_and_answers(self, nb):
-        frame_sequences = self.generate_frame_sequences(nb)
-        prompts = frame_sequences[:, : frame_sequences.size(0) // 2].flatten(1)
-        answers = frame_sequences[:, frame_sequences.size(0) // 2 :].flatten(1)
-        return prompts, answers
-
-    def generate_token_sequences(self, nb):
-        frame_sequences = self.generate_frame_sequences(nb)
-
-        result = []
-
-        for frame_sequence in frame_sequences:
-            a = []
-            if torch.rand(1) < 0.5:
-                for frame in frame_sequence:
-                    if len(a) > 0:
-                        a.append(torch.tensor([self.token_forward]))
-                    a.append(frame.flatten())
-            else:
-                for frame in reversed(frame_sequence):
-                    if len(a) > 0:
-                        a.append(torch.tensor([self.token_backward]))
-                    a.append(frame.flatten())
-
-            result.append(torch.cat(a, dim=0)[None, :])
-
-        return torch.cat(result, dim=0)
-
-    ######################################################################
-
     def frame2img(self, x, scale=15):
-        x = x.reshape(-1, self.height, self.width)
+        x = x.reshape(x.size(0), self.height, -1)
         m = torch.logical_and(
             x >= 0, x < self.first_bird_token + self.nb_bird_tokens
         ).long()
@@ -214,92 +179,140 @@ class Sky(problem.Problem):
 
         return x
 
-    def seq2img(self, seq, scale=15):
-        all = [
-            self.frame2img(
-                seq[:, : self.height * self.width].reshape(-1, self.height, self.width),
-                scale,
-            )
-        ]
+    def seq2str(self, seq):
+        result = []
+        for s in seq:
+            result.append("".join([self.token2char[v] for v in s]))
+        return result
+
+    def save_image(
+        self,
+        result_dir,
+        filename,
+        prompts,
+        answers,
+        predicted_prompts=None,
+        predicted_answers=None,
+    ):
+        if predicted_prompts is None:
+            predicted_prompts = 255
 
-        separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0)
+        if predicted_answers is None:
+            predicted_answers = 255
 
-        t = self.height * self.width
+        def add_frame(x, c, margin, bottom=False):
+            if bottom:
+                h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
+            else:
+                h, w, di, dj = (
+                    x.size(2) + 2 * margin,
+                    x.size(3) + 2 * margin,
+                    margin,
+                    margin,
+                )
 
-        while t < seq.size(1):
-            direction_tokens = seq[:, t]
-            t += 1
+            y = x.new_full((x.size(0), x.size(1), h, w), 0)
 
-            direction_images = self.colors[
-                torch.full(
-                    (direction_tokens.size(0), self.height * scale - 1, scale), 0
+            if type(c) is int:
+                y[...] = c
+            else:
+                c = c.long()[:, None]
+                c = (
+                    (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
+                    + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
+                    + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
                 )
-            ].permute(0, 3, 1, 2)
-
-            for n in range(direction_tokens.size(0)):
-                if direction_tokens[n] == self.token_forward:
-                    for k in range(scale):
-                        for l in [0, 1]:
-                            direction_images[
-                                n,
-                                :,
-                                (self.height * scale) // 2 - scale // 2 + k - l,
-                                3 + scale // 2 - abs(k - scale // 2),
-                            ] = 0
-                elif direction_tokens[n] == self.token_backward:
-                    for k in range(scale):
-                        for l in [0, 1]:
-                            direction_images[
-                                n,
-                                :,
-                                (self.height * scale) // 2 - scale // 2 + k - l,
-                                3 + abs(k - scale // 2),
-                            ] = 0
-                else:
-                    for k in range(2, scale - 2):
-                        for l in [0, 1]:
-                            direction_images[
-                                n,
-                                :,
-                                (self.height * scale) // 2 - scale // 2 + k - l,
-                                k,
-                            ] = 0
-                            direction_images[
-                                n,
-                                :,
-                                (self.height * scale) // 2 - scale // 2 + k - l,
-                                scale - 1 - k,
-                            ] = 0
-
-            all += [
-                separator,
-                direction_images,
-                separator,
-                self.frame2img(
-                    seq[:, t : t + self.height * self.width].reshape(
-                        -1, self.height, self.width
-                    ),
-                    scale,
-                ),
-            ]
-
-            t += self.height * self.width
-
-        return torch.cat(all, dim=3)
+                y[...] = c[:, :, None, None]
 
-    def seq2str(self, seq):
-        result = []
-        for s in seq:
-            result.append("".join([self.token2char[v] for v in s]))
-        return result
+            y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
+
+            return y
+
+        margin = 4
+
+        img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
+        h = img_prompts.size(2)
+        img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
+
+        img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
+        img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
+
+        img_prompts = add_frame(
+            img_prompts, c=predicted_prompts, margin=margin, bottom=True
+        )
+        img_answers = add_frame(
+            img_answers, c=predicted_answers, margin=margin, bottom=True
+        )
+
+        marker_size = 16
+
+        separator = img_prompts.new_full(
+            (
+                img_prompts.size(0),
+                img_prompts.size(1),
+                img_prompts.size(2),
+                marker_size,
+            ),
+            255,
+        )
+
+        separator[:, :, 0] = 0
+        separator[:, :, h - 1] = 0
+
+        for k in range(1, 2 * marker_size - 8):
+            i = k - (marker_size - 4)
+            j = marker_size - 5 - abs(i)
+            separator[:, :, h // 2 - 1 + i, 2 + j] = 0
+            separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
+
+        img = torch.cat([img_prompts, separator, img_answers], dim=3)
 
-    def save_image(self, input, result_dir, filename):
-        img = self.seq2img(input.to("cpu"))
         image_name = os.path.join(result_dir, filename)
-        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
+        torchvision.utils.save_image(
+            img.float() / 255.0, image_name, nrow=6, padding=margin * 4, pad_value=1.0
+        )
+
+    ######################################################################
+
+    def nb_token_values(self):
+        return len(self.colors)
 
-    def save_quizzes(self, input, result_dir, filename_prefix):
-        self.save_image(input, result_dir, filename_prefix + ".png")
+    def generate_prompts_and_answers(self, nb):
+        frame_sequences = self.generate_frame_sequences(nb)
+        frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
+
+        prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
+
+        answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
+
+        # warnings.warn("dirty test with longer answer", RuntimeWarning)
+        # answers = torch.cat(
+        # [
+        # frame_sequences[:, frame_sequences.size(1) // 2 :],
+        # frame_sequences[:, frame_sequences.size(1) // 2 :],
+        # ],
+        # dim=3,
+        # ).flatten(1)
+
+        return prompts, answers
+
+    def save_quizzes(
+        self,
+        result_dir,
+        filename_prefix,
+        prompts,
+        answers,
+        predicted_prompts=None,
+        predicted_answers=None,
+    ):
+        self.save_image(
+            result_dir,
+            filename_prefix + ".png",
+            prompts,
+            answers,
+            predicted_prompts,
+            predicted_answers,
+        )
 
 
 ######################################################################
@@ -307,12 +320,21 @@ class Sky(problem.Problem):
 if __name__ == "__main__":
     import time
 
-    sky = Sky(height=6, width=8, speed=4, nb_iterations=2)
+    sky = Sky(height=6, width=8, speed=1, nb_iterations=4)
 
-    start_time = time.perf_counter()
-    token_sequences = sky.generate_token_sequences(nb=64)
-    delay = time.perf_counter() - start_time
-    print(f"{token_sequences.size(0)/delay:02f} seq/s")
+    prompts, answers = sky.generate_prompts_and_answers(4)
+
+    predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
+    predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
+
+    sky.save_quizzes(
+        "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
+    )
+
+    # start_time = time.perf_counter()
+    # token_sequences = sky.generate_token_sequences(nb=64)
+    # delay = time.perf_counter() - start_time
+    # print(f"{token_sequences.size(0)/delay:02f} seq/s")
 
     # print(sky.seq2str(seq[:4]))
 
@@ -330,9 +352,9 @@ if __name__ == "__main__":
     # seq = (1 - m) * seq + m * 23
 
     # print(seq.size())
-    img = sky.seq2img(token_sequences)
+    img = sky.seq2img(token_sequences)
     # print(img.size())
 
-    torchvision.utils.save_image(
-        img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
-    )
+    torchvision.utils.save_image(
+    # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
+    )