Update. diffusion
authorFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 11:20:34 +0000 (13:20 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 11:20:34 +0000 (13:20 +0200)
grids.py
main.py
mygpt.py [deleted file]
sky.py [deleted file]
wireworld.py [deleted file]

index 4254b32..5e623cb 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -384,6 +384,9 @@ class Grids(problem.Problem):
 
     ######################################################################
 
+    def vocabulary_size(self):
+        return self.nb_token_values
+
     def grid2img(self, x, scale=15, grids=True):
         m = torch.logical_and(x >= 0, x < self.nb_colors).long()
         y = self.colors[x * m].permute(0, 3, 1, 2)
diff --git a/main.py b/main.py
index 0c40f95..ef340ea 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -11,10 +11,7 @@ import torch, torchvision
 from torch import nn
 from torch.nn import functional as F
 
-import ffutils
-
-import mygpt
-import sky, grids
+import ffutils, grids, attae
 
 import threading, subprocess
 
@@ -313,7 +310,7 @@ def quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
 
 log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
 
-vocabulary_size = problem.nb_token_values
+vocabulary_size = problem.vocabulary_size()
 
 log_string(f"vocabulary_size {vocabulary_size}")
 
@@ -640,8 +637,6 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
 
 ######################################################################
 
-import attae
-
 models = []
 
 for i in range(args.nb_models):
diff --git a/mygpt.py b/mygpt.py
deleted file mode 100755 (executable)
index 5b56264..0000000
--- a/mygpt.py
+++ /dev/null
@@ -1,475 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-# This is an implementation from scratch of a "GPT", that is a model
-# composed of several causal self-attention blocks. It is equipped
-# with a caching mechanism for keys and values to avoid a O(N^3) cost
-# for auto-regression.
-
-import math
-
-import torch
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-
-class BSQ(nn.Module):
-    def __init__(self, L):
-        super().__init__()
-        self.L = L
-
-    def forward(self, input, indexes=False):
-        norm = input.pow(2).sum(dim=2, keepdim=True).sqrt()
-        u = input / norm
-
-        if indexes:
-            return ((u >= 0).long() * (2 ** torch.arange(self.L))[None, :]).sum(dim=1)
-
-        hat_u = 1 / math.sqrt(self.L) * (2 * (u >= 0).float() - 1)
-        if self.training:
-            self.loss += u.mean(dim=0).tanh().pow(2).mean()
-            return hat_u + u - u.detach()
-        else:
-            return hat_u
-
-
-class RandomBypass(nn.Module):
-    def __init__(self, m, p):
-        super().__init__()
-        self.m = m
-        self.p = p
-
-    def forward(self, x):
-        y = self.m(x)
-
-        if self.training:
-            u = (torch.rand(x.size(0), device=x.device) <= self.p).long()[:, None]
-            return (u * x.flatten(1) + (1 - u) * y.flatten(1)).reshape(x.size())
-        else:
-            return y
-
-
-######################################################################
-
-# A BracketedSequence is a BxTx... tensor with a first and a nb time
-# steps to compute.
-
-# Modules able to process it expect that they will have to process a
-# first bracket starting at t=0, followed by a succession of brackets
-# that move forward in time, do not overlap, and cover the axis T with
-# no holes.
-#
-# Although it is more general, for a classical prompt-conditioned
-# auto-regressive process it will be a first bracket starting at 0 and
-# of arbitrary length for the "prompt", followed by brackets of length
-# 1 for the successive tokens.
-#
-# Modules able to process brackets may implement a cache that is
-# resetted when the input bracket starts at t=0
-
-
-class BracketedSequence:
-    def __init__(self, x, first=None, nb=None):
-        self.x = x
-        self.first = 0 if first is None else first
-        self.nb = x.size(1) if nb is None else nb
-
-    def slice(self):
-        return self.x[:, self.first : self.first + self.nb]
-
-    def complete(self):
-        return self.first == 0 and self.nb == self.x.size(1)
-
-
-######################################################################
-
-
-class CacheWrapper(nn.Module):
-    def __init__(self, *f):
-        super().__init__()
-        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
-    def forward(self, bs):
-        if bs.first == 0:
-            y = self.f(bs.slice())
-            self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:]))
-            self.cache_y[:, bs.first : bs.first + bs.nb] = y
-        else:
-            self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
-
-        return BracketedSequence(self.cache_y, bs.first, bs.nb)
-
-
-##############################
-
-
-class CachedWithResidual(nn.Module):
-    def __init__(self, *f):
-        super().__init__()
-        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
-    def forward(self, bs):
-        return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
-
-
-##############################
-
-
-class CachedVaswaniPositionalEncoding(nn.Module):
-    def __init__(self, len_max):
-        super().__init__()
-        self.len_max = len_max
-
-    # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
-
-    def forward(self, bs):
-        if bs.first == 0:
-            t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[
-                :, None
-            ]
-            j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[
-                None, :
-            ]
-            k = j % 2
-            self.pe = torch.sin(
-                t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k
-            )
-            self.cache_y = bs.x.new(bs.x.size())
-
-        self.cache_y[:, bs.first : bs.first + bs.nb] = (
-            bs.slice() + self.pe[bs.first : bs.first + bs.nb]
-        )
-
-        return BracketedSequence(self.cache_y, bs.first, bs.nb)
-
-
-##############################
-
-
-class TrainablePositionalEncoding(nn.Module):
-    def __init__(self, dim, len_max):
-        super().__init__()
-        self.len_max = len_max
-        self.pe = nn.Parameter(torch.randn(1, len_max, dim) / math.sqrt(dim))
-
-    def forward(self, bs):
-        if bs.first == 0:
-            self.cache_y = bs.x.new(bs.x.size())
-
-        self.cache_y[:, bs.first : bs.first + bs.nb] = (
-            bs.slice() + self.pe[:, bs.first : bs.first + bs.nb, :]
-        )
-
-        return BracketedSequence(self.cache_y, bs.first, bs.nb)
-
-
-##############################
-
-
-class EncoderHead(nn.Module):
-    def __init__(self, dim_in, dim_out):
-        super().__init__()
-        self.fc = nn.Linear(dim_in, dim_out)
-
-    def forward(self, bs):
-        z = self.fc(bs.x).mean(dim=1)
-        return z, bs.x.shape
-
-
-class DecoderBottom(nn.Module):
-    def __init__(self, dim_in, dim_out):
-        super().__init__()
-        self.fc = nn.Linear(dim_in, dim_out)
-
-    def forward(self, z_shape):
-        z, shape = z_shape
-        y = self.fc(z)[:, None, :].expand(shape)
-        return BracketedSequence(y)
-
-
-##############################
-
-
-class QKVAttention(nn.Module):
-    def __init__(
-        self,
-        dim_in,
-        dim_qk,
-        dim_v,
-        nb_heads=1,
-        compute_attzero=None,
-        attention_dropout=0.0,
-    ):
-        super().__init__()
-
-        def randw(*d):
-            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
-
-        self.compute_attzero = compute_attzero
-        self.attention_dropout = attention_dropout
-        self.record_attention = False
-
-        self.w_q = randw(nb_heads, dim_qk, dim_in)
-        self.w_k = randw(nb_heads, dim_qk, dim_in)
-        self.w_v = randw(nb_heads, dim_v, dim_in)
-        self.w_o = randw(dim_v * nb_heads, dim_in)
-
-    def forward(self, bs_q, bs_kv=None):
-        if bs_kv is None:
-            bs_kv = bs_q
-
-        x_q = bs_q.x
-        x_kv = bs_kv.x
-
-        if bs_kv.first == 0:
-            self.cache_k = x_kv.new_zeros(
-                x_kv.size(0), self.w_k.size(0), x_kv.size(1), self.w_k.size(1)
-            )
-            self.cache_v = x_kv.new_zeros(
-                x_kv.size(0), self.w_v.size(0), x_kv.size(1), self.w_v.size(1)
-            )
-
-        if bs_q.first == 0:
-            self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
-
-        q = torch.einsum(
-            "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
-        )
-
-        self.cache_k[:, :, bs_kv.first : bs_kv.first + bs_kv.nb] = torch.einsum(
-            "ntc,hdc->nhtd", x_kv[:, bs_kv.first : bs_kv.first + bs_kv.nb], self.w_k
-        )
-        self.cache_v[:, :, bs_kv.first : bs_kv.first + bs_kv.nb] = torch.einsum(
-            "ntc,hdc->nhtd", x_kv[:, bs_kv.first : bs_kv.first + bs_kv.nb], self.w_v
-        )
-
-        a = torch.einsum(
-            "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_kv.first + bs_kv.nb]
-        ) / math.sqrt(self.w_q.size(1))
-
-        if self.compute_attzero is not None:
-            if bs_q.first == 0:
-                self.cache_attzero = self.compute_attzero(
-                    torch.arange(x_q.size(1), device=q.device)[:, None],
-                    torch.arange(x_kv.size(1), device=q.device)[None, :],
-                )[None, None, :, :]
-            a = a.masked_fill(
-                self.cache_attzero[
-                    :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_kv.first + bs_kv.nb
-                ],
-                float("-inf"),
-            )
-
-        a = a.softmax(dim=3)
-
-        if self.record_attention:
-            self.a = a
-
-        a = F.dropout(a, self.attention_dropout, self.training)
-
-        y = torch.einsum(
-            "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_kv.first + bs_kv.nb]
-        ).flatten(2)
-
-        self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
-
-        return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb)
-
-
-##############################
-
-
-class NoiseInjector(nn.Module):
-    def __init__(self, identifier=None):
-        super().__init__()
-        self.noise_std = 0.0
-        self.identifier = identifier
-
-    def forward(self, x):
-        if self.noise_std > 0:
-            x = x * (
-                1 - 2 * (torch.rand(x.size(), device=x.device) < self.noise_std).long()
-            )
-        return x
-
-
-##############################
-
-
-class BlockSummarizer(nn.Module):
-    def __init__(self, nb_blocks, nb_tokens, dim_keys, dim_model):
-        self.nb_blocks = nb_blocks
-        self.static_q = nn.Parameter(nb_blocks - 1, nb_tokens, dim_keys)
-
-        def compute_block_attzero(t_q, t_k):
-            block_size = t_q.size(0)
-            return (t_q // block_size) <= (t_k // block_size)
-
-        self.qkv = QKVAttention(
-            dim_in=dim_model,
-            dim_qk=dim_keys,
-            dim_v=dim_model // nb_heads,
-            nb_heads=nb_heads,
-            compute_attzero=compute_attzero,
-            attention_dropout=dropout,
-        )
-
-    def forward(self, bs):
-        pass
-
-
-class ShiftByOne(nn.Module):
-    def __init__(self):
-        super().__init__()
-
-    def forward(self, bs):
-        return BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
-
-
-class MyGPT(nn.Module):
-    def __init__(
-        self,
-        vocabulary_size,
-        dim_model,
-        dim_keys,
-        dim_hidden,
-        nb_heads,
-        nb_blocks,
-        compute_attzero=None,
-        dropout=0.0,
-        len_max=1e5,
-    ):
-        super().__init__()
-
-        assert dim_model % nb_heads == 0
-
-        self.temperature = 1.0
-
-        self.shifter = ShiftByOne()
-
-        self.embedding = nn.Sequential(
-            CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
-        )
-
-        self.positional_encoding = CachedVaswaniPositionalEncoding(len_max)
-
-        trunk_blocks = []
-
-        for b in range(nb_blocks):
-            trunk_blocks += [
-                CachedWithResidual(
-                    CacheWrapper(
-                        nn.LayerNorm((dim_model,)),
-                        NoiseInjector(identifier=("attention", b)),
-                    ),
-                    QKVAttention(
-                        dim_in=dim_model,
-                        dim_qk=dim_keys,
-                        dim_v=dim_model // nb_heads,
-                        nb_heads=nb_heads,
-                        compute_attzero=compute_attzero,
-                        attention_dropout=dropout,
-                    ),
-                ),
-                CachedWithResidual(
-                    CacheWrapper(
-                        nn.LayerNorm((dim_model,)),
-                        NoiseInjector(identifier=("ffw", b)),
-                        nn.Linear(in_features=dim_model, out_features=dim_hidden),
-                        nn.ReLU(),
-                        nn.Linear(in_features=dim_hidden, out_features=dim_model),
-                        nn.Dropout(dropout),
-                    ),
-                ),
-            ]
-
-        self.trunk = nn.Sequential(*trunk_blocks)
-
-        self.readout = CacheWrapper(
-            nn.Linear(in_features=dim_model, out_features=vocabulary_size)
-        )
-
-        with torch.no_grad():
-            for m in self.modules():
-                if isinstance(m, nn.Embedding):
-                    m.weight.normal_(mean=0, std=2e-2)
-                elif isinstance(m, nn.LayerNorm):
-                    m.bias.zero_()
-                    m.weight.fill_(1.0)
-
-    def forward(self, bs):
-        for m in self.modules():
-            m.loss = 0
-
-        bs = self.shifter(bs)
-        bs = self.embedding(bs)
-        bs = self.positional_encoding(bs)
-        bs = self.trunk(bs)
-        bs = self.readout(bs)
-        bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature
-
-        for m in self.modules():
-            self.loss += m.loss
-
-        return bs
-
-    def reset_transformations(self):
-        self.temperature = 1.0
-        for m in self.modules():
-            if isinstance(m, NoiseInjector):
-                m.noise_std = 0.0
-
-    def set_noise_injection(self, noise_std, identifier=None):
-        for m in self.modules():
-            if isinstance(m, NoiseInjector):
-                if identifier is None or identifier == m.identifier:
-                    m.noise_std = noise_std
-
-    def record_attention(self, v=True):
-        for m in self.modules():
-            if isinstance(m, QKVAttention):
-                m.record_attention = v
-
-    def retrieve_attention(self):
-        a = []
-        for m in self.modules():
-            if isinstance(m, QKVAttention):
-                a.append(m.a)
-        return a
-
-
-######################################################################
-
-if __name__ == "__main__":
-    print("Basic check.")
-
-    vocabulary_size = 3
-    x = torch.randint(vocabulary_size, (1, 5))
-
-    model = MyGPT(
-        vocabulary_size=vocabulary_size,
-        dim_model=4,
-        dim_keys=2,
-        dim_hidden=2,
-        nb_heads=2,
-        nb_blocks=2,
-        dropout=0.1,
-    )
-
-    model.eval()
-    y1 = model(BracketedSequence(x)).x
-    y2 = torch.randn_like(y1)
-    for s in range(x.size(1)):
-        z = model(BracketedSequence(x, s, 1))
-        y2[:, s] = z.slice()
-
-    print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}")
-
-######################################################################
diff --git a/sky.py b/sky.py
deleted file mode 100755 (executable)
index cc5bd4f..0000000
--- a/sky.py
+++ /dev/null
@@ -1,364 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, sys, tqdm, os, warnings
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-import problem
-
-
-class Sky(problem.Problem):
-    colors = torch.tensor(
-        [
-            [255, 255, 255],
-            [255, 0, 0],
-            [0, 192, 0],
-            [0, 0, 255],
-            [255, 192, 0],
-            [0, 255, 255],
-            [255, 0, 255],
-            [192, 255, 192],
-            [255, 192, 192],
-            [192, 192, 255],
-            [192, 192, 192],
-        ]
-    )
-
-    token_background = 0
-    first_bird_token = 1
-    nb_bird_tokens = colors.size(0) - 1
-
-    token2char = (
-        "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
-    )
-
-    def __init__(
-        self,
-        height=6,
-        width=8,
-        nb_birds=3,
-        speed=2,
-        nb_iterations=2,
-        avoid_collision=True,
-        max_nb_cached_chunks=None,
-        chunk_size=None,
-        nb_threads=-1,
-    ):
-        super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
-        self.height = height
-        self.width = width
-        self.nb_birds = nb_birds
-        self.speed = speed
-        self.nb_iterations = nb_iterations
-        self.avoid_collision = avoid_collision
-
-    def generate_frame_sequences(self, nb):
-        frame_sequences = []
-
-        for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"):
-            i, j, vi, vj = (
-                torch.empty(self.nb_birds, dtype=torch.int64),
-                torch.empty(self.nb_birds, dtype=torch.int64),
-                torch.empty(self.nb_birds, dtype=torch.int64),
-                torch.empty(self.nb_birds, dtype=torch.int64),
-            )
-
-            def collision_okay():
-                if not self.avoid_collision:
-                    return True
-
-                count = torch.zeros(self.height, self.width, dtype=torch.int64)
-
-                for n in range(self.nb_birds):
-                    count[i[n], j[n]] += 1
-                    count[i[n] - vi[n], j[n]] += 1
-                    count[i[n], j[n] - vj[n]] += 1
-
-                return count.max() <= 1
-
-            col = (
-                torch.randperm(self.colors.size(0) - 1)[: self.nb_birds].sort().values
-                + 1
-            )
-
-            while True:
-                while True:
-                    for n in range(self.nb_birds):
-                        while True:
-                            i[n] = torch.randint(self.height, (1,))
-                            j[n] = torch.randint(self.width, (1,))
-                            vm = torch.randint(4, (1,))
-                            vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
-                            if (
-                                i[n] - vi[n] >= 0
-                                and i[n] - vi[n] < self.height
-                                and j[n] - vj[n] >= 0
-                                and j[n] - vj[n] < self.width
-                            ):
-                                break
-
-                    if collision_okay():
-                        break
-
-                result = torch.zeros(
-                    self.nb_iterations * self.speed,
-                    self.height,
-                    self.width,
-                    dtype=torch.int64,
-                )
-
-                fine = torch.empty(self.nb_iterations * self.speed)
-
-                t_to_keep = (
-                    torch.arange(self.nb_iterations, device=result.device) * self.speed
-                )
-
-                for l in range(self.nb_iterations * self.speed):
-                    fine[l] = collision_okay()
-                    for n in range(self.nb_birds):
-                        c = col[n]
-                        result[l, i[n], j[n]] = c
-                        result[l, i[n] - vi[n], j[n]] = c
-                        result[l, i[n], j[n] - vj[n]] = c
-
-                        if (i[n] == 0 and vi[n] == -1) or (
-                            i[n] == self.height - 1 and vi[n] == 1
-                        ):
-                            vi[n] = -vi[n]
-
-                        if (j[n] == 0 and vj[n] == -1) or (
-                            j[n] == self.width - 1 and vj[n] == 1
-                        ):
-                            vj[n] = -vj[n]
-
-                        i[n] += vi[n]
-                        j[n] += vj[n]
-
-                result = result[t_to_keep]
-                fine = fine[t_to_keep]
-
-                if fine[-1]:
-                    break
-
-            frame_sequences.append(result)
-
-        return frame_sequences
-
-    ######################################################################
-
-    def frame2img(self, x, scale=15):
-        x = x.reshape(x.size(0), self.height, -1)
-        m = torch.logical_and(
-            x >= 0, x < self.first_bird_token + self.nb_bird_tokens
-        ).long()
-        x = self.colors[x * m].permute(0, 3, 1, 2)
-        s = x.shape
-        x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
-        x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
-
-        x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
-        x[:, :, torch.arange(0, x.size(2), scale), :] = 0
-        x = x[:, :, 1:, 1:]
-
-        for n in range(m.size(0)):
-            for i in range(m.size(1)):
-                for j in range(m.size(2)):
-                    if m[n, i, j] == 0:
-                        for k in range(2, scale - 2):
-                            for l in [0, 1]:
-                                x[n, :, i * scale + k, j * scale + k - l] = 0
-                                x[
-                                    n, :, i * scale + scale - 1 - k, j * scale + k - l
-                                ] = 0
-
-        return x
-
-    def seq2str(self, seq):
-        result = []
-        for s in seq:
-            result.append("".join([self.token2char[v] for v in s]))
-        return result
-
-    def save_image(
-        self,
-        result_dir,
-        filename,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
-    ):
-        if predicted_prompts is None:
-            predicted_prompts = 255
-
-        if predicted_answers is None:
-            predicted_answers = 255
-
-        def add_frame(x, c, margin, bottom=False):
-            if bottom:
-                h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0
-            else:
-                h, w, di, dj = (
-                    x.size(2) + 2 * margin,
-                    x.size(3) + 2 * margin,
-                    margin,
-                    margin,
-                )
-
-            y = x.new_full((x.size(0), x.size(1), h, w), 0)
-
-            if type(c) is int:
-                y[...] = c
-            else:
-                c = c.long()[:, None]
-                c = (
-                    (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
-                    + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
-                    + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
-                )
-                y[...] = c[:, :, None, None]
-
-            y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
-
-            return y
-
-        margin = 4
-
-        img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1)
-        h = img_prompts.size(2)
-        img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1)
-
-        img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True)
-        img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True)
-
-        img_prompts = add_frame(
-            img_prompts, c=predicted_prompts, margin=margin, bottom=True
-        )
-        img_answers = add_frame(
-            img_answers, c=predicted_answers, margin=margin, bottom=True
-        )
-
-        marker_size = 16
-
-        separator = img_prompts.new_full(
-            (
-                img_prompts.size(0),
-                img_prompts.size(1),
-                img_prompts.size(2),
-                marker_size,
-            ),
-            255,
-        )
-
-        separator[:, :, 0] = 0
-        separator[:, :, h - 1] = 0
-
-        for k in range(1, 2 * marker_size - 8):
-            i = k - (marker_size - 4)
-            j = marker_size - 5 - abs(i)
-            separator[:, :, h // 2 - 1 + i, 2 + j] = 0
-            separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0
-
-        img = torch.cat([img_prompts, separator, img_answers], dim=3)
-
-        image_name = os.path.join(result_dir, filename)
-        torchvision.utils.save_image(
-            img.float() / 255.0, image_name, nrow=6, padding=margin * 4, pad_value=1.0
-        )
-
-    ######################################################################
-
-    def nb_token_values(self):
-        return len(self.colors)
-
-    def generate_prompts_and_answers(self, nb):
-        frame_sequences = self.generate_frame_sequences(nb)
-        frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0)
-
-        prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1)
-
-        answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1)
-
-        # warnings.warn("dirty test with longer answer", RuntimeWarning)
-        # answers = torch.cat(
-        # [
-        # frame_sequences[:, frame_sequences.size(1) // 2 :],
-        # frame_sequences[:, frame_sequences.size(1) // 2 :],
-        # ],
-        # dim=3,
-        # ).flatten(1)
-
-        return prompts, answers
-
-    def save_quiz_illustrations(
-        self,
-        result_dir,
-        filename_prefix,
-        prompts,
-        answers,
-        predicted_prompts=None,
-        predicted_answers=None,
-    ):
-        self.save_image(
-            result_dir,
-            filename_prefix + ".png",
-            prompts,
-            answers,
-            predicted_prompts,
-            predicted_answers,
-        )
-
-
-######################################################################
-
-if __name__ == "__main__":
-    import time
-
-    sky = Sky(height=6, width=8, speed=1, nb_iterations=4)
-
-    prompts, answers = sky.generate_prompts_and_answers(4)
-
-    predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
-    predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
-
-    sky.save_quiz_illustrations(
-        "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
-    )
-
-    # start_time = time.perf_counter()
-    # token_sequences = sky.generate_token_sequences(nb=64)
-    # delay = time.perf_counter() - start_time
-    # print(f"{token_sequences.size(0)/delay:02f} seq/s")
-
-    # print(sky.seq2str(seq[:4]))
-
-    # for t in range(len(it[0])):
-    # img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0)
-    # torchvision.utils.save_image(
-    # img.float() / 255.0,
-    # f"/tmp/frame_{t:03d}.png",
-    # nrow=8,
-    # padding=6,
-    # pad_value=0,
-    # )
-
-    # m = (torch.rand(seq.size()) < 0.05).long()
-    # seq = (1 - m) * seq + m * 23
-
-    # print(seq.size())
-    # img = sky.seq2img(token_sequences)
-    # print(img.size())
-
-    # torchvision.utils.save_image(
-    # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
-    # )
diff --git a/wireworld.py b/wireworld.py
deleted file mode 100755 (executable)
index 8257cad..0000000
+++ /dev/null
@@ -1,357 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, sys, tqdm, os
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-import problem
-
-
-class Wireworld(problem.Problem):
-    colors = torch.tensor(
-        [
-            [128, 128, 128],
-            [128, 128, 255],
-            [255, 0, 0],
-            [255, 255, 0],
-        ]
-    )
-
-    token_empty = 0
-    token_head = 1
-    token_tail = 2
-    token_conductor = 3
-    token_forward = 4
-    token_backward = 5
-
-    token2char = (
-        "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
-    )
-
-    def __init__(
-        self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4
-    ):
-        self.height = height
-        self.width = width
-        self.nb_objects = nb_objects
-        self.nb_walls = nb_walls
-        self.speed = speed
-        self.nb_iterations = nb_iterations
-
-    def direction_tokens(self):
-        return self.token_forward, self.token_backward
-
-    def generate_frame_sequences(self, nb):
-        result = []
-        N = 100
-        for _ in tqdm.tqdm(
-            range(0, nb + N, N), dynamic_ncols=True, desc="world generation"
-        ):
-            result.append(self.generate_frame_sequences_hard(100))
-        return torch.cat(result, dim=0)[:nb]
-
-    def generate_frame_sequences_hard(self, nb):
-        frame_sequences = []
-        nb_frames = (self.nb_iterations - 1) * self.speed + 1
-
-        result = torch.full(
-            (nb * 4, nb_frames, self.height, self.width),
-            self.token_empty,
-        )
-
-        for n in range(result.size(0)):
-            while True:
-                i = torch.randint(self.height, (1,))
-                j = torch.randint(self.width, (1,))
-                v = torch.randint(2, (2,))
-                vi = v[0] * (v[1] * 2 - 1)
-                vj = (1 - v[0]) * (v[1] * 2 - 1)
-                while True:
-                    if i < 0 or i >= self.height or j < 0 or j >= self.width:
-                        break
-                    o = 0
-                    if i > 0:
-                        o += (result[n, 0, i - 1, j] == self.token_conductor).long()
-                    if i < self.height - 1:
-                        o += (result[n, 0, i + 1, j] == self.token_conductor).long()
-                    if j > 0:
-                        o += (result[n, 0, i, j - 1] == self.token_conductor).long()
-                    if j < self.width - 1:
-                        o += (result[n, 0, i, j + 1] == self.token_conductor).long()
-                    if o > 1:
-                        break
-                    result[n, 0, i, j] = self.token_conductor
-                    i += vi
-                    j += vj
-                if (
-                    result[n, 0] == self.token_conductor
-                ).long().sum() > self.width and torch.rand(1) < 0.5:
-                    break
-
-            while True:
-                for _ in range(self.height * self.width):
-                    i = torch.randint(self.height, (1,))
-                    j = torch.randint(self.width, (1,))
-                    v = torch.randint(2, (2,))
-                    vi = v[0] * (v[1] * 2 - 1)
-                    vj = (1 - v[0]) * (v[1] * 2 - 1)
-                    if (
-                        i + vi >= 0
-                        and i + vi < self.height
-                        and j + vj >= 0
-                        and j + vj < self.width
-                        and result[n, 0, i, j] == self.token_conductor
-                        and result[n, 0, i + vi, j + vj] == self.token_conductor
-                    ):
-                        result[n, 0, i, j] = self.token_head
-                        result[n, 0, i + vi, j + vj] = self.token_tail
-                        break
-
-                # if torch.rand(1) < 0.75:
-                break
-
-        weight = torch.full((1, 1, 3, 3), 1.0)
-
-        mask = (torch.rand(result[:, 0].size()) < 0.01).long()
-        rand = torch.randint(4, mask.size())
-        result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
-
-        # empty->empty
-        # head->tail
-        # tail->conductor
-        # conductor->head if 1 or 2 head in the neighborhood, or remains conductor
-
-        nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1)
-        valid = nb_heads > 0
-
-        for l in range(nb_frames - 1):
-            nb_head_neighbors = (
-                F.conv2d(
-                    input=(result[:, l] == self.token_head).float()[:, None, :, :],
-                    weight=weight,
-                    padding=1,
-                )
-                .long()
-                .squeeze(1)
-            )
-            mask_1_or_2_heads = (nb_head_neighbors == 1).long() + (
-                nb_head_neighbors == 2
-            ).long()
-            result[:, l + 1] = (
-                (result[:, l] == self.token_empty).long() * self.token_empty
-                + (result[:, l] == self.token_head).long() * self.token_tail
-                + (result[:, l] == self.token_tail).long() * self.token_conductor
-                + (result[:, l] == self.token_conductor).long()
-                * (
-                    mask_1_or_2_heads * self.token_head
-                    + (1 - mask_1_or_2_heads) * self.token_conductor
-                )
-            )
-            pred_nb_heads = nb_heads
-            nb_heads = (
-                (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1)
-            )
-            valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads))
-
-        result = result[valid]
-
-        result = result[
-            :, torch.arange(self.nb_iterations, device=result.device) * self.speed
-        ]
-
-        i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
-        result = result[i]
-
-        # print(f"{result.size(0)=} {nb=}")
-
-        if result.size(0) < nb:
-            # print(result.size(0))
-            result = torch.cat(
-                [result, self.generate_frame_sequences(nb - result.size(0))], dim=0
-            )
-
-        return result[:nb]
-
-    def generate_token_sequences(self, nb):
-        frame_sequences = self.generate_frame_sequences(nb)
-
-        result = []
-
-        for frame_sequence in frame_sequences:
-            a = []
-            if torch.rand(1) < 0.5:
-                for frame in frame_sequence:
-                    if len(a) > 0:
-                        a.append(torch.tensor([self.token_forward]))
-                    a.append(frame.flatten())
-            else:
-                for frame in reversed(frame_sequence):
-                    if len(a) > 0:
-                        a.append(torch.tensor([self.token_backward]))
-                    a.append(frame.flatten())
-
-            result.append(torch.cat(a, dim=0)[None, :])
-
-        return torch.cat(result, dim=0)
-
-    ######################################################################
-
-    def frame2img(self, x, scale=15):
-        x = x.reshape(-1, self.height, self.width)
-        m = torch.logical_and(x >= 0, x < 4).long()
-
-        x = self.colors[x * m].permute(0, 3, 1, 2)
-        s = x.shape
-        x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
-        x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
-
-        x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
-        x[:, :, torch.arange(0, x.size(2), scale), :] = 0
-        x = x[:, :, 1:, 1:]
-
-        for n in range(m.size(0)):
-            for i in range(m.size(1)):
-                for j in range(m.size(2)):
-                    if m[n, i, j] == 0:
-                        for k in range(2, scale - 2):
-                            for l in [0, 1]:
-                                x[n, :, i * scale + k, j * scale + k - l] = 0
-                                x[
-                                    n, :, i * scale + scale - 1 - k, j * scale + k - l
-                                ] = 0
-
-        return x
-
-    def seq2img(self, seq, scale=15):
-        all = [
-            self.frame2img(
-                seq[:, : self.height * self.width].reshape(-1, self.height, self.width),
-                scale,
-            )
-        ]
-
-        separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0)
-
-        t = self.height * self.width
-
-        while t < seq.size(1):
-            direction_tokens = seq[:, t]
-            t += 1
-
-            direction_images = self.colors[
-                torch.full(
-                    (direction_tokens.size(0), self.height * scale - 1, scale), 0
-                )
-            ].permute(0, 3, 1, 2)
-
-            for n in range(direction_tokens.size(0)):
-                if direction_tokens[n] == self.token_forward:
-                    for k in range(scale):
-                        for l in [0, 1]:
-                            direction_images[
-                                n,
-                                :,
-                                (self.height * scale) // 2 - scale // 2 + k - l,
-                                3 + scale // 2 - abs(k - scale // 2),
-                            ] = 0
-                elif direction_tokens[n] == self.token_backward:
-                    for k in range(scale):
-                        for l in [0, 1]:
-                            direction_images[
-                                n,
-                                :,
-                                (self.height * scale) // 2 - scale // 2 + k - l,
-                                3 + abs(k - scale // 2),
-                            ] = 0
-                else:
-                    for k in range(2, scale - 2):
-                        for l in [0, 1]:
-                            direction_images[
-                                n,
-                                :,
-                                (self.height * scale) // 2 - scale // 2 + k - l,
-                                k,
-                            ] = 0
-                            direction_images[
-                                n,
-                                :,
-                                (self.height * scale) // 2 - scale // 2 + k - l,
-                                scale - 1 - k,
-                            ] = 0
-
-            all += [
-                separator,
-                direction_images,
-                separator,
-                self.frame2img(
-                    seq[:, t : t + self.height * self.width].reshape(
-                        -1, self.height, self.width
-                    ),
-                    scale,
-                ),
-            ]
-
-            t += self.height * self.width
-
-        return torch.cat(all, dim=3)
-
-    def seq2str(self, seq):
-        result = []
-        for s in seq:
-            result.append("".join([self.token2char[v] for v in s]))
-        return result
-
-    def save_image(self, input, result_dir, filename):
-        img = self.seq2img(input.to("cpu"))
-        image_name = os.path.join(result_dir, filename)
-        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
-
-    def save_quizzes(self, input, result_dir, filename_prefix):
-        self.save_image(input, result_dir, filename_prefix + ".png")
-
-
-######################################################################
-
-if __name__ == "__main__":
-    import time
-
-    wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1)
-
-    start_time = time.perf_counter()
-    frame_sequences = wireworld.generate_frame_sequences(nb=96)
-    delay = time.perf_counter() - start_time
-    print(f"{frame_sequences.size(0)/delay:02f} seq/s")
-
-    # print(wireworld.seq2str(seq[:4]))
-
-    for t in range(frame_sequences.size(1)):
-        img = wireworld.seq2img(frame_sequences[:, t])
-        torchvision.utils.save_image(
-            img.float() / 255.0,
-            f"/tmp/frame_{t:03d}.png",
-            nrow=8,
-            padding=6,
-            pad_value=0,
-        )
-
-    # m = (torch.rand(seq.size()) < 0.05).long()
-    # seq = (1 - m) * seq + m * 23
-
-    wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
-    token_sequences = wireworld.generate_token_sequences(32)
-    wireworld.save_quizzes(token_sequences, "/tmp", "seq")
-    # img = wireworld.seq2img(frame_sequences[:60])
-
-    # torchvision.utils.save_image(
-    # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1
-    # )