--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+# @XREMOTE_HOST: elk.fleuret.org
+# @XREMOTE_EXEC: python
+# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
+# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
+# @XREMOTE_GET: *.png
+
+import sys, argparse, time, math
+
+import torch, torchvision
+
+from torch import optim, nn
+from torch.nn import functional as F
+from tqdm import tqdm
+
+######################################################################
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+######################################################################
+
+parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.")
+
+parser.add_argument("--nb_epochs", type=int, default=250)
+
+parser.add_argument("--batch_size", type=int, default=100)
+
+parser.add_argument("--nb_train_samples", type=int, default=-1)
+
+parser.add_argument("--nb_test_samples", type=int, default=-1)
+
+parser.add_argument("--data_dir", type=str, default="./data/")
+
+parser.add_argument("--log_filename", type=str, default="train.log")
+
+parser.add_argument("--embedding_dim", type=int, default=64)
+
+parser.add_argument("--nb_channels", type=int, default=64)
+
+args = parser.parse_args()
+
+log_file = open(args.log_filename, "w")
+
+######################################################################
+
+
+def log_string(s):
+ t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
+
+ if log_file is not None:
+ log_file.write(t + s + "\n")
+ log_file.flush()
+
+ print(t + s)
+ sys.stdout.flush()
+
+
+######################################################################
+
+
+class VaswaniPositionalEncoding(nn.Module):
+ def __init__(self, len_max):
+ super().__init__()
+ self.len_max = len_max
+
+ def forward(self, x):
+ t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
+ j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
+ k = j % 2 # works with float, weird
+ pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
+ y = x + pe
+ return y
+
+
+######################################################################
+
+
+class WithResidual(nn.Module):
+ def __init__(self, *f):
+ super().__init__()
+ self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+ def forward(self, x):
+ return x + self.f(x)
+
+
+######################################################################
+
+
+def vanilla_attention(q, k, v):
+ a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+ a = a.softmax(dim=3)
+ y = torch.einsum("nhts,nhsd->nhtd", a, v)
+ return y
+
+
+######################################################################
+
+
+class MHAttention(nn.Module):
+ def __init__(
+ self,
+ dim_model,
+ dim_qk,
+ dim_v,
+ nb_heads=1,
+ attention=vanilla_attention,
+ attention_dropout=0.0,
+ ):
+ super().__init__()
+
+ def randw(*d):
+ return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+ self.attention = attention
+ self.attention_dropout = attention_dropout
+ self.w_q = randw(nb_heads, dim_qk, dim_model)
+ self.w_k = randw(nb_heads, dim_qk, dim_model)
+ self.w_v = randw(nb_heads, dim_v, dim_model)
+ self.w_o = randw(nb_heads, dim_v, dim_model)
+
+ def forward(self, x_q, x_kv=None):
+ if x_kv is None:
+ x_kv = x_q
+
+ q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
+ k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k)
+ v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v)
+ y = self.attention(q, k, v)
+ y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
+
+ return y
+
+
+######################################################################
+
+
+class AttentionAE(nn.Module):
+ def __init__(
+ self,
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ dropout=0.0,
+ len_max=1e5,
+ ):
+ super().__init__()
+
+ assert dim_model % nb_heads == 0
+
+ self.embedding = nn.Sequential(
+ nn.Linear(2, dim_model),
+ nn.Dropout(dropout),
+ )
+
+ self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+ trunk_blocks = []
+
+ for b in range(nb_blocks):
+ trunk_blocks += [
+ WithResidual(
+ nn.LayerNorm((dim_model,)),
+ MHAttention(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention=vanilla_attention,
+ attention_dropout=dropout,
+ ),
+ ),
+ WithResidual(
+ nn.LayerNorm((dim_model,)),
+ nn.Linear(in_features=dim_model, out_features=dim_hidden),
+ nn.ReLU(),
+ nn.Linear(in_features=dim_hidden, out_features=dim_model),
+ nn.Dropout(dropout),
+ ),
+ ]
+
+ self.trunk = nn.Sequential(*trunk_blocks)
+
+ self.readout = nn.Linear(in_features=dim_model, out_features=1)
+
+ with torch.no_grad():
+ for m in self.modules():
+ if isinstance(m, nn.Embedding):
+ m.weight.normal_(mean=0, std=2e-2)
+ elif isinstance(m, nn.LayerNorm):
+ m.bias.zero_()
+ m.weight.fill_(1.0)
+
+ def forward(self, x):
+ x = x.reshape(-1, 2, 28 * 28).permute(0, 2, 1)
+ x = self.embedding(x)
+ x = self.positional_encoding(x)
+ x = self.trunk(x)
+ x = self.readout(x).reshape(-1, 1, 28, 28)
+ return x
+
+
+######################################################################
+
+
+class WithMaskedResidual(nn.Module):
+ def __init__(self, masker, *f):
+ super().__init__()
+ self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+ self.masker = masker
+ self.mask = None
+
+ def forward(self, x):
+ if self.mask is None:
+ self.mask = self.masker(x)
+ return self.mask * x + self.f(x)
+
+
+######################################################################
+
+
+class FunctionalAttentionAE(nn.Module):
+ def __init__(
+ self,
+ vocabulary_size,
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ nb_work_tokens=100,
+ dropout=0.0,
+ len_max=1e5,
+ ):
+ super().__init__()
+
+ assert dim_model % nb_heads == 0
+
+ self.nb_work_tokens = nb_work_tokens
+
+ self.embedding = nn.Sequential(
+ nn.Embedding(2 * vocabulary_size, dim_model),
+ nn.Dropout(dropout),
+ )
+
+ self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+ trunk_blocks = []
+
+ def no_peek_attention(q, k, v):
+ a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+ n = self.nb_work_tokens
+ s = (q.size(2) - n) // 2
+ a[:, :, n + 1 * s : n + 2 * s, n + 0 * s : n + 1 * s] = float("-inf")
+ a[:, :, n + 0 * s : n + 1 * s, n + 1 * s : n + 2 * s] = float("-inf")
+ a = a.softmax(dim=3)
+ y = torch.einsum("nhts,nhsd->nhtd", a, v)
+ return y
+
+ def masker(x):
+ m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens
+ return m[None, :, None]
+
+ for b in range(nb_blocks):
+ trunk_blocks += [
+ WithMaskedResidual(
+ masker,
+ nn.LayerNorm((dim_model,)),
+ MHAttention(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention=no_peek_attention,
+ attention_dropout=dropout,
+ ),
+ ),
+ WithMaskedResidual(
+ masker,
+ nn.LayerNorm((dim_model,)),
+ nn.Linear(in_features=dim_model, out_features=dim_hidden),
+ nn.ReLU(),
+ nn.Linear(in_features=dim_hidden, out_features=dim_model),
+ nn.Dropout(dropout),
+ ),
+ ]
+
+ self.trunk = nn.Sequential(*trunk_blocks)
+
+ self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+
+ with torch.no_grad():
+ for m in self.modules():
+ if isinstance(m, nn.Embedding):
+ m.weight.normal_(mean=0, std=2e-2)
+ elif isinstance(m, nn.LayerNorm):
+ m.bias.zero_()
+ m.weight.fill_(1.0)
+
+ def forward(self, x):
+ x = self.embedding(x)
+ x = F.pad(x, (0, 0, self.nb_work_tokens, 0))
+ x = self.positional_encoding(x)
+ x = self.trunk(x)
+ x = F.pad(x, (0, 0, -self.nb_work_tokens, 0))
+ x = self.readout(x)
+ return x
+
+
+######################################################################
+
+
+class FullAveragePooling(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ x = x.view(x.size(0), x.size(1), -1).mean(2).view(x.size(0), x.size(1), 1, 1)
+ return x
+
+
+class ResNetBlock(nn.Module):
+ def __init__(self, nb_channels, kernel_size):
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(
+ nb_channels,
+ nb_channels,
+ kernel_size=kernel_size,
+ padding=(kernel_size - 1) // 2,
+ )
+
+ self.conv2 = nn.Conv2d(
+ nb_channels,
+ nb_channels,
+ kernel_size=kernel_size,
+ padding=(kernel_size - 1) // 2,
+ )
+
+ def forward(self, x):
+ y = F.relu(self.conv1(x))
+ y = F.relu(x + self.conv2(y))
+ return y
+
+
+######################################################################
+
+
+class ResAutoEncoder(nn.Module):
+ def __init__(self, nb_channels, kernel_size):
+ super().__init__()
+
+ self.encoder = nn.Conv2d(
+ 2, nb_channels, kernel_size=kernel_size, padding=kernel_size // 2
+ )
+ self.core = nn.Sequential(
+ *[ResNetBlock(nb_channels, kernel_size) for _ in range(20)]
+ )
+ self.decoder = nn.Conv2d(
+ nb_channels, 1, kernel_size=kernel_size, padding=kernel_size // 2
+ )
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.decoder(x)
+ return x
+
+
+######################################################################
+
+
+class AutoEncoder(nn.Module):
+ def __init__(self, nb_channels, embedding_dim):
+ super().__init__()
+
+ self.encoder = nn.Sequential(
+ nn.Conv2d(256, nb_channels, kernel_size=5), # to 24x24
+ nn.ReLU(inplace=True),
+ nn.Conv2d(nb_channels, nb_channels, kernel_size=5), # to 20x20
+ nn.ReLU(inplace=True),
+ nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2), # to 9x9
+ nn.ReLU(inplace=True),
+ nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2), # to 4x4
+ nn.ReLU(inplace=True),
+ nn.Conv2d(nb_channels, embedding_dim, kernel_size=4),
+ )
+
+ self.decoder = nn.Sequential(
+ nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size=4),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(
+ nb_channels, nb_channels, kernel_size=3, stride=2
+ ), # from 4x4
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(
+ nb_channels, nb_channels, kernel_size=4, stride=2
+ ), # from 9x9
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5), # from 20x20
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(nb_channels, 256, kernel_size=5), # from 24x24
+ )
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.decoder(x)
+ return x
+
+
+######################################################################
+
+train_set = torchvision.datasets.MNIST(
+ args.data_dir + "/mnist/", train=True, download=True
+)
+train_input = train_set.data.view(-1, 28, 28).long()
+
+test_set = torchvision.datasets.MNIST(
+ args.data_dir + "/mnist/", train=False, download=True
+)
+test_input = test_set.data.view(-1, 28, 28).long()
+
+if args.nb_train_samples > 0:
+ train_input = train_input[: args.nb_train_samples]
+
+if args.nb_test_samples > 0:
+ test_input = test_input[: args.nb_test_samples]
+
+######################################################################
+
+model = AutoEncoder(args.nb_channels, args.embedding_dim)
+
+# model = AttentionAE(
+# dim_model=16,
+# dim_keys=16,
+# dim_hidden=16,
+# nb_heads=4,
+# nb_blocks=4,
+# dropout=0.0,
+# len_max=1e5,
+# )
+
+# model = ResAutoEncoder(nb_channels=128, kernel_size=9)
+
+print(model)
+
+optimizer = optim.Adam(model.parameters(), lr=1e-3)
+
+model.to(device)
+
+train_input, test_input = train_input.to(device), test_input.to(device)
+
+nb_iterations = 10
+
+######################################################################
+
+
+def dist(u, v):
+ return (u - v).pow(2).sum(dim=(1, 2, 3), keepdim=True).sqrt()
+
+
+def kl(u, v):
+ s = (u.size(0),) + (1,) * (u.dim() - 1)
+ u = u.flatten(1).log_softmax(dim=1)
+ v = v.flatten(1).log_softmax(dim=1)
+ return (
+ F.kl_div(u, v, log_target=True, reduction="none")
+ .sum(dim=1, keepdim=True)
+ .clamp(min=0)
+ .reshape(*s)
+ )
+
+
+def pb(e, desc):
+ return tqdm(
+ e,
+ dynamic_ncols=True,
+ desc=desc,
+ total=train_input.size(0) // args.batch_size,
+ delay=10,
+ )
+
+
+for n_epoch in range(args.nb_epochs):
+ acc_loss = 0
+
+ for targets in pb(train_input.split(args.batch_size), "train"):
+ targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) * 9.0 + 1.0
+ input = torch.randn(targets.size(), device=targets.device) * 1e-3
+
+ # print(f"-----------------")
+
+ loss = 0
+ for n in range(nb_iterations):
+ output = model(input).clamp(min=-10, max=10)
+ # assert not output.isnan().any()
+ current_kl = kl(output, targets)
+
+ # assert not current_kl.isnan().any()
+
+ nb_remain = nb_iterations - n
+ tolerated_kl = kl(input, targets) * (nb_remain - 1) / nb_remain
+
+ # assert not tolerated_kl.isnan().any()
+
+ with torch.no_grad():
+ a = input.new_full((input.size(0), 1, 1, 1), 0.0)
+ b = input.new_full((input.size(0), 1, 1, 1), 1.0)
+
+ # KL(a * output + (1-a) * targets) = 0
+ # KL(b * output + (1-b) * targets) >> 0
+
+ for _ in range(10):
+ c = (a + b) / 2
+ kl_c = kl(c * output + (1 - c) * targets, targets)
+ # print(f"{c.size()=} {output.size()=} {targets.size()=} {kl_c.size()=}")
+ # print(f"{kl_c}")
+ # assert kl_c.min() >= 0
+ # assert not kl_c.isnan().any()
+ m = (kl_c >= tolerated_kl).long()
+ a = m * a + (1 - m) * c
+ b = m * c + (1 - m) * b
+ # print(f"{a.size()=} {b.size()=} {m.size()=}")
+
+ c = (a + b) / 2
+
+ # print()
+ # print(tolerated_kl.flatten())
+ # print(kl_c.flatten())
+ # print(c.flatten())
+
+ input = c * output + (1 - c) * targets
+ loss += kl(output, input).mean()
+ # assert not loss.isnan()
+ input = input.detach()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ acc_loss += loss.item()
+
+ log_string(f"acc_loss {n_epoch} {acc_loss}")
+
+ ######################################################################
+
+ targets = test_input[:256]
+ targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) * 9.0 + 1.0
+ input = torch.randn(targets.size(), device=targets.device) * 1e-3
+ model.eval()
+
+ input = torch.randn(input.size(), device=input.device)
+ for _ in range(nb_iterations):
+ output = model(input)
+ input = output.detach()
+
+ output = (
+ output.softmax(dim=1)
+ * torch.arange(output.size(1), device=output.device)[None, :, None, None]
+ ).sum(dim=1)
+ output = output[:, None, :, :] / 255
+
+ torchvision.utils.save_image(
+ 1 - output, f"output_{n_epoch:04d}.png", nrow=16, pad_value=0.8
+ )
+
+
+######################################################################