Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 1 Oct 2024 05:12:59 +0000 (07:12 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 1 Oct 2024 05:12:59 +0000 (07:12 +0200)
tinygen.py [new file with mode: 0755]

diff --git a/tinygen.py b/tinygen.py
new file mode 100755 (executable)
index 0000000..66c005c
--- /dev/null
@@ -0,0 +1,520 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+# @XREMOTE_HOST: elk.fleuret.org
+# @XREMOTE_EXEC: python
+# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
+# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
+# @XREMOTE_GET: *.png
+
+import sys, argparse, time, math
+
+import torch, torchvision
+
+from torch import optim, nn
+from torch.nn import functional as F
+from tqdm import tqdm
+
+######################################################################
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+######################################################################
+
+parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.")
+
+parser.add_argument("--nb_epochs", type=int, default=250)
+
+parser.add_argument("--batch_size", type=int, default=100)
+
+parser.add_argument("--data_dir", type=str, default="./data/")
+
+parser.add_argument("--log_filename", type=str, default="train.log")
+
+parser.add_argument("--embedding_dim", type=int, default=64)
+
+parser.add_argument("--nb_channels", type=int, default=64)
+
+args = parser.parse_args()
+
+log_file = open(args.log_filename, "w")
+
+######################################################################
+
+
+def log_string(s):
+    t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
+
+    if log_file is not None:
+        log_file.write(t + s + "\n")
+        log_file.flush()
+
+    print(t + s)
+    sys.stdout.flush()
+
+
+######################################################################
+
+
+class VaswaniPositionalEncoding(nn.Module):
+    def __init__(self, len_max):
+        super().__init__()
+        self.len_max = len_max
+
+    def forward(self, x):
+        t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
+        j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
+        k = j % 2  # works with float, weird
+        pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
+        y = x + pe
+        return y
+
+
+######################################################################
+
+
+class WithResidual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, x):
+        return x + self.f(x)
+
+
+######################################################################
+
+
+def vanilla_attention(q, k, v):
+    a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+    a = a.softmax(dim=3)
+    y = torch.einsum("nhts,nhsd->nhtd", a, v)
+    return y
+
+
+######################################################################
+
+
+class MHAttention(nn.Module):
+    def __init__(
+        self,
+        dim_model,
+        dim_qk,
+        dim_v,
+        nb_heads=1,
+        attention=vanilla_attention,
+        attention_dropout=0.0,
+    ):
+        super().__init__()
+
+        def randw(*d):
+            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+        self.attention = attention
+        self.attention_dropout = attention_dropout
+        self.w_q = randw(nb_heads, dim_qk, dim_model)
+        self.w_k = randw(nb_heads, dim_qk, dim_model)
+        self.w_v = randw(nb_heads, dim_v, dim_model)
+        self.w_o = randw(nb_heads, dim_v, dim_model)
+
+    def forward(self, x_q, x_kv=None):
+        if x_kv is None:
+            x_kv = x_q
+
+        q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
+        k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k)
+        v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v)
+        y = self.attention(q, k, v)
+        y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
+
+        return y
+
+
+######################################################################
+
+
+class AttentionAE(nn.Module):
+    def __init__(
+        self,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        dropout=0.0,
+        len_max=1e5,
+    ):
+        super().__init__()
+
+        assert dim_model % nb_heads == 0
+
+        self.embedding = nn.Sequential(
+            nn.Linear(2, dim_model),
+            nn.Dropout(dropout),
+        )
+
+        self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+        trunk_blocks = []
+
+        for b in range(nb_blocks):
+            trunk_blocks += [
+                WithResidual(
+                    nn.LayerNorm((dim_model,)),
+                    MHAttention(
+                        dim_model=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        attention=vanilla_attention,
+                        attention_dropout=dropout,
+                    ),
+                ),
+                WithResidual(
+                    nn.LayerNorm((dim_model,)),
+                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                    nn.ReLU(),
+                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                    nn.Dropout(dropout),
+                ),
+            ]
+
+        self.trunk = nn.Sequential(*trunk_blocks)
+
+        self.readout = nn.Linear(in_features=dim_model, out_features=1)
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
+
+    def forward(self, x):
+        x = x.reshape(-1, 2, 28 * 28).permute(0, 2, 1)
+        x = self.embedding(x)
+        x = self.positional_encoding(x)
+        x = self.trunk(x)
+        x = self.readout(x).reshape(-1, 1, 28, 28)
+        return x
+
+
+######################################################################
+
+
+class WithMaskedResidual(nn.Module):
+    def __init__(self, masker, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+        self.masker = masker
+        self.mask = None
+
+    def forward(self, x):
+        if self.mask is None:
+            self.mask = self.masker(x)
+        return self.mask * x + self.f(x)
+
+
+######################################################################
+
+
+class FunctionalAttentionAE(nn.Module):
+    def __init__(
+        self,
+        vocabulary_size,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        nb_work_tokens=100,
+        dropout=0.0,
+        len_max=1e5,
+    ):
+        super().__init__()
+
+        assert dim_model % nb_heads == 0
+
+        self.nb_work_tokens = nb_work_tokens
+
+        self.embedding = nn.Sequential(
+            nn.Embedding(2 * vocabulary_size, dim_model),
+            nn.Dropout(dropout),
+        )
+
+        self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+        trunk_blocks = []
+
+        def no_peek_attention(q, k, v):
+            a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+            n = self.nb_work_tokens
+            s = (q.size(2) - n) // 2
+            a[:, :, n + 1 * s : n + 2 * s, n + 0 * s : n + 1 * s] = float("-inf")
+            a[:, :, n + 0 * s : n + 1 * s, n + 1 * s : n + 2 * s] = float("-inf")
+            a = a.softmax(dim=3)
+            y = torch.einsum("nhts,nhsd->nhtd", a, v)
+            return y
+
+        def masker(x):
+            m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens
+            return m[None, :, None]
+
+        for b in range(nb_blocks):
+            trunk_blocks += [
+                WithMaskedResidual(
+                    masker,
+                    nn.LayerNorm((dim_model,)),
+                    MHAttention(
+                        dim_model=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        attention=no_peek_attention,
+                        attention_dropout=dropout,
+                    ),
+                ),
+                WithMaskedResidual(
+                    masker,
+                    nn.LayerNorm((dim_model,)),
+                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                    nn.ReLU(),
+                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                    nn.Dropout(dropout),
+                ),
+            ]
+
+        self.trunk = nn.Sequential(*trunk_blocks)
+
+        self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
+
+    def forward(self, x):
+        x = self.embedding(x)
+        x = F.pad(x, (0, 0, self.nb_work_tokens, 0))
+        x = self.positional_encoding(x)
+        x = self.trunk(x)
+        x = F.pad(x, (0, 0, -self.nb_work_tokens, 0))
+        x = self.readout(x)
+        return x
+
+
+######################################################################
+
+
+class FullAveragePooling(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        x = x.view(x.size(0), x.size(1), -1).mean(2).view(x.size(0), x.size(1), 1, 1)
+        return x
+
+
+class ResNetBlock(nn.Module):
+    def __init__(self, nb_channels, kernel_size):
+        super().__init__()
+
+        self.conv1 = nn.Conv2d(
+            nb_channels,
+            nb_channels,
+            kernel_size=kernel_size,
+            padding=(kernel_size - 1) // 2,
+        )
+
+        self.conv2 = nn.Conv2d(
+            nb_channels,
+            nb_channels,
+            kernel_size=kernel_size,
+            padding=(kernel_size - 1) // 2,
+        )
+
+    def forward(self, x):
+        y = F.relu(self.conv1(x))
+        y = F.relu(x + self.conv2(y))
+        return y
+
+
+######################################################################
+
+
+class ResAutoEncoder(nn.Module):
+    def __init__(self, nb_channels, kernel_size):
+        super().__init__()
+
+        self.encoder = nn.Conv2d(
+            2, nb_channels, kernel_size=kernel_size, padding=kernel_size // 2
+        )
+        self.core = nn.Sequential(
+            *[ResNetBlock(nb_channels, kernel_size) for _ in range(20)]
+        )
+        self.decoder = nn.Conv2d(
+            nb_channels, 1, kernel_size=kernel_size, padding=kernel_size // 2
+        )
+
+    def forward(self, x):
+        x = self.encoder(x)
+        x = self.decoder(x)
+        return x
+
+
+######################################################################
+
+
+class AutoEncoder(nn.Module):
+    def __init__(self, nb_channels, embedding_dim):
+        super().__init__()
+
+        self.encoder = nn.Sequential(
+            nn.Conv2d(1, nb_channels, kernel_size=5),  # to 24x24
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=5),  # to 20x20
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2),  # to 9x9
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2),  # to 4x4
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, embedding_dim, kernel_size=4),
+        )
+
+        self.decoder = nn.Sequential(
+            nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size=4),
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(
+                nb_channels, nb_channels, kernel_size=3, stride=2
+            ),  # from 4x4
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(
+                nb_channels, nb_channels, kernel_size=4, stride=2
+            ),  # from 9x9
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5),  # from 20x20
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(nb_channels, 1, kernel_size=5),  # from 24x24
+        )
+
+    def encode(self, x):
+        return self.encoder(x).view(x.size(0), -1)
+
+    def decode(self, z):
+        return self.decoder(z.view(z.size(0), -1, 1, 1))
+
+    def forward(self, x):
+        x = self.encoder(x)
+        x = self.decoder(x)
+        return x
+
+
+######################################################################
+
+train_set = torchvision.datasets.MNIST(
+    args.data_dir + "/mnist/", train=True, download=True
+)
+train_input = train_set.data.view(-1, 1, 28, 28).float()
+
+test_set = torchvision.datasets.MNIST(
+    args.data_dir + "/mnist/", train=False, download=True
+)
+test_input = test_set.data.view(-1, 1, 28, 28).float()
+
+######################################################################
+
+model = AutoEncoder(args.nb_channels, args.embedding_dim)
+
+# model = AttentionAE(
+# dim_model=16,
+# dim_keys=16,
+# dim_hidden=16,
+# nb_heads=4,
+# nb_blocks=4,
+# dropout=0.0,
+# len_max=1e5,
+# )
+
+# model = ResAutoEncoder(nb_channels=128, kernel_size=9)
+
+print(model)
+
+optimizer = optim.Adam(model.parameters(), lr=1e-3)
+
+model.to(device)
+
+train_input, test_input = train_input.to(device), test_input.to(device)
+
+mu, std = train_input.mean(), train_input.std()
+train_input.sub_(mu).div_(std)
+test_input.sub_(mu).div_(std)
+
+nb_iterations = 10
+
+######################################################################
+
+
+def dist(u, v):
+    return (u - v).pow(2).sum(dim=(1, 2, 3), keepdim=True).sqrt()
+
+
+def pb(e, desc):
+    return tqdm(
+        e,
+        dynamic_ncols=True,
+        desc=desc,
+        total=train_input.size(0) // args.batch_size,
+        delay=10,
+    )
+
+
+for n_epoch in range(args.nb_epochs):
+    acc_loss = 0
+
+    for targets in pb(train_input.split(args.batch_size), "train"):
+        input = torch.randn(targets.size(), device=targets.device)
+
+        loss = 0
+        for n in range(nb_iterations):
+            output = model(input)
+            current_d = dist(targets, output)
+            nb_remain = nb_iterations - n
+            tolerated_d = dist(targets, input) * (nb_remain - 1) / nb_remain
+            a = (tolerated_d / (current_d + 1e-6)).clamp(max=1)
+            loss += (1 - a).mean() / nb_iterations
+            input = targets - a * (targets - output.detach())
+
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+        acc_loss += loss.item()
+
+    log_string(f"acc_loss {n_epoch} {acc_loss}")
+
+    ######################################################################
+
+    input = test_input[:256]
+    model.eval()
+
+    input = torch.randn(input.size(), device=input.device)
+    for _ in range(nb_iterations):
+        output = model(input)
+        input = output.detach()
+
+    output = (output * std + mu) / 255
+
+    torchvision.utils.save_image(
+        1 - output, f"output_{n_epoch:04d}.png", nrow=16, pad_value=0.8
+    )
+
+
+######################################################################