Update. master
authorFrançois Fleuret <fleuret@meta.com>
Tue, 17 Jun 2025 15:14:20 +0000 (17:14 +0200)
committerFrançois Fleuret <fleuret@meta.com>
Tue, 17 Jun 2025 15:14:20 +0000 (17:14 +0200)
bit_mlp.py
difdis.py [new file with mode: 0755]
distributed.py [new file with mode: 0755]
grid.py [new file with mode: 0755]
picocrafter.py
redshift.py [new file with mode: 0755]
tinyae.py
tinygen.py [new file with mode: 0755]
tinymnist.py

index 90409f2..6f7f92e 100755 (executable)
@@ -116,33 +116,40 @@ for linear_layer in errors.keys():
 
         ######################################################################
 
-        errors[linear_layer].append((nb_hidden, test_error))
+        errors[linear_layer].append(
+            (nb_hidden, test_error * 100, acc_train_loss / train_input.size(0))
+        )
 
 import matplotlib.pyplot as plt
 
-fig = plt.figure()
-fig.set_figheight(6)
-fig.set_figwidth(8)
 
-ax = fig.add_subplot(1, 1, 1)
+def save_fig(filename, ymax, ylabel, index):
+    fig = plt.figure()
+    fig.set_figheight(6)
+    fig.set_figwidth(8)
 
-ax.set_ylim(0, 1)
-ax.spines.right.set_visible(False)
-ax.spines.top.set_visible(False)
-ax.set_xscale("log")
-ax.set_xlabel("Nb hidden units")
-ax.set_ylabel("Test error (%)")
+    ax = fig.add_subplot(1, 1, 1)
 
-X = torch.tensor([x[0] for x in errors[nn.Linear]])
-Y = torch.tensor([x[1] for x in errors[nn.Linear]])
-ax.plot(X, Y, color="gray", label="nn.Linear")
+    ax.set_ylim(0, ymax)
+    ax.spines.right.set_visible(False)
+    ax.spines.top.set_visible(False)
+    ax.set_xscale("log")
+    ax.set_xlabel("Nb hidden units")
+    ax.set_ylabel(ylabel)
 
-X = torch.tensor([x[0] for x in errors[QLinear]])
-Y = torch.tensor([x[1] for x in errors[QLinear]])
-ax.plot(X, Y, color="red", label="QLinear")
+    X = torch.tensor([x[0] for x in errors[nn.Linear]])
+    Y = torch.tensor([x[index] for x in errors[nn.Linear]])
+    ax.plot(X, Y, color="gray", label="nn.Linear")
 
-ax.legend(frameon=False, loc=1)
+    X = torch.tensor([x[0] for x in errors[QLinear]])
+    Y = torch.tensor([x[index] for x in errors[QLinear]])
+    ax.plot(X, Y, color="red", label="QLinear")
 
