Merge branch 'dev' master
authorFrançois Fleuret <francois@fleuret.org>
Sun, 14 Jul 2024 22:20:00 +0000 (00:20 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 14 Jul 2024 22:20:00 +0000 (00:20 +0200)
grids.py
main.py
problem.py
quiz_machine.py
sky.py

index 47e5861..eea8c6c 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -17,6 +17,92 @@ from torch.nn import functional as F
 import problem
 
 
+def grow_islands(nb, height, width, nb_seeds, nb_iterations):
+    w = torch.empty(5, 1, 3, 3)
+
+    w[0, 0] = torch.tensor(
+        [
+            [1.0, 1.0, 1.0],
+            [1.0, 0.0, 1.0],
+            [1.0, 1.0, 1.0],
+        ]
+    )
+
+    w[1, 0] = torch.tensor(
+        [
+            [-1.0, 1.0, 0.0],
+            [1.0, 0.0, 0.0],
+            [0.0, 0.0, 0.0],
+        ]
+    )
+
+    w[2, 0] = torch.tensor(
+        [
+            [0.0, 1.0, -1.0],
+            [0.0, 0.0, 1.0],
+            [0.0, 0.0, 0.0],
+        ]
+    )
+
+    w[3, 0] = torch.tensor(
+        [
+            [0.0, 0.0, 0.0],
+            [0.0, 0.0, 1.0],
+            [0.0, 1.0, -1.0],
+        ]
+    )
+
+    w[4, 0] = torch.tensor(
+        [
+            [0.0, 0.0, 0.0],
+            [1.0, 0.0, 0.0],
+            [-1.0, 1.0, 0.0],
+        ]
+    )
+
+    Z = torch.zeros(nb, height, width)
+    U = Z.flatten(1)
+
+    for _ in range(nb_seeds):
+        M = F.conv2d(Z[:, None, :, :], w, padding=1)
+        M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
+        M = ((M[:, 0] == 0) & (Z == 0)).long()
+        Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
+        M = M * torch.rand(M.size())
+        M = M.flatten(1)
+        M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
+        U += M * Q
+
+    for _ in range(nb_iterations):
+        M = F.conv2d(Z[:, None, :, :], w, padding=1)
+        M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
+        M = ((M[:, 1] >= 0) & (Z == 0)).long()
+        Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
+        M = M * torch.rand(M.size())
+        M = M.flatten(1)
+        M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
+        U = Z.flatten(1)
+        U += M * Q
+
+    M = Z.clone()
+    Z = Z * (torch.arange(Z.size(1) * Z.size(2)) + 1).reshape(1, Z.size(1), Z.size(2))
+
+    while True:
+        W = Z.clone()
+        Z = F.max_pool2d(Z, 3, 1, 1) * M
+        if Z.equal(W):
+            break
+
+    Z = Z.long()
+    U = Z.flatten(1)
+    V = F.one_hot(U).max(dim=1).values
+    W = V.cumsum(dim=1) - V
+    N = torch.arange(Z.size(0))[:, None, None].expand_as(Z)
+    Z = W[N, Z]
+
+    return Z
+
+
 class Grids(problem.Problem):
     named_colors = [
         ("white", [255, 255, 255]),
@@ -32,11 +118,40 @@ class Grids(problem.Problem):
         ("gray", [128, 128, 128]),
     ]
 
-    def __init__(self, device=torch.device("cpu")):
+    def __init__(
+        self,
+        max_nb_cached_chunks=None,
+        chunk_size=None,
+        nb_threads=-1,
+        tasks=None,
+    ):
         self.colors = torch.tensor([c for _, c in self.named_colors])
         self.height = 10
         self.width = 10
-        self.device = device
+        self.cache_rec_coo = {}
+
+        all_tasks = [
+            self.task_replace_color,
+            self.task_translate,
+            self.task_grow,
+            self.task_half_fill,
+            self.task_frame,
+            self.task_detect,
+            self.task_count,
+            self.task_trajectory,
+            self.task_bounce,
+            self.task_scale,
+            self.task_symbols,
+            self.task_isometry,
+            #            self.task_islands,
+        ]
+
+        if tasks is None:
+            self.all_tasks = all_tasks
+        else:
+            self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
+
+        super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
 
     ######################################################################
 
@@ -110,10 +225,10 @@ class Grids(problem.Problem):
                 c = c.long()[:, None]
                 c = (
                     (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
-                    * torch.tensor([64, 64, 64], device=c.device)
-                    + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
-                    + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
-                    + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
+                    * torch.tensor([64, 64, 64])
+                    + (c == 1).long() * torch.tensor([0, 255, 0])
+                    + (c == 0).long() * torch.tensor([255, 255, 255])
+                    + (c == -1).long() * torch.tensor([255, 0, 0])
                 )
                 y[...] = c[:, :, None, None]
 
@@ -195,121 +310,133 @@ class Grids(problem.Problem):
         return len(self.colors)
 
     # @torch.compile
-    def rec_coo_(self, nb_rec, min_height=3, min_width=3):
-        # @torch.compile
-        def overlap(ia, ja, ib, jb):
-            return (
-                ia[1] >= ib[0] and ia[0] <= ib[1] and ja[1] >= jb[0] and ja[0] <= jb[1]
-            )
+    def rec_coo(
+        self,
+        nb_rec,
+        min_height=3,
+        min_width=3,
+        surface_max=None,
+        prevent_overlap=False,
+    ):
+        if surface_max is None:
+            surface_max = self.height * self.width // 2
 
-        if nb_rec == 3:
-            while True:
-                i = torch.randint(self.height + 1, (nb_rec, 2)).sort(dim=1).values
-                j = torch.randint(self.width + 1, (nb_rec, 2)).sort(dim=1).values
-                if (
-                    not (
-                        overlap(i[0], j[0], i[1], j[1])
-                        or overlap(i[0], j[0], i[2], j[2])
-                        or overlap(i[1], j[1], i[2], j[2])
-                    )
-                    and (i[:, 1] - i[:, 0]).min() >= min_height
-                    and (j[:, 1] - j[:, 0]).min() >= min_width
-                ):
-                    break
-            return (
-                (i[0, 0], j[0, 0], i[0, 1], j[0, 1]),
-                (i[1, 0], j[1, 0], i[1, 1], j[1, 1]),
-                (i[2, 0], j[2, 0], i[2, 1], j[2, 1]),
-            )
+        signature = (nb_rec, min_height, min_width, surface_max)
 
-    # That's quite a tensorial spaghetti mess to sample
-    # non-overlapping rectangles quickly, but made the generation of
-    # 100k samples go from 1h50 with a lame pure python code to 3min30s
-    # with this one.
-    # @torch.compile
-    def rec_coo(self, nb_rec, min_height=3, min_width=3):
-        nb_trials = 200
+        try:
+            return self.cache_rec_coo[signature].pop()
+        except IndexError:
+            pass
+        except KeyError:
+            pass
 
+        N = 10000
         while True:
-            v = (
-                (
-                    torch.rand(nb_trials * nb_rec, self.height + 1, device=self.device)
-                    .sort(dim=-1)
-                    .indices
-                    < 2
-                )
-                .long()
-                .cumsum(dim=1)
-                == 1
-            ).long()
+            while True:
+                i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
+                j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
 
-            h = (
-                (
-                    torch.rand(nb_trials * nb_rec, self.width + 1, device=self.device)
-                    .sort(dim=-1)
-                    .indices
-                    < 2
+                big_enough = (
+                    (i[:, 1] >= i[:, 0] + min_height)
+                    & (j[:, 1] >= j[:, 0] + min_height)
+                    & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
                 )
-                .long()
-                .cumsum(dim=1)
-                == 1
-            ).long()
-
-            i = torch.logical_and(
-                v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width
-            )
 
-            v, h = v[i], h[i]
-            v = v[: v.size(0) - v.size(0) % nb_rec]
-            h = h[: h.size(0) - h.size(0) % nb_rec]
-            v = v.reshape(v.size(0) // nb_rec, nb_rec, -1)
-            h = h.reshape(h.size(0) // nb_rec, nb_rec, -1)
+                i, j = i[big_enough], j[big_enough]
 
-            r = v[:, :, :, None] * h[:, :, None, :]
+                n = i.size(0) - i.size(0) % nb_rec
 
-            valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1
+                if n > 0:
+                    break
 
-            v = v[valid]
-            h = h[valid]
+            i = i[:n].reshape(n // nb_rec, nb_rec, -1)
+            j = j[:n].reshape(n // nb_rec, nb_rec, -1)
+
+            if prevent_overlap:
+                can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum(
+                    dim=-1
+                ) <= self.height * self.width
+                i, j = i[can_fit], j[can_fit]
+                if nb_rec == 2:
+                    A_i1, A_i2, A_j1, A_j2 = (
+                        i[:, 0, 0],
+                        i[:, 0, 1],
+                        j[:, 0, 0],
+                        j[:, 0, 1],
+                    )
+                    B_i1, B_i2, B_j1, B_j2 = (
+                        i[:, 1, 0],
+                        i[:, 1, 1],
+                        j[:, 1, 0],
+                        j[:, 1, 1],
+                    )
+                    no_overlap = torch.logical_not(
+                        (A_i1 >= B_i2)
+                        & (A_i2 <= B_i1)
+                        & (A_j1 >= B_j1)
+                        & (A_j2 <= B_j1)
+                    )
+                    i, j = i[no_overlap], j[no_overlap]
+                elif nb_rec == 3:
+                    A_i1, A_i2, A_j1, A_j2 = (
+                        i[:, 0, 0],
+                        i[:, 0, 1],
+                        j[:, 0, 0],
+                        j[:, 0, 1],
+                    )
+                    B_i1, B_i2, B_j1, B_j2 = (
+                        i[:, 1, 0],
+                        i[:, 1, 1],
+                        j[:, 1, 0],
+                        j[:, 1, 1],
+                    )
+                    C_i1, C_i2, C_j1, C_j2 = (
+                        i[:, 2, 0],
+                        i[:, 2, 1],
+                        j[:, 2, 0],
+                        j[:, 2, 1],
+                    )
+                    no_overlap = (
+                        (
+                            (A_i1 >= B_i2)
+                            | (A_i2 <= B_i1)
+                            | (A_j1 >= B_j2)
+                            | (A_j2 <= B_j1)
+                        )
+                        & (
+                            (A_i1 >= C_i2)
+                            | (A_i2 <= C_i1)
+                            | (A_j1 >= C_j2)
+                            | (A_j2 <= C_j1)
+                        )
+                        & (
+                            (B_i1 >= C_i2)
+                            | (B_i2 <= C_i1)
+                            | (B_j1 >= C_j2)
+                            | (B_j2 <= C_j1)
+                        )
+                    )
+                    i, j = (i[no_overlap], j[no_overlap])
+                else:
+                    assert nb_rec == 1
 
-            if v.size(0) > 0:
+            if i.size(0) > 1:
                 break
 
-        av = torch.arange(v.size(2), device=self.device)[None, :]
-        ah = torch.arange(h.size(2), device=self.device)[None, :]
-
-        return [
-            (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1)
-            for i1, j1, i2, j2 in zip(
-                v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values,
-                h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values,
-                (v[0] * av).max(dim=-1).values,
-                (h[0] * ah).max(dim=-1).values,
-            )
+        self.cache_rec_coo[signature] = [
+            [
+                (
+                    i[n, k, 0].item(),
+                    j[n, k, 0].item(),
+                    i[n, k, 1].item(),
+                    j[n, k, 1].item(),
+                )
+                for k in range(nb_rec)
+            ]
+            for n in range(i.size(0))
         ]
 
-    # @torch.compile
-    def rec_coo_(self, x, n, min_height=3, min_width=3):
-        collision = x.new(x.size())
-        while True:
-            collision[...] = 0
-            result = []
-            for _ in range(n):
-                while True:
-                    i1, i2 = torch.randint(x.size(0), (2,))
-                    if i1 + min_height <= i2:
-                        break
-                while True:
-                    j1, j2 = torch.randint(x.size(1), (2,))
-                    if j1 + min_width <= j2:
-                        break
-                collision[i1:i2, j1:j2] += 1
-                if collision.max() > 1:
-                    break
-                result.append((i1, j1, i2, j2))
-            if collision.max() == 1:
-                break
-        return result
+        return self.cache_rec_coo[signature].pop()
 
     ######################################################################
 
@@ -318,7 +445,7 @@ class Grids(problem.Problem):
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(nb_rec)
+            r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
@@ -326,12 +453,16 @@ class Grids(problem.Problem):
 
     # @torch.compile
     def task_translate(self, A, f_A, B, f_B):
-        di, dj = torch.randint(3, (2,)) - 1
+        while True:
+            di, dj = torch.randint(3, (2,)) - 1
+            if di.abs() + dj.abs() > 0:
+                break
+
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
-                r = self.rec_coo(nb_rec)
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
                 i1, j1, i2, j2 = r[nb_rec - 1]
                 if (
                     i1 + di >= 0
@@ -354,10 +485,10 @@ class Grids(problem.Problem):
         di, dj = torch.randint(2, (2,)) * 2 - 1
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
-        direction = torch.randint(2, (1,))
+        direction = torch.randint(2, (1,)).item()
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
-                r = self.rec_coo(nb_rec)
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
                 i1, j1, i2, j2 = r[nb_rec - 1]
                 if i1 + 3 < i2 and j1 + 3 < j2:
                     break
@@ -376,13 +507,13 @@ class Grids(problem.Problem):
                     f_X[i1:i2, j1:j2] = c[n]
 
     # @torch.compile
-    def task_color_grow(self, A, f_A, B, f_B):
+    def task_half_fill(self, A, f_A, B, f_B):
         di, dj = torch.randint(2, (2,)) * 2 - 1
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
-        direction = torch.randint(4, (1,))
+        direction = torch.randint(4, (1,)).item()
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(nb_rec)
+            r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[2 * n]
@@ -422,20 +553,24 @@ class Grids(problem.Problem):
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(nb_rec)
+            r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
-                f_X[i1:i2, j1:j2] = c[n]
                 if n == nb_rec - 1:
-                    f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
+                    f_X[i1:i2, j1] = c[n]
+                    f_X[i1:i2, j2 - 1] = c[n]
+                    f_X[i1, j1:j2] = c[n]
+                    f_X[i2 - 1, j1:j2] = c[n]
+                else:
+                    f_X[i1:i2, j1:j2] = c[n]
 
     # @torch.compile
     def task_detect(self, A, f_A, B, f_B):
         nb_rec = 3
         c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(nb_rec)
+            r = self.rec_coo(nb_rec, prevent_overlap=True)
             for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
@@ -478,29 +613,51 @@ class Grids(problem.Problem):
 
         return no, nq, nq_diag
 
-    # @torch.compile
     def task_count(self, A, f_A, B, f_B):
-        N = (torch.randint(4, (1,)) + 2).item()
-        c = torch.randperm(len(self.colors) - 1)[:N] + 1
+        while True:
+            error = False
+
+            N = torch.randint(5, (1,)).item() + 1
+            c = torch.zeros(N + 1)
+            c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1
+
+            for X, f_X in [(A, f_A), (B, f_B)]:
+                if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
+                    self.cache_count = list(
+                        grow_islands(
+                            1000,
+                            self.height,
+                            self.width,
+                            nb_seeds=self.height * self.width // 8,
+                            nb_iterations=self.height * self.width // 10,
+                        )
+                    )
 
-        for X, f_X in [(A, f_A), (B, f_B)]:
-            nb = torch.zeros(N, dtype=torch.int64)
-            q = torch.randint(N, (self.height * self.width,))
-            k = torch.randperm(self.height * self.width)
-            for p in range(self.height * self.width):
-                i, j = k[p] % self.height, k[p] // self.height
-                no, nq, nq_diag = self.contact(X, i, j, c[q[p]])
-                if no == 0 and nq_diag == 0:
-                    if nq == 0:
-                        if nb[q[p]] < self.width:
-                            X[i, j] = c[q[p]]
-                            nb[q[p]] += 1
-                    if nq == 1:
-                        X[i, j] = c[q[p]]
-
-            for n in range(N):
-                for j in range(nb[n]):
-                    f_X[n, j] = c[n]
+                X[...] = self.cache_count.pop()
+
+                k = (X.max() + 1 + (c.size(0) - 1)).item()
+                V = torch.arange(k) // (c.size(0) - 1)
+                V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
+                    c.size(0) - 1
+                ) + 1
+                V[0] = 0
+                X[...] = c[V[X]]
+
+                if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1:
+                    f_X[...] = 0
+                    for e in range(1, N + 1):
+                        for j in range((X == c[e]).sum() + 1):
+                            if j < self.width:
+                                f_X[e - 1, j] = c[e]
+                            else:
+                                error = True
+                                break
+                else:
+                    error = True
+                    break
+
+            if not error:
+                break
 
     # @torch.compile
     def task_trajectory(self, A, f_A, B, f_B):
@@ -508,7 +665,10 @@ class Grids(problem.Problem):
         for X, f_X in [(A, f_A), (B, f_B)]:
             while True:
                 di, dj = torch.randint(7, (2,)) - 3
-                i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
+                i, j = (
+                    torch.randint(self.height, (1,)).item(),
+                    torch.randint(self.width, (1,)).item(),
+                )
                 if (
                     abs(di) + abs(dj) > 0
                     and i + 2 * di >= 0
@@ -549,8 +709,9 @@ class Grids(problem.Problem):
                 X[...] = 0
 
                 for _ in range((self.height * self.width) // 10):
-                    i, j = torch.randint(self.height, (1,)), torch.randint(
-                        self.width, (1,)
+                    i, j = (
+                        torch.randint(self.height, (1,)).item(),
+                        torch.randint(self.width, (1,)).item(),
                     )
                     X[i, j] = c[0]
                     f_X[i, j] = c[0]
@@ -560,7 +721,10 @@ class Grids(problem.Problem):
                     if abs(di) + abs(dj) == 1:
                         break
 
-                i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
+                i, j = (
+                    torch.randint(self.height, (1,)).item(),
+                    torch.randint(self.width, (1,)).item(),
+                )
 
                 X[i, j] = c[1]
                 f_X[i, j] = c[1]
@@ -598,18 +762,21 @@ class Grids(problem.Problem):
     def task_scale(self, A, f_A, B, f_B):
         c = torch.randperm(len(self.colors) - 1)[:2] + 1
 
-        i, j = torch.randint(self.height // 2, (1,)), torch.randint(
-            self.width // 2, (1,)
+        i, j = (
+            torch.randint(self.height // 2, (1,)).item(),
+            torch.randint(self.width // 2, (1,)).item(),
         )
 
         for X, f_X in [(A, f_A), (B, f_B)]:
             for _ in range(3):
                 while True:
-                    i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
-                        self.width // 2 + 1, (1,)
+                    i1, j1 = (
+                        torch.randint(self.height // 2 + 1, (1,)).item(),
+                        torch.randint(self.width // 2 + 1, (1,)).item(),
                     )
-                    i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
-                        self.width // 2 + 1, (1,)
+                    i2, j2 = (
+                        torch.randint(self.height // 2 + 1, (1,)).item(),
+                        torch.randint(self.width // 2 + 1, (1,)).item(),
                     )
                     if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
                         break
@@ -639,7 +806,7 @@ class Grids(problem.Problem):
 
             ai, aj = i.float().mean(), j.float().mean()
 
-            q = torch.randint(3, (1,)) + 1
+            q = torch.randint(3, (1,)).item() + 1
 
             X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
             X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
@@ -656,12 +823,12 @@ class Grids(problem.Problem):
             f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
 
     # @torch.compile
-    def task_ortho(self, A, f_A, B, f_B):
+    def task_isometry(self, A, f_A, B, f_B):
         nb_rec = 3
         di, dj = torch.randint(3, (2,)) - 1
         o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
         m = torch.eye(2)
-        for _ in range(torch.randint(4, (1,))):
+        for _ in range(torch.randint(4, (1,)).item()):
             m = m @ o
         if torch.rand(1) < 0.5:
             m[0, :] = -m[0, :]
@@ -710,9 +877,88 @@ class Grids(problem.Problem):
                 ):
                     break
 
+    def compute_distance(self, walls, goal_i, goal_j):
+        max_length = walls.numel()
+        dist = torch.full_like(walls, max_length)
+
+        dist[goal_i, goal_j] = 0
+        pred_dist = torch.empty_like(dist)
+
+        while True:
+            pred_dist.copy_(dist)
+            dist[1:-1, 1:-1] = (
+                torch.cat(
+                    (
+                        dist[None, 1:-1, 1:-1],
+                        dist[None, 1:-1, 0:-2],
+                        dist[None, 2:, 1:-1],
+                        dist[None, 1:-1, 2:],
+                        dist[None, 0:-2, 1:-1],
+                    ),
+                    0,
+                ).min(dim=0)[0]
+                + 1
+            )
+
+            dist = walls * max_length + (1 - walls) * dist
+
+            if dist.equal(pred_dist):
+                return dist * (1 - walls)
+
     # @torch.compile
-    def task_islands(self, A, f_A, B, f_B):
-        pass
+    def task_distance(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[:3] + 1
+        dist0 = torch.empty(self.height + 2, self.width + 2)
+        dist1 = torch.empty(self.height + 2, self.width + 2)
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            nb_rec = torch.randint(3, (1,)).item() + 1
+            while True:
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                X[...] = 0
+                f_X[...] = 0
+                for n in range(nb_rec):
+                    i1, j1, i2, j2 = r[n]
+                    X[i1:i2, j1:j2] = c[0]
+                    f_X[i1:i2, j1:j2] = c[0]
+                while True:
+                    i0, j0 = (
+                        torch.randint(self.height, (1,)).item(),
+                        torch.randint(self.width, (1,)).item(),
+                    )
+                    if X[i0, j0] == 0:
+                        break
+                while True:
+                    i1, j1 = (
+                        torch.randint(self.height, (1,)).item(),
+                        torch.randint(self.width, (1,)).item(),
+                    )
+                    if X[i1, j1] == 0:
+                        break
+                dist1[...] = 1
+                dist1[1:-1, 1:-1] = (X != 0).long()
+                dist1[...] = self.compute_distance(dist1, i1 + 1, j1 + 1)
+                if (
+                    dist1[i0 + 1, j0 + 1] >= 1
+                    and dist1[i0 + 1, j0 + 1] < self.height * 4
+                ):
+                    break
+
+            dist0[...] = 1
+            dist0[1:-1, 1:-1] = (X != 0).long()
+            dist0[...] = self.compute_distance(dist0, i0 + 1, j0 + 1)
+
+            dist0 = dist0[1:-1, 1:-1]
+            dist1 = dist1[1:-1, 1:-1]
+
+            D = dist1[i0, j0]
+            for d in range(1, D):
+                M = (dist0 == d) & (dist1 == D - d)
+                f_X[...] = (1 - M) * f_X + M * c[1]
+
+            X[i0, j0] = c[2]
+            f_X[i0, j0] = c[2]
+            X[i1, j1] = c[2]
+            f_X[i1, j1] = c[2]
 
     # for X, f_X in [(A, f_A), (B, f_B)]:
     # n = torch.arange(self.height * self.width).reshape(self.height, self.width)
@@ -722,24 +968,104 @@ class Grids(problem.Problem):
     # i,j=q%self.height,q//self.height
     # if
 
-    ######################################################################
+    # @torch.compile
+    def task_puzzle(self, A, f_A, B, f_B):
+        S = 4
+        i0, j0 = (self.height - S) // 2, (self.width - S) // 2
+        c = torch.randperm(len(self.colors) - 1)[:4] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            while True:
+                f_X[...] = 0
+                h = list(torch.randperm(c.size(0)))
+                n = torch.zeros(c.max() + 1)
+                for _ in range(2):
+                    k = torch.randperm(S * S)
+                    for q in k:
+                        i, j = q % S + i0, q // S + j0
+                        if f_X[i, j] == 0:
+                            r, s, t, u = (
+                                f_X[i - 1, j],
+                                f_X[i, j - 1],
+                                f_X[i + 1, j],
+                                f_X[i, j + 1],
+                            )
+                            r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)]
+                            if r > 0 and n[r] < 6:
+                                n[r] += 1
+                                f_X[i, j] = r
+                            elif s > 0 and n[s] < 6:
+                                n[s] += 1
+                                f_X[i, j] = s
+                            elif t > 0 and n[t] < 6:
+                                n[t] += 1
+                                f_X[i, j] = t
+                            elif u > 0 and n[u] < 6:
+                                n[u] += 1
+                                f_X[i, j] = u
+                            else:
+                                if len(h) > 0:
+                                    d = c[h.pop()]
+                                    n[d] += 1
+                                    f_X[i, j] = d
+
+                if n.sum() == S * S:
+                    break
 
-    def all_tasks(self):
-        return [
-            self.task_replace_color,
-            self.task_translate,
-            self.task_grow,
-            self.task_color_grow,
-            self.task_frame,
-            self.task_detect,
-            self.task_count,
-            self.task_trajectory,
-            self.task_bounce,
-            self.task_scale,
-            self.task_symbols,
-            self.task_ortho,
-            # self.task_islands,
-        ]
+            k = 0
+            for d in range(4):
+                while True:
+                    ii, jj = (
+                        torch.randint(self.height, (1,)).item(),
+                        torch.randint(self.width, (1,)).item(),
+                    )
+                    e = 0
+                    for i in range(S):
+                        for j in range(S):
+                            if (
+                                ii + i >= self.height
+                                or jj + j >= self.width
+                                or (
+                                    f_X[i + i0, j + j0] == c[d]
+                                    and X[ii + i, jj + j] > 0
+                                )
+                            ):
+                                e = 1
+                    if e == 0:
+                        break
+                for i in range(S):
+                    for j in range(S):
+                        if f_X[i + i0, j + j0] == c[d]:
+                            X[ii + i, jj + j] = c[d]
+
+    def task_islands(self, A, f_A, B, f_B):
+        c = torch.randperm(len(self.colors) - 1)[:2] + 1
+        for X, f_X in [(A, f_A), (B, f_B)]:
+            if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0:
+                self.cache_islands = list(
+                    grow_islands(
+                        1000,
+                        self.height,
+                        self.width,
+                        nb_seeds=self.height * self.width // 20,
+                        nb_iterations=self.height * self.width // 2,
+                    )
+                )
+
+            A = self.cache_islands.pop()
+
+            while True:
+                i, j = (
+                    torch.randint(self.height // 2, (1,)).item(),
+                    torch.randint(self.width // 2, (1,)).item(),
+                )
+                if A[i, j] > 0:
+                    break
+
+            X[...] = (A > 0) * c[0]
+            X[i, j] = c[1]
+            f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0]
+
+    ######################################################################
 
     def trivial_prompts_and_answers(self, prompts, answers):
         S = self.height * self.width
@@ -747,11 +1073,9 @@ class Grids(problem.Problem):
         f_Bs = answers
         return (Bs == f_Bs).long().min(dim=-1).values > 0
 
-    def generate_prompts_and_answers(
-        self, nb, tasks=None, progress_bar=False, device="cpu"
-    ):
+    def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
         if tasks is None:
-            tasks = self.all_tasks()
+            tasks = self.all_tasks
 
         S = self.height * self.width
         prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
@@ -772,12 +1096,12 @@ class Grids(problem.Problem):
             f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
             B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
             f_B = answer.view(self.height, self.width)
-            task = tasks[torch.randint(len(tasks), (1,))]
+            task = tasks[torch.randint(len(tasks), (1,)).item()]
             task(A, f_A, B, f_B)
 
         return prompts.flatten(1), answers.flatten(1)
 
-    def save_quizzes(
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
@@ -797,12 +1121,22 @@ class Grids(problem.Problem):
             nrow,
         )
 
+    def save_some_examples(self, result_dir):
+        nb, nrow = 72, 4
+        for t in self.all_tasks:
+            print(t.__name__)
+            prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
+            self.save_quiz_illustrations(
+                result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+            )
+
 
 ######################################################################
 
 if __name__ == "__main__":
     import time
 
+    # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
     grids = Grids()
 
     # nb = 1000
@@ -816,22 +1150,26 @@ if __name__ == "__main__":
     # print(f"{prompts.size(0)/delay:02f} seq/s")
     # exit(0)
 
-    if True:
-        nb = 72
+    # if True:
+    nb, nrow = 72, 4
+    # nb, nrow = 8, 2
 
-        for t in grids.all_tasks():
-            # for t in [grids.task_ortho]:
-            print(t.__name__)
-            prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t])
-            grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)
+    # for t in grids.all_tasks:
+    for t in [grids.task_distance]:
+        print(t.__name__)
+        prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
+        grids.save_quiz_illustrations(
+            "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+        )
 
-        exit(0)
+    # exit(0)
 
-    nb = 500
+    nb = 1000
 
-    for t in grids.all_tasks():
+    # for t in grids.all_tasks:
+    for t in [grids.task_distance]:
         start_time = time.perf_counter()
-        prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t])
+        prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
         delay = time.perf_counter() - start_time
         print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
 
@@ -841,7 +1179,7 @@ if __name__ == "__main__":
     predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
     predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
 
-    grids.save_quizzes(
+    grids.save_quiz_illustrations(
         "/tmp",
         "test",
         prompts[:nb],
diff --git a/main.py b/main.py
index 9c3d7f1..6b00bbf 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -15,22 +15,14 @@ import ffutils
 
 import mygpt
 import sky, grids, quiz_machine
-from problem import MultiThreadProblem
 
-# world quizzes vs. culture quizzes
+import threading
 
-######################################################################
-
-if torch.cuda.is_available():
-    device = torch.device("cuda")
-    torch.backends.cuda.matmul.allow_tf32 = True
-else:
-    device = torch.device("cpu")
+import torch.multiprocessing as mp
 
 ######################################################################
 
 parser = argparse.ArgumentParser(
-    description="An implementation of GPT with cache.",
     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
@@ -40,7 +32,9 @@ parser.add_argument("--result_dir", type=str, default=None)
 
 parser.add_argument("--seed", type=int, default=0)
 
-parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
+parser.add_argument("--resume", action="store_true", default=False)
+
+parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
 
 ########################################
 
@@ -54,6 +48,10 @@ parser.add_argument("--nb_train_samples", type=int, default=None)
 
 parser.add_argument("--nb_test_samples", type=int, default=None)
 
+parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
+
+parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
+
 parser.add_argument("--learning_rate", type=float, default=5e-4)
 
 ########################################
@@ -78,23 +76,34 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
 
 parser.add_argument("--problem", type=str, default="grids")
 
-parser.add_argument("--multi_thread_problem", action="store_true", default=False)
+parser.add_argument("--nb_threads", type=int, default=1)
+
+parser.add_argument("--gpus", type=str, default="all")
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
-parser.add_argument("--min_to_validate", type=int, default=None)
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
+
+parser.add_argument("--proba_understands", type=float, default=0.9)
 
-parser.add_argument("--max_to_validate", type=int, default=None)
+parser.add_argument("--proba_not_understands", type=float, default=0.5)
 
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
+parser.add_argument("--generation_temperature", type=float, default=1.0)
 
-parser.add_argument("--generation_temperature", type=float, default=2.0)
+parser.add_argument("--dirty_debug", action="store_true", default=False)
 
-parser.add_argument("--deterministic_validation", action="store_true", default=False)
+######################################################################
 
-parser.add_argument("--bidirectional_validation", action="store_true", default=False)
+grids_tasks = ", ".join(
+    [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks]
+)
 
-parser.add_argument("--dirty_debug", action="store_true", default=False)
+parser.add_argument(
+    "--grids_tasks",
+    type=str,
+    default=None,
+    help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
+)
 
 ######################################################################
 
@@ -112,12 +121,6 @@ parser.add_argument("--sky_speed", type=int, default=3)
 
 args = parser.parse_args()
 
-if args.min_to_validate is None:
-    args.min_to_validate = args.nb_gpts - 1
-
-if args.max_to_validate is None:
-    args.max_to_validate = args.nb_gpts - 1
-
 if args.result_dir is None:
     args.result_dir = f"results_culture"
 
@@ -125,7 +128,7 @@ if args.result_dir is None:
 
 default_args = {
     "model": "37M",
-    "batch_size": 100,
+    "batch_size": 25,
     "nb_train_samples": 100000,
     "nb_test_samples": 10000,
 }
@@ -183,11 +186,15 @@ else:
 
 ######################################################################
 
-try:
-    os.mkdir(args.result_dir)
-except FileExistsError:
-    print(f"result directory {args.result_dir} already exists")
-    exit(1)
+if args.resume:
+    assert os.path.isdir(args.result_dir)
+
+else:
+    try:
+        os.mkdir(args.result_dir)
+    except FileExistsError:
+        print(f"result directory {args.result_dir} already exists")
+        exit(1)
 
 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
 
@@ -213,6 +220,10 @@ def log_string(s):
     sys.stdout.flush()
 
 
+now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
+
+os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py")
+
 log_string(f"argv {' '.join(sys.argv)}")
 
 for n in vars(args):
@@ -221,6 +232,19 @@ for n in vars(args):
 
 ######################################################################
 
+if args.gpus == "all":
+    gpus_idx = range(torch.cuda.device_count())
+else:
+    gpus_idx = [int(k) for k in args.gpus.split(",")]
+
+gpus = [torch.device(f"cuda:{n}") for n in gpus_idx]
+
+if torch.cuda.is_available():
+    main_device = gpus[0]
+else:
+    assert len(gpus) == 0
+    main_device = torch.device("cpu")
+
 if args.dirty_debug:
     args.nb_train_samples = 2500
     args.nb_test_samples = 100
@@ -240,16 +264,23 @@ if args.problem == "sky":
         nb_birds=args.sky_nb_birds,
         nb_iterations=args.sky_nb_iterations,
         speed=args.sky_speed,
+        max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+        chunk_size=100,
+        nb_threads=args.nb_threads,
     )
     back_accuracy = False
 elif args.problem == "grids":
-    problem = grids.Grids(device=device)
+    problem = grids.Grids(
+        max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+        chunk_size=100,
+        nb_threads=args.nb_threads,
+        tasks=args.grids_tasks,
+    )
     back_accuracy = True
 else:
     raise ValueError
 
-if args.multi_thread_problem:
-    problem = MultiThreadProblem(problem, args.nb_train_samples, chunk_size=1000)
+problem.save_some_examples(args.result_dir)
 
 quiz_machine = quiz_machine.QuizMachine(
     problem=problem,
@@ -259,12 +290,12 @@ quiz_machine = quiz_machine.QuizMachine(
     batch_size=args.physical_batch_size,
     result_dir=args.result_dir,
     logger=log_string,
-    device=device,
+    device=main_device,
 )
 
 ######################################################################
 
-log_string(f"device {device}")
+log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
 
 vocabulary_size = quiz_machine.vocabulary_size()
 
@@ -272,64 +303,47 @@ log_string(f"vocabulary_size {vocabulary_size}")
 
 ######################################################################
 
-# Compute the entropy of the training tokens
 
-token_count = 0
-for input in quiz_machine.batches(split="train", desc="train-entropy"):
-    token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
-        (0, 1)
-    )
-token_probas = token_count / token_count.sum()
-entropy = -torch.xlogy(token_probas, token_probas).sum()
-train_set_perplexity = math.exp(entropy)
+def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_device):
+    with torch.autograd.no_grad():
+        model.eval().to(local_device)
 
-######################################################################
-# A bit of paranoia never hurts
+        nb_test_samples, acc_test_loss = 0, 0.0
+        nb_samples_accumulated = 0
 
-if args.max_percents_of_test_in_train >= 0:
+        for input in quiz_machine.batches(model, split="test"):
+            input = input.to(local_device)
 
-    def subsets_as_tuples(batches, cs):
-        s = set()
-        for batch in batches:
-            for x in batch:
-                s.add(tuple([v.item() for v in x]))
-                if len(s) == cs:
-                    yield s
-                    s = set()
-        yield s
+            bs = model(mygpt.BracketedSequence(input))
+            output = bs.x
 
-    nb_test, nb_in_train = 0, 0
-    for test_subset in subsets_as_tuples(
-        quiz_machine.batches(split="test", desc="test-check"), 25000
-    ):
-        in_train = set()
-        for train_subset in subsets_as_tuples(
-            quiz_machine.batches(split="train", desc="train-check"), 25000
-        ):
-            in_train.update(test_subset.intersection(train_subset))
-        nb_in_train += len(in_train)
-        nb_test += len(test_subset)
+            loss = F.cross_entropy(output.transpose(1, 2), input)
 
-    log_string(
-        f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
-    )
+            acc_test_loss += loss.item() * input.size(0)
 
-    assert (
-        nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
-    ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
+            nb_test_samples += input.size(0)
 
-##############################
+        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
 
+        log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
 
-def one_epoch(model, quiz_machine):
-    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+        model.main_test_accuracy = quiz_machine.produce_results(
+            n_epoch=n_epoch,
+            model=model,
+            result_dir=args.result_dir,
+            deterministic_synthesis=deterministic_synthesis,
+        )
 
-    model.train()
+
+def one_epoch(model, quiz_machine, local_device=main_device):
+    model.to(local_device).train()
+
+    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    for input in quiz_machine.batches(split="train"):
-        input = input.to(device)
+    for input in quiz_machine.batches(model, split="train"):
+        input = input.to(local_device)
 
         if nb_train_samples % args.batch_size == 0:
             optimizer.zero_grad()
@@ -347,148 +361,110 @@ def one_epoch(model, quiz_machine):
 
     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
 
-    log_string(f"train_perplexity {n_epoch} {train_perplexity}")
+    log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}")
 
+    run_tests(model, quiz_machine, deterministic_synthesis=False)
 
-######################################################################
+    model.to(main_device)
 
 
-def run_tests(model, quiz_machine, deterministic_synthesis):
-    with torch.autograd.no_grad():
-        model.eval()
-
-        nb_test_samples, acc_test_loss = 0, 0.0
-        nb_samples_accumulated = 0
-
-        for input in quiz_machine.batches(split="test"):
-            input = input.to(device)
-
-            bs = model(mygpt.BracketedSequence(input))
-            output = bs.x
+######################################################################
 
-            loss = F.cross_entropy(output.transpose(1, 2), input)
+# This is the key routine that decides what generated quizzes to keep
 
-            acc_test_loss += loss.item() * input.size(0)
 
-            nb_test_samples += input.size(0)
+# token_logprobas are NxMxT where M is the number of models
 
-        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
 
-        log_string(f"test_perplexity {n_epoch} {test_perplexity}")
+def compute_valid_quizzes_(token_logprobas):
+    warnings.warn("validation with uniform constraints", RuntimeWarning)
+    l = token_logprobas.min(dim=-1).values.sort(dim=-1).values
+    return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5))
 
-        model.main_test_accuracy = quiz_machine.produce_results(
-            n_epoch=n_epoch,
-            model=model,
-            result_dir=args.result_dir,
-            deterministic_synthesis=deterministic_synthesis,
-        )
 
+def compute_valid_quizzes(token_logprobas):
+    l = token_logprobas.sum(dim=-1).sort(dim=-1).values
+    return (l[:, 0] < math.log(args.proba_not_understands)) & (
+        l[:, 1] > math.log(args.proba_understands)
+    )
 
-######################################################################
 
+def extract_valid_quizzes_and_logprobas(recorded):
+    validated_quizzes, validated_logprobas = [], []
+    for quizzes, token_logprobas in recorded:
+        validated_indices = compute_valid_quizzes(token_logprobas)
+        validated_quizzes.append(quizzes[validated_indices])
+        validated_logprobas.append(token_logprobas[validated_indices])
 
-def valid_c_quizzes(recorded, criteria):
-    result = [q[criteria(c)] for q, c in recorded]
-    return torch.cat(result, dim=0) if len(result) > 0 else torch.tensor([])
+    if len(validated_quizzes) > 0:
+        return torch.cat(validated_quizzes, dim=0), torch.cat(
+            validated_logprobas, dim=0
+        )
+    else:
+        return None, None
 
 
 ######################################################################
 
 
-def create_c_quizzes(
-    models,
-    quiz_machine,
-    nb_for_train=1000,
-    nb_for_test=100,
-):
-    quizzes_and_nb_correct_records = []
-
+def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
     nb_to_create = nb_for_train + nb_for_test
 
-    # ------------------------------------------------------------
+    recorded_quizzes_logprobas = []
 
-    standard_validity = lambda nb_correct: torch.logical_and(
-        nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate
-    )
-
-    file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat")
+    nb_validated = 0
 
-    with open(file_name, "w") as logp_file:
-        while (
-            valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity).size(0)
-            < nb_to_create
-        ):
-            # Select a model at random to generate the new quizzes
+    while nb_validated < nb_to_create:
+        model_for_generation = models[torch.randint(len(models), (1,))]
 
-            model_for_generation = models[torch.randint(len(models), (1,))]
-
-            c_quizzes = quiz_machine.generate_quizzes(
-                nb_to_create,
-                model_for_generation=model_for_generation,
-                temperature=args.generation_temperature,
-            )
-
-            c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
-
-            if c_quizzes.size(0) > 0:
-                nb_correct, seq_logproba = quiz_machine.compute_correctness(
-                    c_quizzes,
-                    models,
-                    bidirectional_validation=args.bidirectional_validation,
-                    deterministic_validation=args.deterministic_validation,
-                )
-
-                for n, l in zip(nb_correct, seq_logproba):
-                    s = " ".join([str(x.item()) for x in l])
-                    logp_file.write(f"{n} {s}\n")
+        c_quizzes = quiz_machine.generate_quizzes(
+            nb_to_create,
+            model_for_generation=model_for_generation,
+            temperature=args.generation_temperature,
+        )
 
-                if args.dirty_debug:
-                    nb_correct = torch.randint(
-                        len(models) + 1, nb_correct.size(), device=c_quizzes.device
-                    )
+        c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
 
-                quizzes_and_nb_correct_records.append((c_quizzes, nb_correct))
+        if c_quizzes.size(0) > 0:
+            token_logproba = quiz_machine.solution_token_logprobas(models, c_quizzes)
+            recorded_quizzes_logprobas.append((c_quizzes, token_logproba))
 
-            nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0)
-            nv = " ".join([str(x.item()) for x in nv])
+            (
+                validated_quizzes,
+                validated_logprobas,
+            ) = extract_valid_quizzes_and_logprobas(recorded_quizzes_logprobas)
 
-            nb_validated = valid_c_quizzes(
-                quizzes_and_nb_correct_records, standard_validity
-            ).size(0)
+            if validated_quizzes is not None:
+                nb_validated = validated_quizzes.size(0)
 
-            log_string(
-                f"keep c_quizzes model {model_for_generation.id} kept {nv} nb_accumulated {nb_validated} / {nb_to_create}"
-            )
+        log_string(
+            f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}"
+        )
 
     # store the new c_quizzes which have been validated
 
-    new_c_quizzes = valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity)
-
-    quiz_machine.reverse_random_half_in_place(new_c_quizzes)
-
-    quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
-    quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
+    quiz_machine.reverse_random_half_in_place(validated_quizzes)
+    quiz_machine.store_c_quizzes(validated_quizzes[:nb_for_train], for_train=True)
+    quiz_machine.store_c_quizzes(
+        validated_quizzes[nb_for_train:nb_to_create], for_train=False
+    )
 
-    # save a bunch of images to investigate what quizzes with a
-    # certain nb of correct predictions look like
+    ######################################################################
+    # save images with their logprobas
 
-    for n in range(len(models) + 1):
-        s = (
-            "_validated"
-            if n >= args.min_to_validate and n <= args.max_to_validate
-            else ""
-        )
+    vq = validated_quizzes[:72]
+    vl = validated_logprobas[:72]
 
-        q = valid_c_quizzes(
-            quizzes_and_nb_correct_records, criteria=lambda nb_correct: nb_correct == n
-        )[:72]
+    if vq.size(0) > 0:
+        prefix = f"culture_c_quiz_{n_epoch:04d}"
+        filename = os.path.join(args.result_dir, prefix + "_logp.pth")
+        torch.save(vl, filename)
+        # with open(file_name, "w") as logp_file:
+        # for l in vl:
+        # s = " ".join([str(x.item()) for x in l])
+        # logp_file.write(s + "\n")
 
-        quiz_machine.reverse_random_half_in_place(q)
-
-        if q.size(0) > 0:
-            quiz_machine.save_quizzes(
-                args.result_dir, f"culture_c_quiz_{n_epoch:04d}_N{n}{s}", q
-            )
+        quiz_machine.save_quiz_illustrations(args.result_dir, prefix, vq)
 
 
 ######################################################################
@@ -496,6 +472,7 @@ def create_c_quizzes(
 models = []
 
 for k in range(args.nb_gpts):
+    log_string(f"creating model {k} and its w_quizzes")
     model = mygpt.MyGPT(
         vocabulary_size=vocabulary_size,
         dim_model=args.dim_model,
@@ -505,24 +482,109 @@ for k in range(args.nb_gpts):
         nb_blocks=args.nb_blocks,
         causal=True,
         dropout=args.dropout,
-    ).to(device)
+    ).to(main_device)
 
     model.main_test_accuracy = 0.0
     model.id = k
 
+    model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples)
+    quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
+    model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
+    quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
+
     models.append(model)
 
+######################################################################
+
+if args.resume:
+    try:
+        for model in models:
+            filename = f"gpt_{model.id:03d}.pth"
+
+            try:
+                d = torch.load(os.path.join(args.result_dir, filename))
+                model.load_state_dict(d[0])
+                model.main_test_accuracy = d[1]
+                log_string(f"successfully loaded {filename}")
+            except FileNotFoundError:
+                log_string(f"cannot find {filename}")
+                pass
+
+        try:
+            filename = "c_quizzes.pth"
+            quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename))
+            log_string(f"successfully loaded {filename}")
+        except FileNotFoundError:
+            log_string(f"cannot find {filename}")
+            pass
+
+    except:
+        log_string(f"error when loading {filename}.")
+        exit(1)
+
+######################################################################
 
 nb_parameters = sum(p.numel() for p in models[0].parameters())
 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 ######################################################################
 
-nb_new_c_quizzes_for_train = args.nb_train_samples // 50
-nb_new_c_quizzes_for_test = args.nb_test_samples // 50
+# Compute the entropy of the training tokens
+
+token_count = 0
+for input in quiz_machine.batches(models[0], split="train", desc="train-entropy"):
+    token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
+        (0, 1)
+    )
+token_probas = token_count / token_count.sum()
+entropy = -torch.xlogy(token_probas, token_probas).sum()
+train_set_perplexity = math.exp(entropy)
+
+######################################################################
+# A bit of paranoia never hurts
+
+if args.max_percents_of_test_in_train >= 0:
+
+    def subsets_as_tuples(batches, cs):
+        s = set()
+        for batch in batches:
+            for x in batch:
+                s.add(tuple([v.item() for v in x]))
+                if len(s) == cs:
+                    yield s
+                    s = set()
+        yield s
+
+    nb_test, nb_in_train = 0, 0
+    for test_subset in subsets_as_tuples(
+        quiz_machine.batches(models[0], split="test", desc="test-check"), 25000
+    ):
+        in_train = set()
+        for train_subset in subsets_as_tuples(
+            quiz_machine.batches(models[0], split="train", desc="train-check"), 25000
+        ):
+            in_train.update(test_subset.intersection(train_subset))
+        nb_in_train += len(in_train)
+        nb_test += len(test_subset)
+
+    log_string(
+        f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
+    )
+
+    assert (
+        nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
+    ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
+
+######################################################################
+
+if args.nb_new_c_quizzes_for_train is None:
+    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50
+
+if args.nb_new_c_quizzes_for_test is None:
+    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50
 
 log_string(
-    f"nb_new_c_quizzes_for_train {nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {nb_new_c_quizzes_for_test}"
+    f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
 )
 
 ######################################################################
@@ -530,8 +592,9 @@ log_string(
 if args.dirty_debug:
     args.accuracy_to_make_c_quizzes = 0.0
     args.nb_gpts = 2
-    nb_new_c_quizzes_for_train = 100
-    nb_new_c_quizzes_for_test = 10
+    args.nb_new_c_quizzes_for_train = 100
+    args.nb_new_c_quizzes_for_test = 10
+
 
 ######################################################################
 
@@ -541,46 +604,63 @@ for n_epoch in range(args.nb_epochs):
     cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
     log_string(f"current_test_accuracies {cta}")
 
+    ##################################################
+    # If all the models are good enough, generate new quizzes and
+    # re-compute the test errors
+
+    if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
+        create_c_quizzes(
+            models,
+            quiz_machine,
+            nb_for_train=args.nb_new_c_quizzes_for_train,
+            nb_for_test=args.nb_new_c_quizzes_for_test,
+        )
+
+        filename = "c_quizzes.pth"
+        quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename))
+        log_string(f"wrote {filename}")
+
+        # Force one epoch of training
+        for model in models:
+            model.main_test_accuracy = 0.0
+
     ##################################################
     # Select, improve, and eval the worst model
 
-    weakest_model = min(models, key=lambda m: float(m.main_test_accuracy))
+    ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
 
-    log_string(
-        f"training model {weakest_model.id} main_test_accuracy {weakest_model.main_test_accuracy}"
-    )
+    weakest_models = ranked_models[: len(gpus)]
 
-    one_epoch(weakest_model, quiz_machine)
+    threads = []
 
-    log_string(
-        f"train_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}"
-    )
+    for gpu, model in zip(gpus, weakest_models):
+        log_string(f"training model {model.id}")
 
-    run_tests(weakest_model, quiz_machine, deterministic_synthesis=False)
+        t = threading.Thread(
+            target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
+        )
 
-    log_string(
-        f"test_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}"
-    )
+        threads.append(t)
 
-    ##################################################
-    # Replace a fraction of the w_quizzes with fresh ones
+        t.start()
 
-    quiz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
+    for t in threads:
+        t.join()
 
-    ##################################################
-    # If all the models are good enough, generate new quizzes and
-    # re-compute the test errors
+    # Save the models to disk
 
-    if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
-        create_c_quizzes(
-            models,
-            quiz_machine,
-            nb_for_train=nb_new_c_quizzes_for_train,
-            nb_for_test=nb_new_c_quizzes_for_test,
+    for model in weakest_models:
+        filename = f"gpt_{model.id:03d}.pth"
+        torch.save(
+            (model.state_dict(), model.main_test_accuracy),
+            os.path.join(args.result_dir, filename),
         )
+        log_string(f"wrote {filename}")
 
-        for model in models:
-            run_tests(model, quiz_machine, deterministic_synthesis=False)
+    # Renew the training samples
+
+    for model in weakest_models:
+        quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
 
 
 ######################################################################
index a49634d..05f3b20 100755 (executable)
@@ -9,18 +9,34 @@ import threading, queue, torch, tqdm
 
 
 class Problem:
+    def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1):
+        if nb_threads > 0:
+            self.chunk_size = chunk_size
+            self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
+            for _ in range(nb_threads):
+                threading.Thread(target=self.fill_cache, daemon=True).start()
+            self.rest = None
+        else:
+            self.queue = None
+
+    def nb_cached_quizzes(self):
+        if self.queue is None:
+            return None
+        else:
+            return self.queue.qsize() * self.chunk_size
+
     def nb_token_values(self):
         pass
 
     def trivial_prompts_and_answers(self, prompts, answers):
         pass
 
-    # returns two tensors nb x D and nb x D'
-    def generate_prompts_and_answers(self, nb):
+    # The one to implement, returns two tensors nb x D and nb x D'
+    def generate_prompts_and_answers_(self, nb):
         pass
 
     # save a file to vizualize quizzes, you can save a txt or png file
-    def save_quizzes(
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
@@ -31,49 +47,16 @@ class Problem:
     ):
         pass
 
-
-class MultiThreadProblem:
-    def __init__(self, problem, max_nb_cached_chunks, chunk_size, nb_threads=1):
-        self.problem = problem
-        self.chunk_size = chunk_size
-        self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
-        for _ in range(nb_threads):
-            threading.Thread(target=self.fill_cache, daemon=True).start()
-        self.rest = None
-
-    def nb_token_values(self):
-        return self.problem.nb_token_values()
-
-    def save_quizzes(
-        self,
-        result_dir,
-        filename_prefix,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
-    ):
-        self.problem.save_quizzes(
-            result_dir,
-            filename_prefix,
-            prompts,
-            answers,
-            predicted_prompts=None,
-            predicted_answers=None,
-        )
-
     def fill_cache(self):
         while True:
-            prompts, answers = self.problem.generate_prompts_and_answers(
-                self.chunk_size
-            )
+            prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
 
             self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
 
-    def trivial_prompts_and_answers(self, prompts, answers):
-        return self.problem.trivial_prompts_and_answers(prompts, answers)
-
     def generate_prompts_and_answers(self, nb):
+        if self.queue is None:
+            return self.generate_prompts_and_answers_(nb)
+
         if self.rest is not None:
             prompts, answers = rest
         else:
@@ -105,3 +88,6 @@ class MultiThreadProblem:
             prompts, answers = prompts[:-k], answers[:-k]
 
         return prompts, answers
+
+    def save_some_examples(self, result_dir):
+        pass
index f0fb408..bc468d3 100755 (executable)
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, os, tqdm, warnings
+import math, os, tqdm, warnings, sys
 
 import torch, torchvision
 
@@ -15,6 +15,38 @@ from torch.nn import functional as F
 import mygpt
 from mygpt import BracketedSequence
 
+import threading
+
+######################################################################
+# if output is log(P(X=y)) and target is Y, returns -log P(X=Y) + H(X
+# | X != Y)
+
+
+# output is NxCxT and target is NxT
+def confusion(output, target, reduction="mean"):
+    N, C, T = output.shape
+    output = output.permute(0, 2, 1).reshape(-1, C)
+    target = target.flatten()
+    all_t = torch.arange(N * T, device=output.device)
+    output = output.log_softmax(dim=-1)
+    result = -output[all_t, target]
+
+    output[all_t, target] = float("-inf")
+    output = output.log_softmax(dim=-1)
+    e = output.exp()
+    output[all_t, target] = 0
+    result = result - (output * e).sum(-1)
+
+    if reduction == "none":
+        return result.reshape(N, T)
+    elif reduction == "mean":
+        return result.reshape(N, T).mean()
+    elif reduction == "sum":
+        return result.reshape(N, T).sum()
+    else:
+        raise ValueError(f"unknown reduction '{reduction}'.")
+
+
 ######################################################################
 
 # ar_mask is a tensor with 0s and 1s, of same shape as input, with
@@ -235,32 +267,18 @@ class QuizMachine:
         self.prompt_len = None
         self.answer_len = None
 
-        self.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
-        self.reverse_random_half_in_place(self.train_w_quizzes)
-        self.train_w_quizzes = self.train_w_quizzes.to(device)
-
-        self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
-        self.reverse_random_half_in_place(self.test_w_quizzes)
-        self.test_w_quizzes = self.test_w_quizzes.to(device)
-
+        self.LOCK_C_QUIZZES = threading.Lock()
         self.train_c_quizzes = []
         self.test_c_quizzes = []
 
-        if result_dir is not None:
-            self.save_quizzes(
-                result_dir,
-                "culture_w_quizzes",
-                self.train_w_quizzes[:72],
-            )
-
-    def save_quizzes(
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
         quizzes,
         mistakes=None,
     ):
-        quizzes = quizzes.clone()
+        quizzes = quizzes.clone().to("cpu")
         n_forward = quizzes[quizzes[:, 0] == self.token_forward]
         n_backward = quizzes[:, 0] == self.token_backward
         backward = quizzes[n_backward]
@@ -271,14 +289,14 @@ class QuizMachine:
         predicted_answers = 1 - predicted_prompts
         if mistakes is not None:
             # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
-            predicted_prompts *= mistakes
-            predicted_answers *= mistakes
+            predicted_prompts *= mistakes.to("cpu")
+            predicted_answers *= mistakes.to("cpu")
         else:
             # 0/2 ~ not-to-predict / to predict
             predicted_prompts *= 2
             predicted_answers *= 2
 
-        self.problem.save_quizzes(
+        self.problem.save_quiz_illustrations(
             result_dir,
             filename_prefix,
             quizzes[:, 1 : 1 + self.prompt_len],
@@ -287,34 +305,41 @@ class QuizMachine:
             predicted_answers,
         )
 
-    def batches(self, split="train", desc=None):
-        assert split in {"train", "test"}
-        if split == "train":
-            w_quizzes = self.train_w_quizzes
-            c_quizzes = self.train_c_quizzes
-        else:
-            w_quizzes = self.test_w_quizzes
-            c_quizzes = self.test_c_quizzes
+    def vocabulary_size(self):
+        return self.nb_token_values
 
-        if len(c_quizzes) > 0:
-            c_quizzes = torch.cat(c_quizzes, dim=0)
-            if c_quizzes.size(0) > w_quizzes.size(0) // 2:
-                i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
-                c_quizzes = c_quizzes[i]
+    ######################################################################
 
-            i = torch.randperm(w_quizzes.size(0))[
-                : w_quizzes.size(0) - c_quizzes.size(0)
-            ]
-            w_quizzes = w_quizzes[i]
+    def batches(self, model, split="train", desc=None):
+        assert split in {"train", "test"}
 
-            self.nb_batch_w_quizzes = w_quizzes.size(0)
-            self.nb_batch_c_quizzes = c_quizzes.size(0)
+        with self.LOCK_C_QUIZZES:
+            if split == "train":
+                w_quizzes = model.train_w_quizzes
+                c_quizzes = self.train_c_quizzes
+            else:
+                w_quizzes = model.test_w_quizzes
+                c_quizzes = self.test_c_quizzes
+
+            if len(c_quizzes) > 0:
+                c_quizzes = torch.cat(c_quizzes, dim=0)
+                if c_quizzes.size(0) > w_quizzes.size(0) // 2:
+                    i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
+                    c_quizzes = c_quizzes[i]
+
+                i = torch.randperm(w_quizzes.size(0))[
+                    : w_quizzes.size(0) - c_quizzes.size(0)
+                ]
+                w_quizzes = w_quizzes[i]
 
-            input = torch.cat([w_quizzes, c_quizzes], dim=0)
-        else:
-            input = w_quizzes
-            self.nb_batch_w_quizzes = w_quizzes.size(0)
-            self.nb_batch_c_quizzes = 0
+                self.nb_batch_w_quizzes = w_quizzes.size(0)
+                self.nb_batch_c_quizzes = c_quizzes.size(0)
+
+                input = torch.cat([w_quizzes, c_quizzes], dim=0)
+            else:
+                input = w_quizzes
+                self.nb_batch_w_quizzes = w_quizzes.size(0)
+                self.nb_batch_c_quizzes = 0
 
         # Shuffle
         input = input[torch.randperm(input.size(0))]
@@ -326,13 +351,13 @@ class QuizMachine:
         ):
             yield batch
 
-    def vocabulary_size(self):
-        return self.nb_token_values
+    ######################################################################
 
     def produce_results(
         self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
     ):
         def compute_accuracy(input, log_prefix=None):
+            input = input.to(self.device)
             ar_mask = self.make_ar_mask(input)
             result = input.clone() * (1 - ar_mask)
             seq_logproba = torch.empty(input.size(0), device=self.device)
@@ -373,19 +398,15 @@ class QuizMachine:
                 backward_nb_total = correct[n_backward].size(0)
 
                 self.logger(
-                    f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)"
-                )
-
-                self.logger(
-                    f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)"
+                    f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}"
                 )
 
             return result, correct
 
-        compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train")
+        # compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
 
         test_result, test_correct = compute_accuracy(
-            self.test_w_quizzes[:nmax], log_prefix="test"
+            model.test_w_quizzes[:nmax], log_prefix="test"
         )
 
         main_test_accuracy = test_correct.sum() / test_correct.size(0)
@@ -393,7 +414,7 @@ class QuizMachine:
 
         ##############################
 
-        self.save_quizzes(
+        self.save_quiz_illustrations(
             result_dir,
             f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
             quizzes=test_result[:72],
@@ -402,19 +423,65 @@ class QuizMachine:
 
         return main_test_accuracy
 
-    def renew_w_quizzes(self, nb, for_train=True):
-        input = self.train_w_quizzes if for_train else self.test_w_quizzes
+    ######################################################################
+
+    def renew_w_quizzes(self, model, nb, for_train=True):
+        input = model.train_w_quizzes if for_train else model.test_w_quizzes
         nb = min(nb, input.size(0))
         input[:-nb] = input[nb:].clone()
         fresh_w_quizzes = self.generate_token_sequences(nb)
         self.reverse_random_half_in_place(fresh_w_quizzes)
-        input[-nb:] = fresh_w_quizzes.to(self.device)
+        input[-nb:] = fresh_w_quizzes.to("cpu")
+
+    ######################################################################
 
     def store_c_quizzes(self, new_c_quizzes, for_train=True):
-        if for_train:
-            self.train_c_quizzes.append(new_c_quizzes)
-        else:
-            self.test_c_quizzes.append(new_c_quizzes)
+        with self.LOCK_C_QUIZZES:
+            if for_train:
+                self.train_c_quizzes.append(new_c_quizzes.to("cpu"))
+            else:
+                self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
+
+    def save_c_quizzes(self, filename):
+        torch.save((self.train_c_quizzes, self.test_c_quizzes), filename)
+
+    def load_c_quizzes(self, filename):
+        self.train_c_quizzes, self.test_c_quizzes = torch.load(filename)
+
+    ######################################################################
+
+    def solution_token_logprobas(self, models, c_quizzes):
+        logproba = c_quizzes.new_zeros(
+            c_quizzes.size(0),
+            len(models),
+            c_quizzes.size(1),
+            device=self.device,
+            dtype=torch.float32,
+        )
+
+        for model in models:
+            with torch.autograd.no_grad():
+                t = model.training
+                model.eval()
+
+                for input, l in zip(
+                    c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+                ):
+                    input = input.to(self.device)
+                    ar_mask = self.make_ar_mask(input)
+                    output = model(mygpt.BracketedSequence(input)).x
+                    l[:, model.id] = (
+                        -F.cross_entropy(
+                            output.transpose(1, 2), input, reduction="none"
+                        )
+                        * ar_mask
+                    )
+
+                model.train(t)
+
+        return logproba.to("cpu")
+
+    ###############################################################
 
     def compute_correctness(
         self,
@@ -488,7 +555,10 @@ class QuizMachine:
 
     def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
         c_quizzes = torch.empty(
-            nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
+            nb,
+            self.prompt_len + self.answer_len + 2,
+            device=self.device,
+            dtype=torch.int64,
         )
 
         seq_logproba = torch.zeros(nb, device=self.device)
@@ -538,4 +608,4 @@ class QuizMachine:
             device=self.device,
         )
 
-        return c_quizzes
+        return c_quizzes.to("cpu")
diff --git a/sky.py b/sky.py
index ed440d3..cc5bd4f 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -50,7 +50,11 @@ class Sky(problem.Problem):
         speed=2,
         nb_iterations=2,
         avoid_collision=True,
+        max_nb_cached_chunks=None,
+        chunk_size=None,
+        nb_threads=-1,
     ):
+        super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
         self.height = height
         self.width = width
         self.nb_birds = nb_birds
@@ -296,7 +300,7 @@ class Sky(problem.Problem):
 
         return prompts, answers
 
-    def save_quizzes(
+    def save_quiz_illustrations(
         self,
         result_dir,
         filename_prefix,
@@ -327,7 +331,7 @@ if __name__ == "__main__":
     predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
     predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
 
-    sky.save_quizzes(
+    sky.save_quiz_illustrations(
         "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
     )