-filename = f"bit_mlp.pdf"
-print(f"saving {filename}")
-fig.savefig(filename, bbox_inches="tight")
+    ax.legend(frameon=False, loc=1)
+
+    print(f"saving {filename}")
+    fig.savefig(filename, bbox_inches="tight")
+
+
+save_fig("bit_mlp_err.pdf", ymax=15, ylabel="Test error (%)", index=1)
+save_fig("bit_mlp_loss.pdf", ymax=1.25, ylabel="Train loss", index=2)
diff --git a/difdis.py b/difdis.py
new file mode 100755 (executable)
index 0000000..34fef17
--- /dev/null
+++ b/difdis.py
@@ -0,0 +1,568 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+# @XREMOTE_HOST: elk.fleuret.org
+# @XREMOTE_EXEC: python
+# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
+# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
+# @XREMOTE_GET: *.png
+
+import sys, argparse, time, math
+
+import torch, torchvision
+
+from torch import optim, nn
+from torch.nn import functional as F
+from tqdm import tqdm
+
+######################################################################
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+######################################################################
+
+parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.")
+
+parser.add_argument("--nb_epochs", type=int, default=250)
+
+parser.add_argument("--batch_size", type=int, default=100)
+
+parser.add_argument("--nb_train_samples", type=int, default=-1)
+
+parser.add_argument("--nb_test_samples", type=int, default=-1)
+
+parser.add_argument("--data_dir", type=str, default="./data/")
+
+parser.add_argument("--log_filename", type=str, default="train.log")
+
+parser.add_argument("--embedding_dim", type=int, default=64)
+
+parser.add_argument("--nb_channels", type=int, default=64)
+
+args = parser.parse_args()
+
+log_file = open(args.log_filename, "w")
+
+######################################################################
+
+
+def log_string(s):
+    t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
+
+    if log_file is not None:
+        log_file.write(t + s + "\n")
+        log_file.flush()
+
+    print(t + s)
+    sys.stdout.flush()
+
+
+######################################################################
+
+
+class VaswaniPositionalEncoding(nn.Module):
+    def __init__(self, len_max):
+        super().__init__()
+        self.len_max = len_max
+
+    def forward(self, x):
+        t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
+        j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
+        k = j % 2  # works with float, weird
+        pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
+        y = x + pe
+        return y
+
+
+######################################################################
+
+
+class WithResidual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, x):
+        return x + self.f(x)
+
+
+######################################################################
+
+
+def vanilla_attention(q, k, v):
+    a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+    a = a.softmax(dim=3)
+    y = torch.einsum("nhts,nhsd->nhtd", a, v)
+    return y
+
+
+######################################################################
+
+
+class MHAttention(nn.Module):
+    def __init__(
+        self,
+        dim_model,
+        dim_qk,
+        dim_v,
+        nb_heads=1,
+        attention=vanilla_attention,
+        attention_dropout=0.0,
+    ):
+        super().__init__()
+
+        def randw(*d):
+            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+        self.attention = attention
+        self.attention_dropout = attention_dropout
+        self.w_q = randw(nb_heads, dim_qk, dim_model)
+        self.w_k = randw(nb_heads, dim_qk, dim_model)
+        self.w_v = randw(nb_heads, dim_v, dim_model)
+        self.w_o = randw(nb_heads, dim_v, dim_model)
+
+    def forward(self, x_q, x_kv=None):
+        if x_kv is None:
+            x_kv = x_q
+
+        q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
+        k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k)
+        v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v)
+        y = self.attention(q, k, v)
+        y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
+
+        return y
+
+
+######################################################################
+
+
+class AttentionAE(nn.Module):
+    def __init__(
+        self,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        dropout=0.0,
+        len_max=1e5,
+    ):
+        super().__init__()
+
+        assert dim_model % nb_heads == 0
+
+        self.embedding = nn.Sequential(
+            nn.Linear(2, dim_model),
+            nn.Dropout(dropout),
+        )
+
+        self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+        trunk_blocks = []
+
+        for b in range(nb_blocks):
+            trunk_blocks += [
+                WithResidual(
+                    nn.LayerNorm((dim_model,)),
+                    MHAttention(
+                        dim_model=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        attention=vanilla_attention,
+                        attention_dropout=dropout,
+                    ),
+                ),
+                WithResidual(
+                    nn.LayerNorm((dim_model,)),
+                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                    nn.ReLU(),
+                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                    nn.Dropout(dropout),
+                ),
+            ]
+
+        self.trunk = nn.Sequential(*trunk_blocks)
+
+        self.readout = nn.Linear(in_features=dim_model, out_features=1)
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
+
+    def forward(self, x):
+        x = x.reshape(-1, 2, 28 * 28).permute(0, 2, 1)
+        x = self.embedding(x)
+        x = self.positional_encoding(x)
+        x = self.trunk(x)
+        x = self.readout(x).reshape(-1, 1, 28, 28)
+        return x
+
+
+######################################################################
+
+
+class WithMaskedResidual(nn.Module):
+    def __init__(self, masker, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+        self.masker = masker
+        self.mask = None
+
+    def forward(self, x):
+        if self.mask is None:
+            self.mask = self.masker(x)
+        return self.mask * x + self.f(x)
+
+
+######################################################################
+
+
+class FunctionalAttentionAE(nn.Module):
+    def __init__(
+        self,
+        vocabulary_size,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        nb_work_tokens=100,
+        dropout=0.0,
+        len_max=1e5,
+    ):
+        super().__init__()
+
+        assert dim_model % nb_heads == 0
+
+        self.nb_work_tokens = nb_work_tokens
+
+        self.embedding = nn.Sequential(
+            nn.Embedding(2 * vocabulary_size, dim_model),
+            nn.Dropout(dropout),
+        )
+
+        self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+        trunk_blocks = []
+
+        def no_peek_attention(q, k, v):
+            a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+            n = self.nb_work_tokens
+            s = (q.size(2) - n) // 2
+            a[:, :, n + 1 * s : n + 2 * s, n + 0 * s : n + 1 * s] = float("-inf")
+            a[:, :, n + 0 * s : n + 1 * s, n + 1 * s : n + 2 * s] = float("-inf")
+            a = a.softmax(dim=3)
+            y = torch.einsum("nhts,nhsd->nhtd", a, v)
+            return y
+
+        def masker(x):
+            m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens
+            return m[None, :, None]
+
+        for b in range(nb_blocks):
+            trunk_blocks += [
+                WithMaskedResidual(
+                    masker,
+                    nn.LayerNorm((dim_model,)),
+                    MHAttention(
+                        dim_model=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        attention=no_peek_attention,
+                        attention_dropout=dropout,
+                    ),
+                ),
+                WithMaskedResidual(
+                    masker,
+                    nn.LayerNorm((dim_model,)),
+                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                    nn.ReLU(),
+                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                    nn.Dropout(dropout),
+                ),
+            ]
+
+        self.trunk = nn.Sequential(*trunk_blocks)
+
+        self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
+
+    def forward(self, x):
+        x = self.embedding(x)
+        x = F.pad(x, (0, 0, self.nb_work_tokens, 0))
+        x = self.positional_encoding(x)
+        x = self.trunk(x)
+        x = F.pad(x, (0, 0, -self.nb_work_tokens, 0))
+        x = self.readout(x)
+        return x
+
+
+######################################################################
+
+
+class FullAveragePooling(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        x = x.view(x.size(0), x.size(1), -1).mean(2).view(x.size(0), x.size(1), 1, 1)
+        return x
+
+
+class ResNetBlock(nn.Module):
+    def __init__(self, nb_channels, kernel_size):
+        super().__init__()
+
+        self.conv1 = nn.Conv2d(
+            nb_channels,
+            nb_channels,
+            kernel_size=kernel_size,
+            padding=(kernel_size - 1) // 2,
+        )
+
+        self.conv2 = nn.Conv2d(
+            nb_channels,
+            nb_channels,
+            kernel_size=kernel_size,
+            padding=(kernel_size - 1) // 2,
+        )
+
+    def forward(self, x):
+        y = F.relu(self.conv1(x))
+        y = F.relu(x + self.conv2(y))
+        return y
+
+
+######################################################################
+
+
+class ResAutoEncoder(nn.Module):
+    def __init__(self, nb_channels, kernel_size):
+        super().__init__()
+
+        self.encoder = nn.Conv2d(
+            2, nb_channels, kernel_size=kernel_size, padding=kernel_size // 2
+        )
+        self.core = nn.Sequential(
+            *[ResNetBlock(nb_channels, kernel_size) for _ in range(20)]
+        )
+        self.decoder = nn.Conv2d(
+            nb_channels, 1, kernel_size=kernel_size, padding=kernel_size // 2
+        )
+
+    def forward(self, x):
+        x = self.encoder(x)
+        x = self.decoder(x)
+        return x
+
+
+######################################################################
+
+
+class AutoEncoder(nn.Module):
+    def __init__(self, nb_channels, embedding_dim):
+        super().__init__()
+
+        self.encoder = nn.Sequential(
+            nn.Conv2d(256, nb_channels, kernel_size=5),  # to 24x24
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=5),  # to 20x20
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2),  # to 9x9
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2),  # to 4x4
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, embedding_dim, kernel_size=4),
+        )
+
+        self.decoder = nn.Sequential(
+            nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size=4),
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(
+                nb_channels, nb_channels, kernel_size=3, stride=2
+            ),  # from 4x4
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(
+                nb_channels, nb_channels, kernel_size=4, stride=2
+            ),  # from 9x9
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5),  # from 20x20
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(nb_channels, 256, kernel_size=5),  # from 24x24
+        )
+
+    def forward(self, x):
+        x = self.encoder(x)
+        x = self.decoder(x)
+        return x * 1e-3
+
+
+######################################################################
+
+train_set = torchvision.datasets.MNIST(
+    args.data_dir + "/mnist/", train=True, download=True
+)
+train_input = train_set.data.view(-1, 28, 28).long()
+
+test_set = torchvision.datasets.MNIST(
+    args.data_dir + "/mnist/", train=False, download=True
+)
+test_input = test_set.data.view(-1, 28, 28).long()
+
+if args.nb_train_samples > 0:
+    train_input = train_input[: args.nb_train_samples]
+
+if args.nb_test_samples > 0:
+    test_input = test_input[: args.nb_test_samples]
+
+######################################################################
+
+model = AutoEncoder(args.nb_channels, args.embedding_dim)
+
+# model = AttentionAE(
+# dim_model=16,
+# dim_keys=16,
+# dim_hidden=16,
+# nb_heads=4,
+# nb_blocks=4,
+# dropout=0.0,
+# len_max=1e5,
+# )
+
+# model = ResAutoEncoder(nb_channels=128, kernel_size=9)
+
+print(model)
+
+optimizer = optim.Adam(model.parameters(), lr=1e-3)
+
+model.to(device)
+
+train_input, test_input = train_input.to(device), test_input.to(device)
+
+nb_iterations = 10
+
+######################################################################
+
+
+def dist(u, v):
+    return (u - v).pow(2).sum(dim=(1, 2, 3), keepdim=True).sqrt()
+
+
+def kl(u, v):
+    s = (u.size(0),) + (1,) * (u.dim() - 1)
+    u = u.flatten(1).log_softmax(dim=1)
+    v = v.flatten(1).log_softmax(dim=1)
+    return (
+        F.kl_div(u, v, log_target=True, reduction="none")
+        .sum(dim=1, keepdim=True)
+        .clamp(min=0)
+        .reshape(*s)
+    )
+
+
+def pb(e, desc):
+    return tqdm(
+        e,
+        dynamic_ncols=True,
+        desc=desc,
+        total=train_input.size(0) // args.batch_size,
+        delay=10,
+    )
+
+
+def save_logit_image(logit_image, filename):
+    image = (
+        logit_image.softmax(dim=1)
+        * torch.arange(logit_image.size(1), device=logit_image.device)[
+            None, :, None, None
+        ]
+    ).sum(dim=1, keepdim=True)
+    image = image / 255
+
+    torchvision.utils.save_image(1 - image, filename, nrow=16, pad_value=0.8)
+
+
+for n_epoch in range(args.nb_epochs):
+    acc_loss = 0
+
+    for targets in pb(train_input.split(args.batch_size), "train"):
+        targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2)
+        targets = (targets * 0.9 + 0.1 / targets.size(1)).log()
+        input = torch.randn(targets.size(), device=targets.device) * 1e-3
+
+        loss = 0
+        for n in range(nb_iterations):
+            input = input.log_softmax(dim=1)
+            output = model(input).clamp(min=-10, max=10)
+            current_kl = kl(output, targets)
+
+            nb_remain = nb_iterations - n
+            tolerated_kl = kl(input, targets) * (nb_remain - 1) / nb_remain
+
+            with torch.no_grad():
+                a = input.new_full((input.size(0), 1, 1, 1), 0.0)
+                b = input.new_full((input.size(0), 1, 1, 1), 1.0)
+
+                # KL(a * output + (1-a) * targets) = 0
+                # KL(b * output + (1-b) * targets) >> 0
+
+                for _ in range(20):
+                    c = (a + b) / 2
+                    kl_c = kl((1 - c) * output + c * targets, targets)
+                    m = (kl_c >= tolerated_kl).long()
+                    # m=1 -> kl > tolerated => a = c
+                    # m=0 -> kl < tolerated => b = c
+                    a = m * c + (1 - m) * a
+                    b = m * b + (1 - m) * c
+
+                c = (a + b) / 2
+
+            # print((kl_c / (tolerated_kl+1e-6)).flatten())
+
+            input = (1 - c) * output + c * targets
+            loss += kl(output, input).mean()
+            input = input.detach()
+
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+        acc_loss += loss.item()
+
+    log_string(f"acc_loss {n_epoch} {acc_loss}")
+
+    save_logit_image(output, f"train_output_{n_epoch:04d}.png")
+
+    ######################################################################
+
+    model.eval()
+
+    input = torch.randn((256, 256, 28, 28), device=targets.device) * 1e-3
+    input = input.log_softmax(dim=1)
+
+    for _ in range(nb_iterations):
+        input = input.log_softmax(dim=1)
+        output = model(input).clamp(min=-10, max=10)
+        input = output.detach()
+
+    save_logit_image(output, f"test_output_{n_epoch:04d}.png")
+
+######################################################################
diff --git a/distributed.py b/distributed.py
new file mode 100755 (executable)
index 0000000..648756e
--- /dev/null
@@ -0,0 +1,151 @@
+#!/usr/bin/env python
+
+import time, socket, threading, struct, pickle
+
+import argparse
+
+######################################################################
+
+parser = argparse.ArgumentParser(
+    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+)
+
+parser.add_argument("--server", type=str, default=None)
+
+parser.add_argument("--port", type=int, default=12021)
+
+args = parser.parse_args()
+
+######################################################################
+
+
+class SocketConnection:
+    def __init__(self, established_socket, read_len=16384):
+        self.read_len = read_len
+        self.socket = established_socket
+        self.socket.setblocking(1)
+        self.buffer = b""
+        self.SEND_LOCK = threading.Lock()
+        self.RECEIVE_LOCK = threading.Lock()
+        self.failed = False
+
+    def send(self, x):
+        with self.SEND_LOCK:
+            data = pickle.dumps(x)
+            self.socket.send(struct.pack("!i", len(data)))
+            self.socket.sendall(data)
+
+    def raw_read(self, l):
+        while len(self.buffer) < l:
+            d = self.socket.recv(self.read_len)
+            if d:
+                self.buffer += d
+            else:
+                raise EOFError()
+
+        x = self.buffer[:l]
+        self.buffer = self.buffer[l:]
+        return x
+
+    def receive(self):
+        with self.RECEIVE_LOCK:
+            l = struct.unpack("!i", self.raw_read(4))[0]
+            return pickle.loads(self.raw_read(l))
+
+
+######################################################################
+
+
+def start_server(port, core, reader):
+    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    s.bind(("0.0.0.0", port))
+    s.listen(5)
+    nb_accepts = 0
+
+    def threadable_reader(reader, receiver, client_id):
+        try:
+            while True:
+                reader(link.receive(), client_id)
+        except EOFError:
+            print(f"** closing reader #{client_id} on EOFError **")
+
+    def threadable_core(sender, client_id):
+        try:
+            core(sender, client_id)
+        except BrokenPipeError:
+            print(f"** closing core #{client_id} on BrokenPipeError **")
+
+    while True:
+        client_socket, ip_and_port = s.accept()
+        link = SocketConnection(client_socket)
+
+        threading.Thread(
+            target=threadable_core,
+            kwargs={"sender": link.send, "client_id": nb_accepts},
+            daemon=True,
+        ).start()
+
+        threading.Thread(
+            target=threadable_reader,
+            kwargs={
+                "reader": reader,
+                "receiver": link.receive,
+                "client_id": nb_accepts,
+            },
+            daemon=True,
+        ).start()
+
+        nb_accepts += 1
+
+
+######################################################################
+
+
+def create_client(servername, port, reader):
+    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    server_socket.connect((servername, port))
+    link = SocketConnection(server_socket)
+
+    def threadable_reader(reader):
+        while True:
+            reader(link.receive())
+
+    def writer(x):
+        link.send(x)
+
+    threading.Thread(
+        target=threadable_reader, kwargs={"reader": reader}, daemon=True
+    ).start()
+
+    return writer
+
+
+######################################################################
+
+if args.server is None:
+    print(f"Starting server port {args.port}")
+
+    def reader(obj, client_id):
+        print(f'Server received from client #{client_id} "{obj}"')
+
+    def core(writer, client_id):
+        writer(f"HELLO {client_id}")
+        while True:
+            writer(f"PONG {time.localtime().tm_sec}")
+            time.sleep(3)
+
+    start_server(port=args.port, core=core, reader=reader)
+
+else:
+    print(f"Starting client connecting to {args.server}:{args.port}")
+
+    def reader(obj):
+        print(f'Client received from server "{obj}"')
+
+    writer = create_client(args.server, args.port, reader)
+
+    while True:
+        writer(f"PING {time.localtime().tm_sec}")
+        time.sleep(3)
+
+######################################################################
diff --git a/grid.py b/grid.py
new file mode 100755 (executable)
index 0000000..ac0ebe0
--- /dev/null
+++ b/grid.py
@@ -0,0 +1,337 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+# This code implement a simple system to manipulate formal
+# specifications of tokens on a grid.
+
+import math, re, random
+
+import torch
+
+from torch.nn import functional as F
+
+######################################################################
+
+
+class FormalGrid:
+    def __init__(
+        self, grid_height=8, grid_width=8, nb_symbols=4, device=torch.device("cpu")
+    ):
+        self.grid_height = grid_height
+        self.grid_width = grid_width
+        self.nb_symbols = nb_symbols
+        self.nb_configs = (self.grid_height * self.grid_width) ** self.nb_symbols
+
+        self.row = torch.empty(
+            self.nb_configs, self.nb_symbols, dtype=torch.int8, device=device
+        )
+        self.col = torch.empty(
+            self.nb_configs, self.nb_symbols, dtype=torch.int8, device=device
+        )
+
+        i = torch.arange(self.nb_configs, device=device)
+
+        k = 1
+        for s in range(self.nb_symbols):
+            self.row[:, s] = (i // k) % self.grid_height
+            k *= self.grid_height
+            self.col[:, s] = (i // k) % self.grid_width
+            k *= self.grid_width
+
+        nb = 0
+        for s in range(self.nb_symbols):
+            nb += F.one_hot(
+                (self.row[:, s] * self.grid_width + self.col[:, s]).long(),
+                num_classes=self.grid_height * self.grid_width,
+            ).to(torch.int8)
+        self.master_grid_set = nb.max(dim=1).values <= 1
+
+    def new_grid_set(self, constraints=None):
+        g = self.master_grid_set.clone()
+        if constraints:
+            self.add_constraints(g, constraints)
+        return g
+
+    ######################################################################
+
+    def constraint_to_fun(self, constraint):
+        a, b, c = None, None, None
+        col, row = self.col, self.row
+
+        def match(pattern):
+            nonlocal a, b, c
+            r = re.search("^" + pattern + "$", constraint)
+            if r:
+                g = tuple(int(x) - 1 for x in r.groups())
+                a, b, c = g + (None,) * (3 - len(g))
+                return True
+            else:
+                return False
+
+        if match("([1-9]) is_in_top_half"):
+            return row[:, a] < self.grid_height // 2
+
+        elif match("([1-9]) is_in_bottom_half"):
+            return row[:, a] >= self.grid_height // 2
+
+        elif match("([1-9]) is_on_left_side"):
+            return col[:, a] < self.grid_width // 2
+
+        elif match("([1-9]) is_on_right_side"):
+            return col[:, a] >= self.grid_width // 2
+
+        elif match("([1-9]) next_to ([1-9])"):
+            return (row[:, a] - row[:, b]).abs() + (col[:, a] - col[:, b]).abs() == 1
+
+        elif match("([1-9]) is_below ([1-9])"):
+            return row[:, a] > row[:, b]
+
+        elif match("([1-9]) is_above ([1-9])"):
+            return row[:, a] < row[:, b]
+
+        elif match("([1-9]) is_left_of ([1-9])"):
+            return col[:, a] < col[:, b]
+
+        elif match("([1-9]) is_right_of ([1-9])"):
+            return col[:, a] > col[:, b]
+
+        elif match("([1-9]) ([1-9]) is_parallel_to_diagonal"):
+            return (col[:, a] - col[:, b]).abs() == (row[:, a] - row[:, b]).abs()
+
+        elif match("([1-9]) ([1-9]) is_vertical"):
+            return col[:, a] == col[:, b]
+
+        elif match("([1-9]) ([1-9]) is_horizontal"):
+            return row[:, a] == row[:, b]
+
+        elif match("([1-9]) ([1-9]) ([1-9]) are_aligned"):
+            return (col[:, a] - col[:, b]) * (row[:, a] - row[:, c]) - (
+                row[:, a] - row[:, b]
+            ) * (col[:, a] - col[:, c]) == 0
+
+        elif match("([1-9]) middle_of ([1-9]) ([1-9])"):
+            return (col[:, b] + col[:, a] == 2 * col[:, b]) & (
+                row[:, b] + row[:, a] == 2 * row[:, b]
+            )
+
+        elif match("([1-9]) is_equidistant_from ([1-9]) and ([1-9])"):
+            return (col[:, a] - col[:, b]) ** 2 + (row[:, a] - row[:, b]) ** 2 == (
+                col[:, a] - col[:, c]
+            ) ** 2 + (row[:, a] - row[:, c]) ** 2
+
+        elif match("([1-9]) is_further_from ([1-9]) than_from ([1-9])"):
+            return (col[:, a] - col[:, b]) ** 2 + (row[:, a] - row[:, b]) ** 2 > (
+                col[:, a] - col[:, c]
+            ) ** 2 + (row[:, a] - row[:, c]) ** 2
+
+        elif match("([1-9]) is_closer_to ([1-9]) than_to ([1-9])"):
+            return (col[:, a] - col[:, b]) ** 2 + (row[:, a] - row[:, b]) ** 2 < (
+                col[:, a] - col[:, c]
+            ) ** 2 + (row[:, a] - row[:, c]) ** 2
+
+        elif match("([1-9]) ([1-9]) ([1-9]) form_a_right_angle"):
+            return (col[:, a] - col[:, b]) * (col[:, c] - col[:, b]) + (
+                row[:, a] - row[:, b]
+            ) * (row[:, c] - row[:, b]) == 0
+
+        else:
+            raise ValueError(f"Unknown type of constraint {constraint}")
+
+    ######################################################################
+
+    def check_reasonning(self, steps):
+        grid_set = grid.new_grid_set()
+        for step in steps:
+            if step[0] == "=>":
+                f = self.constraint_to_fun(step[1:])
+                if (grid_set & torch.logical_not(f)).any():
+                    return False, step[1:]
+            else:
+                grid_set[...] = grid_set & self.constraint_to_fun(step)
+
+        return True, None
+
+    def add_constraints(self, grid_set, constraints):
+        for constraint in constraints:
+            grid_set[...] = grid_set & self.constraint_to_fun(constraint)
+
+    ######################################################################
+
+    def views(self, grid_set):
+        g = torch.empty(self.grid_height, self.grid_width, dtype=torch.int64)
+        row, col = self.row[grid_set], self.col[grid_set]
+        i = torch.randperm(row.size(0))
+        row, col = row[i], col[i]
+        for r, c in zip(row, col):
+            g.zero_()
+            for k in range(self.nb_symbols):
+                g[r[k], c[k]] = k + 1
+            v = ""
+            for r in g:
+                v += " ".join(["-" if n == 0 else str(n.item()) for n in r]) + "\n"
+            yield v
+
+    def random_property(self):
+        a, b, c = random.sample(list(range(1, self.nb_symbols + 1)), 3)
+
+        sb, sc = min(b, c), max(b, c)
+
+        ta, tb, tc = sorted((a, b, c))
+
+        l = (
+            [
+                f"{a} is_in_top_half",
+                f"{a} is_in_bottom_half",
+                f"{a} is_on_left_side",
+                f"{a} is_on_right_side",
+            ]
+            + [
+                f"{a} is_below {b}",
+                f"{a} is_above {b}",
+                f"{a} is_left_of {b}",
+                f"{a} is_right_of {b}",
+                f"{sb} next_to {sc}",
+            ]
+            + [
+                f"{sb} {sc} is_parallel_to_diagonal",
+                f"{sb} {sc} is_vertical",
+                f"{sb} {sc} is_horizontal",
+            ]
+            * 2
+            + [
+                f"{ta} {tb} {tc} are_aligned",
+                f"{a} middle_of {sb} {sc}",
+                f"{ta} {tb} {tc} form_a_right_angle",
+            ]
+            * 3
+            + [
+                f"{a} is_equidistant_from {sb} and {sc}",
+                f"{a} is_further_from {b} than_from {c}",
+                f"{a} is_closer_to {b} than_to {c}",
+            ]
+        )
+
+        return random.choice(l)
+
+
+######################################################################
+
+if __name__ == "__main__":
+
+    if torch.cuda.is_available():
+        device = torch.device("cuda")
+    elif torch.backends.mps.is_available():
+        device = torch.device("mps")
+    else:
+        device = torch.device("cpu")
+
+    # grid = FormalGrid(grid_height=7, grid_width=7, nb_symbols=4, device=device)
+    # grid_set = grid.new_grid_set(["4 is_equidistant_from 2 and 3", "2 4 is_parallel_to_diagonal"])
+    # print(next(iter(grid.views(grid_set))))
+    # exit(0)
+
+    def proof_depth(steps, c):
+        a = steps.get(c)
+        if a is None:
+            return 0
+        else:
+            c1, c2 = a
+            return max(proof_depth(steps, c1), proof_depth(steps, c2))
+
+    def generate_proof(grid):
+        while True:
+            constraints = [grid.random_property() for _ in range(10)]
+            grid_set = grid.new_grid_set(constraints)
+            if grid_set.any():
+                break
+
+        mg = grid.master_grid_set
+
+        print(constraints)
+
+        initial = constraints.copy()
+
+        steps = {}
+
+        for _ in range(1000):
+            c1, c2 = random.sample(constraints, 2)
+            f1, f2 = grid.constraint_to_fun(c1), grid.constraint_to_fun(c2)
+            for _ in range(100):
+                c = grid.random_property()
+                if c not in constraints:
+                    f = grid.constraint_to_fun(c)
+                    if (
+                        (mg & f1 & ~f).any()
+                        and (mg & f2 & ~f).any()
+                        and (mg & f1 & f2 & f).any()
+                        and not (mg & f1 & f2 & ~f).any()
+                    ):
+                        constraints.append(c)
+                        print(c1, "and", c2, "=>", c)
+                        steps[c] = (c1, c2)
+                        # print(next(iter(grid.views(grid.new_grid_set([c1, c2])))))
+                        # print("we have", constraints)
+                        # proof.append(c1 + " and " + c2 + " hence " + c)
+                        break
+
+            if steps.keys() and max([proof_depth(steps, c) for c in steps.keys()]) >= 3:
+
+                break
+
+        return initial, steps
+
+    grid = FormalGrid(grid_height=7, grid_width=7, nb_symbols=4, device=device)
+
+    initial, steps = generate_proof(grid)
+
+    print(" ; ".join(initial))
+
+    def proof(c, indent=""):
+        a = steps.get(c)
+        if a is None:
+            print(f"{indent}{c} is given")
+        else:
+            print(f"{indent}{c} since")
+            c1, c2 = a
+            proof(c1, indent + "  ")
+            proof(c2, indent + "  ")
+
+    print(" ; ".join(initial))
+
+    for c in steps.keys():
+        proof(c)
+        print()
+
+    exit(0)
+
+    # grid_set = grid.new_grid_set(
+    # [
+    # "1 2 3 form_a_right_angle",
+    # "2 3 4 form_a_right_angle",
+    # "3 4 1 form_a_right_angle",
+    # "2 is_equidistant_from 1 and 3",
+    # "1 is_above 4",
+    # ],
+    # )
+
+    grid_set = grid.new_grid_set(
+        [
+            "1 2 3 are_aligned",
+            "2 3 is_parallel_to_diagonal",
+            "4 1 is_vertical",
+            "3 4 is_horizontal",
+            "3 is_left_of 2",
+            "1 is_below 4",
+            "2 is_right_of 4",
+        ],
+    )
+
+    print(f"There are {grid_set.long().sum().item()} configurations")
+
+    for v in grid.views(grid_set):
+        print(v)
index 23d93b2..001bb81 100755 (executable)
@@ -227,9 +227,9 @@ class PicroCrafterEnvironment:
         r = u.sort(dim=-1, descending=True).indices[:, : len(z)]
 
         q *= self.tile2id["#"]
-        q[
-            torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r
-        ] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :]
+        q[torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r] = (
+            torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :]
+        )
 
         if world_margin > 0:
             r = m.new_full(
diff --git a/redshift.py b/redshift.py
new file mode 100755 (executable)
index 0000000..2ed1e52
--- /dev/null
@@ -0,0 +1,67 @@
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+torch.set_default_dtype(torch.float64)
+
+nb_hidden = 5
+hidden_dim = 100
+
+res = 256
+
+input = torch.cat(
+    [
+        torch.linspace(-1, 1, res)[None, :, None].expand(res, res, 1),
+        torch.linspace(-1, 1, res)[:, None, None].expand(res, res, 1),
+    ],
+    dim=-1,
+).reshape(-1, 2)
+
+
+class Angles(nn.Module):
+    def forward(self, x):
+        return x.clamp(min=-0.5, max=0.5)
+
+
+for activation in [nn.ReLU, nn.Tanh, nn.Softplus, Angles]:
+    for s in [1.0, 10.0]:
+        layers = [nn.Linear(2, hidden_dim), activation()]
+        for k in range(nb_hidden - 1):
+            layers += [nn.Linear(hidden_dim, hidden_dim), activation()]
+        layers += [nn.Linear(hidden_dim, 2)]
+        model = nn.Sequential(*layers)
+
+        with torch.no_grad():
+            for p in model.parameters():
+                p *= s
+
+        output = model(input)
+
+        img = (output[:, 1] - output[:, 0]).reshape(1, 1, res, res)
+
+        img = (img - img.mean()) / (1 * img.std())
+
+        img = img.clamp(min=-1, max=1)
+
+        img = torch.cat(
+            [
+                (1 + img).clamp(max=1),
+                (1 - img.abs()).clamp(min=0),
+                (1 - img).clamp(max=1),
+            ],
+            dim=1,
+        )
+
+        name_activation = {
+            nn.ReLU: "relu",
+            nn.Tanh: "tanh",
+            nn.Softplus: "softplus",
+            Angles: "angles",
+        }[activation]
+
+        torchvision.utils.save_image(img, f"result-{name_activation}-{s}.png")
index b4f3aba..806559e 100755 (executable)
--- a/tinyae.py
+++ b/tinyae.py
@@ -92,8 +92,8 @@ class AutoEncoder(nn.Module):
         return self.decoder(z.view(z.size(0), -1, 1, 1))
 
     def forward(self, x):
-        x = self.encoder(x)
-        x = self.decoder(x)
+        x = self.encode(x)
+        x = self.decode(x)
         return x
 
 
@@ -124,20 +124,22 @@ test_input.sub_(mu).div_(std)
 
 ######################################################################
 
-for epoch in range(args.nb_epochs):
-    acc_loss = 0
+for n_epoch in range(args.nb_epochs):
+    acc_train_loss = 0
 
     for input in train_input.split(args.batch_size):
         output = model(input)
-        loss = 0.5 * (output - input).pow(2).sum() / input.size(0)
+        train_loss = F.mse_loss(output, input)
 
         optimizer.zero_grad()
-        loss.backward()
+        train_loss.backward()
         optimizer.step()
 
-        acc_loss += loss.item()
+        acc_train_loss += train_loss.detach().item() * input.size(0)
 
-    log_string("acc_loss {:d} {:f}.".format(epoch, acc_loss))
+    train_loss = acc_train_loss / train_input.size(0)
+
+    log_string(f"train_loss {n_epoch} {train_loss}")
 
 ######################################################################
 
diff --git a/tinygen.py b/tinygen.py
new file mode 100755 (executable)
index 0000000..66c005c
--- /dev/null
@@ -0,0 +1,520 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+# @XREMOTE_HOST: elk.fleuret.org
+# @XREMOTE_EXEC: python
+# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
+# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
+# @XREMOTE_GET: *.png
+
+import sys, argparse, time, math
+
+import torch, torchvision
+
+from torch import optim, nn
+from torch.nn import functional as F
+from tqdm import tqdm
+
+######################################################################
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+######################################################################
+
+parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.")
+
+parser.add_argument("--nb_epochs", type=int, default=250)
+
+parser.add_argument("--batch_size", type=int, default=100)
+
+parser.add_argument("--data_dir", type=str, default="./data/")
+
+parser.add_argument("--log_filename", type=str, default="train.log")
+
+parser.add_argument("--embedding_dim", type=int, default=64)
+
+parser.add_argument("--nb_channels", type=int, default=64)
+
+args = parser.parse_args()
+
+log_file = open(args.log_filename, "w")
+
+######################################################################
+
+
+def log_string(s):
+    t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
+
+    if log_file is not None:
+        log_file.write(t + s + "\n")
+        log_file.flush()
+
+    print(t + s)
+    sys.stdout.flush()
+
+
+######################################################################
+
+
+class VaswaniPositionalEncoding(nn.Module):
+    def __init__(self, len_max):
+        super().__init__()
+        self.len_max = len_max
+
+    def forward(self, x):
+        t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
+        j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
+        k = j % 2  # works with float, weird
+        pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
+        y = x + pe
+        return y
+
+
+######################################################################
+
+
+class WithResidual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, x):
+        return x + self.f(x)
+
+
+######################################################################
+
+
+def vanilla_attention(q, k, v):
+    a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+    a = a.softmax(dim=3)
+    y = torch.einsum("nhts,nhsd->nhtd", a, v)
+    return y
+
+
+######################################################################
+
+
+class MHAttention(nn.Module):
+    def __init__(
+        self,
+        dim_model,
+        dim_qk,
+        dim_v,
+        nb_heads=1,
+        attention=vanilla_attention,
+        attention_dropout=0.0,
+    ):
+        super().__init__()
+
+        def randw(*d):
+            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+        self.attention = attention
+        self.attention_dropout = attention_dropout
+        self.w_q = randw(nb_heads, dim_qk, dim_model)
+        self.w_k = randw(nb_heads, dim_qk, dim_model)
+        self.w_v = randw(nb_heads, dim_v, dim_model)
+        self.w_o = randw(nb_heads, dim_v, dim_model)
+
+    def forward(self, x_q, x_kv=None):
+        if x_kv is None:
+            x_kv = x_q
+
+        q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
+        k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k)
+        v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v)
+        y = self.attention(q, k, v)
+        y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
+
+        return y
+
+
+######################################################################
+
+
+class AttentionAE(nn.Module):
+    def __init__(
+        self,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        dropout=0.0,
+        len_max=1e5,
+    ):
+        super().__init__()
+
+        assert dim_model % nb_heads == 0
+
+        self.embedding = nn.Sequential(
+            nn.Linear(2, dim_model),
+            nn.Dropout(dropout),
+        )
+
+        self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+        trunk_blocks = []
+
+        for b in range(nb_blocks):
+            trunk_blocks += [
+                WithResidual(
+                    nn.LayerNorm((dim_model,)),
+                    MHAttention(
+                        dim_model=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        attention=vanilla_attention,
+                        attention_dropout=dropout,
+                    ),
+                ),
+                WithResidual(
+                    nn.LayerNorm((dim_model,)),
+                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                    nn.ReLU(),
+                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                    nn.Dropout(dropout),
+                ),
+            ]
+
+        self.trunk = nn.Sequential(*trunk_blocks)
+
+        self.readout = nn.Linear(in_features=dim_model, out_features=1)
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
+
+    def forward(self, x):
+        x = x.reshape(-1, 2, 28 * 28).permute(0, 2, 1)
+        x = self.embedding(x)
+        x = self.positional_encoding(x)
+        x = self.trunk(x)
+        x = self.readout(x).reshape(-1, 1, 28, 28)
+        return x
+
+
+######################################################################
+
+
+class WithMaskedResidual(nn.Module):
+    def __init__(self, masker, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+        self.masker = masker
+        self.mask = None
+
+    def forward(self, x):
+        if self.mask is None:
+            self.mask = self.masker(x)
+        return self.mask * x + self.f(x)
+
+
+######################################################################
+
+
+class FunctionalAttentionAE(nn.Module):
+    def __init__(
+        self,
+        vocabulary_size,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        nb_work_tokens=100,
+        dropout=0.0,
+        len_max=1e5,
+    ):
+        super().__init__()
+
+        assert dim_model % nb_heads == 0
+
+        self.nb_work_tokens = nb_work_tokens
+
+        self.embedding = nn.Sequential(
+            nn.Embedding(2 * vocabulary_size, dim_model),
+            nn.Dropout(dropout),
+        )
+
+        self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+        trunk_blocks = []
+
+        def no_peek_attention(q, k, v):
+            a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+            n = self.nb_work_tokens
+            s = (q.size(2) - n) // 2
+            a[:, :, n + 1 * s : n + 2 * s, n + 0 * s : n + 1 * s] = float("-inf")
+            a[:, :, n + 0 * s : n + 1 * s, n + 1 * s : n + 2 * s] = float("-inf")
+            a = a.softmax(dim=3)
+            y = torch.einsum("nhts,nhsd->nhtd", a, v)
+            return y
+
+        def masker(x):
+            m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens
+            return m[None, :, None]
+
+        for b in range(nb_blocks):
+            trunk_blocks += [
+                WithMaskedResidual(
+                    masker,
+                    nn.LayerNorm((dim_model,)),
+                    MHAttention(
+                        dim_model=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        attention=no_peek_attention,
+                        attention_dropout=dropout,
+                    ),
+                ),
+                WithMaskedResidual(
+                    masker,
+                    nn.LayerNorm((dim_model,)),
+                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                    nn.ReLU(),
+                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                    nn.Dropout(dropout),
+                ),
+            ]
+
+        self.trunk = nn.Sequential(*trunk_blocks)
+
+        self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
+
+    def forward(self, x):
+        x = self.embedding(x)
+        x = F.pad(x, (0, 0, self.nb_work_tokens, 0))
+        x = self.positional_encoding(x)
+        x = self.trunk(x)
+        x = F.pad(x, (0, 0, -self.nb_work_tokens, 0))
+        x = self.readout(x)
+        return x
+
+
+######################################################################
+
+
+class FullAveragePooling(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        x = x.view(x.size(0), x.size(1), -1).mean(2).view(x.size(0), x.size(1), 1, 1)
+        return x
+
+
+class ResNetBlock(nn.Module):
+    def __init__(self, nb_channels, kernel_size):
+        super().__init__()
+
+        self.conv1 = nn.Conv2d(
+            nb_channels,
+            nb_channels,
+            kernel_size=kernel_size,
+            padding=(kernel_size - 1) // 2,
+        )
+
+        self.conv2 = nn.Conv2d(
+            nb_channels,
+            nb_channels,
+            kernel_size=kernel_size,
+            padding=(kernel_size - 1) // 2,
+        )
+
+    def forward(self, x):
+        y = F.relu(self.conv1(x))
+        y = F.relu(x + self.conv2(y))
+        return y
+
+
+######################################################################
+
+
+class ResAutoEncoder(nn.Module):
+    def __init__(self, nb_channels, kernel_size):
+        super().__init__()
+
+        self.encoder = nn.Conv2d(
+            2, nb_channels, kernel_size=kernel_size, padding=kernel_size // 2
+        )
+        self.core = nn.Sequential(
+            *[ResNetBlock(nb_channels, kernel_size) for _ in range(20)]
+        )
+        self.decoder = nn.Conv2d(
+            nb_channels, 1, kernel_size=kernel_size, padding=kernel_size // 2
+        )
+
+    def forward(self, x):
+        x = self.encoder(x)
+        x = self.decoder(x)
+        return x
+
+
+######################################################################
+
+
+class AutoEncoder(nn.Module):
+    def __init__(self, nb_channels, embedding_dim):
+        super().__init__()
+
+        self.encoder = nn.Sequential(
+            nn.Conv2d(1, nb_channels, kernel_size=5),  # to 24x24
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=5),  # to 20x20
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2),  # to 9x9
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2),  # to 4x4
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, embedding_dim, kernel_size=4),
+        )
+
+        self.decoder = nn.Sequential(
+            nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size=4),
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(
+                nb_channels, nb_channels, kernel_size=3, stride=2
+            ),  # from 4x4
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(
+                nb_channels, nb_channels, kernel_size=4, stride=2
+            ),  # from 9x9
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5),  # from 20x20
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(nb_channels, 1, kernel_size=5),  # from 24x24
+        )
+
+    def encode(self, x):
+        return self.encoder(x).view(x.size(0), -1)
+
+    def decode(self, z):
+        return self.decoder(z.view(z.size(0), -1, 1, 1))
+
+    def forward(self, x):
+        x = self.encoder(x)
+        x = self.decoder(x)
+        return x
+
+
+######################################################################
+
+train_set = torchvision.datasets.MNIST(
+    args.data_dir + "/mnist/", train=True, download=True
+)
+train_input = train_set.data.view(-1, 1, 28, 28).float()
+
+test_set = torchvision.datasets.MNIST(
+    args.data_dir + "/mnist/", train=False, download=True
+)
+test_input = test_set.data.view(-1, 1, 28, 28).float()
+
+######################################################################
+
+model = AutoEncoder(args.nb_channels, args.embedding_dim)
+
+# model = AttentionAE(
+# dim_model=16,
+# dim_keys=16,
+# dim_hidden=16,
+# nb_heads=4,
+# nb_blocks=4,
+# dropout=0.0,
+# len_max=1e5,
+# )
+
+# model = ResAutoEncoder(nb_channels=128, kernel_size=9)
+
+print(model)
+
+optimizer = optim.Adam(model.parameters(), lr=1e-3)
+
+model.to(device)
+
+train_input, test_input = train_input.to(device), test_input.to(device)
+
+mu, std = train_input.mean(), train_input.std()
+train_input.sub_(mu).div_(std)
+test_input.sub_(mu).div_(std)
+
+nb_iterations = 10
+
+######################################################################
+
+
+def dist(u, v):
+    return (u - v).pow(2).sum(dim=(1, 2, 3), keepdim=True).sqrt()
+
+
+def pb(e, desc):
+    return tqdm(
+        e,
+        dynamic_ncols=True,
+        desc=desc,
+        total=train_input.size(0) // args.batch_size,
+        delay=10,
+    )
+
+
+for n_epoch in range(args.nb_epochs):
+    acc_loss = 0
+
+    for targets in pb(train_input.split(args.batch_size), "train"):
+        input = torch.randn(targets.size(), device=targets.device)
+
+        loss = 0
+        for n in range(nb_iterations):
+            output = model(input)
+            current_d = dist(targets, output)
+            nb_remain = nb_iterations - n
+            tolerated_d = dist(targets, input) * (nb_remain - 1) / nb_remain
+            a = (tolerated_d / (current_d + 1e-6)).clamp(max=1)
+            loss += (1 - a).mean() / nb_iterations
+            input = targets - a * (targets - output.detach())
+
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+        acc_loss += loss.item()
+
+    log_string(f"acc_loss {n_epoch} {acc_loss}")
+
+    ######################################################################
+
+    input = test_input[:256]
+    model.eval()
+
+    input = torch.randn(input.size(), device=input.device)
+    for _ in range(nb_iterations):
+        output = model(input)
+        input = output.detach()
+
+    output = (output * std + mu) / 255
+
+    torchvision.utils.save_image(
+        1 - output, f"output_{n_epoch:04d}.png", nrow=16, pad_value=0.8
+    )
+
+
+######################################################################
index 896477e..19b9387 100755 (executable)
@@ -1,5 +1,8 @@
 #!/usr/bin/env python
 
+# @XREMOTE_HOST: elk.fleuret.org
+# @XREMOTE_PRE: source ~/venv/pytorch/bin/activate
+
 # Any copyright is dedicated to the Public Domain.
 # https://creativecommons.org/publicdomain/zero/1.0/
 
@@ -14,7 +17,12 @@ lr, nb_epochs, batch_size = 1e-1, 10, 100
 
 data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/"
 
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+if torch.cuda.is_available():
+    device = torch.device("cuda")
+elif torch.backends.mps.is_available():
+    device = torch.device("mps")
+else:
+    device = torch.device("cpu")
 
 ######################################################################
 
@@ -52,7 +60,7 @@ model = SomeLeNet()
 
 nb_parameters = sum(p.numel() for p in model.parameters())
 
-print(f"nb_parameters {nb_parameters}")
+print(f"device {device} nb_parameters {nb_parameters}")
 
 optimizer = torch.optim.SGD(model.parameters(), lr=lr)
 criterion = nn.CrossEntropyLoss()
@@ -70,28 +78,36 @@ test_input.sub_(mu).div_(std)
 start_time = time.perf_counter()
 
 for k in range(nb_epochs):
-    acc_loss = 0.0
+    acc_train_loss = 0.0
 
     for input, targets in zip(
         train_input.split(batch_size), train_targets.split(batch_size)
     ):
         output = model(input)
         loss = criterion(output, targets)
-        acc_loss += loss.item()
+        acc_train_loss += loss.item() * input.size(0)
 
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
 
+    acc_test_loss = 0.0
     nb_test_errors = 0
     for input, targets in zip(
         test_input.split(batch_size), test_targets.split(batch_size)
     ):
-        wta = model(input).argmax(1)
+        output = model(input)
+        loss = criterion(output, targets)
+        acc_test_loss += loss.item() * input.size(0)
+
+        wta = output.argmax(1)
         nb_test_errors += (wta != targets).long().sum()
+
     test_error = nb_test_errors / test_input.size(0)
     duration = time.perf_counter() - start_time
 
-    print(f"loss {k} {duration:.02f}s {acc_loss:.02f} {test_error*100:.02f}%")
+    print(
+        f"loss {k} {duration:.02f}s acc_train_loss {acc_train_loss/train_input.size(0):.02f} test_loss {acc_test_loss/test_input.size(0):.02f} test_error {test_error*100:.02f}%"
+    )
 
 ######################################################################