From 1399a4f5270eba99e8557125408826f26910b539 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 1 Oct 2024 21:38:51 +0200 Subject: [PATCH] Update. --- difdis.py | 575 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 575 insertions(+) create mode 100755 difdis.py diff --git a/difdis.py b/difdis.py new file mode 100755 index 0000000..5ec408f --- /dev/null +++ b/difdis.py @@ -0,0 +1,575 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +# @XREMOTE_HOST: elk.fleuret.org +# @XREMOTE_EXEC: python +# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate +# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data +# @XREMOTE_GET: *.png + +import sys, argparse, time, math + +import torch, torchvision + +from torch import optim, nn +from torch.nn import functional as F +from tqdm import tqdm + +###################################################################### + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +###################################################################### + +parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.") + +parser.add_argument("--nb_epochs", type=int, default=250) + +parser.add_argument("--batch_size", type=int, default=100) + +parser.add_argument("--nb_train_samples", type=int, default=-1) + +parser.add_argument("--nb_test_samples", type=int, default=-1) + +parser.add_argument("--data_dir", type=str, default="./data/") + +parser.add_argument("--log_filename", type=str, default="train.log") + +parser.add_argument("--embedding_dim", type=int, default=64) + +parser.add_argument("--nb_channels", type=int, default=64) + +args = parser.parse_args() + +log_file = open(args.log_filename, "w") + +###################################################################### + + +def log_string(s): + t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime()) + + if log_file is not None: + log_file.write(t + s + "\n") + log_file.flush() + + print(t + s) + sys.stdout.flush() + + +###################################################################### + + +class VaswaniPositionalEncoding(nn.Module): + def __init__(self, len_max): + super().__init__() + self.len_max = len_max + + def forward(self, x): + t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None] + j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :] + k = j % 2 # works with float, weird + pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k) + y = x + pe + return y + + +###################################################################### + + +class WithResidual(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, x): + return x + self.f(x) + + +###################################################################### + + +def vanilla_attention(q, k, v): + a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) + a = a.softmax(dim=3) + y = torch.einsum("nhts,nhsd->nhtd", a, v) + return y + + +###################################################################### + + +class MHAttention(nn.Module): + def __init__( + self, + dim_model, + dim_qk, + dim_v, + nb_heads=1, + attention=vanilla_attention, + attention_dropout=0.0, + ): + super().__init__() + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.attention = attention + self.attention_dropout = attention_dropout + self.w_q = randw(nb_heads, dim_qk, dim_model) + self.w_k = randw(nb_heads, dim_qk, dim_model) + self.w_v = randw(nb_heads, dim_v, dim_model) + self.w_o = randw(nb_heads, dim_v, dim_model) + + def forward(self, x_q, x_kv=None): + if x_kv is None: + x_kv = x_q + + q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q) + k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k) + v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v) + y = self.attention(q, k, v) + y = torch.einsum("nhtd,hdc->ntc", y, self.w_o) + + return y + + +###################################################################### + + +class AttentionAE(nn.Module): + def __init__( + self, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + dropout=0.0, + len_max=1e5, + ): + super().__init__() + + assert dim_model % nb_heads == 0 + + self.embedding = nn.Sequential( + nn.Linear(2, dim_model), + nn.Dropout(dropout), + ) + + self.positional_encoding = VaswaniPositionalEncoding(len_max) + + trunk_blocks = [] + + for b in range(nb_blocks): + trunk_blocks += [ + WithResidual( + nn.LayerNorm((dim_model,)), + MHAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=vanilla_attention, + attention_dropout=dropout, + ), + ), + WithResidual( + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = nn.Linear(in_features=dim_model, out_features=1) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + def forward(self, x): + x = x.reshape(-1, 2, 28 * 28).permute(0, 2, 1) + x = self.embedding(x) + x = self.positional_encoding(x) + x = self.trunk(x) + x = self.readout(x).reshape(-1, 1, 28, 28) + return x + + +###################################################################### + + +class WithMaskedResidual(nn.Module): + def __init__(self, masker, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + self.masker = masker + self.mask = None + + def forward(self, x): + if self.mask is None: + self.mask = self.masker(x) + return self.mask * x + self.f(x) + + +###################################################################### + + +class FunctionalAttentionAE(nn.Module): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + nb_work_tokens=100, + dropout=0.0, + len_max=1e5, + ): + super().__init__() + + assert dim_model % nb_heads == 0 + + self.nb_work_tokens = nb_work_tokens + + self.embedding = nn.Sequential( + nn.Embedding(2 * vocabulary_size, dim_model), + nn.Dropout(dropout), + ) + + self.positional_encoding = VaswaniPositionalEncoding(len_max) + + trunk_blocks = [] + + def no_peek_attention(q, k, v): + a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) + n = self.nb_work_tokens + s = (q.size(2) - n) // 2 + a[:, :, n + 1 * s : n + 2 * s, n + 0 * s : n + 1 * s] = float("-inf") + a[:, :, n + 0 * s : n + 1 * s, n + 1 * s : n + 2 * s] = float("-inf") + a = a.softmax(dim=3) + y = torch.einsum("nhts,nhsd->nhtd", a, v) + return y + + def masker(x): + m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens + return m[None, :, None] + + for b in range(nb_blocks): + trunk_blocks += [ + WithMaskedResidual( + masker, + nn.LayerNorm((dim_model,)), + MHAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=no_peek_attention, + attention_dropout=dropout, + ), + ), + WithMaskedResidual( + masker, + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + def forward(self, x): + x = self.embedding(x) + x = F.pad(x, (0, 0, self.nb_work_tokens, 0)) + x = self.positional_encoding(x) + x = self.trunk(x) + x = F.pad(x, (0, 0, -self.nb_work_tokens, 0)) + x = self.readout(x) + return x + + +###################################################################### + + +class FullAveragePooling(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = x.view(x.size(0), x.size(1), -1).mean(2).view(x.size(0), x.size(1), 1, 1) + return x + + +class ResNetBlock(nn.Module): + def __init__(self, nb_channels, kernel_size): + super().__init__() + + self.conv1 = nn.Conv2d( + nb_channels, + nb_channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ) + + self.conv2 = nn.Conv2d( + nb_channels, + nb_channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ) + + def forward(self, x): + y = F.relu(self.conv1(x)) + y = F.relu(x + self.conv2(y)) + return y + + +###################################################################### + + +class ResAutoEncoder(nn.Module): + def __init__(self, nb_channels, kernel_size): + super().__init__() + + self.encoder = nn.Conv2d( + 2, nb_channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + self.core = nn.Sequential( + *[ResNetBlock(nb_channels, kernel_size) for _ in range(20)] + ) + self.decoder = nn.Conv2d( + nb_channels, 1, kernel_size=kernel_size, padding=kernel_size // 2 + ) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +###################################################################### + + +class AutoEncoder(nn.Module): + def __init__(self, nb_channels, embedding_dim): + super().__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(256, nb_channels, kernel_size=5), # to 24x24 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=5), # to 20x20 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2), # to 9x9 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2), # to 4x4 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, embedding_dim, kernel_size=4), + ) + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size=4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d( + nb_channels, nb_channels, kernel_size=3, stride=2 + ), # from 4x4 + nn.ReLU(inplace=True), + nn.ConvTranspose2d( + nb_channels, nb_channels, kernel_size=4, stride=2 + ), # from 9x9 + nn.ReLU(inplace=True), + nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5), # from 20x20 + nn.ReLU(inplace=True), + nn.ConvTranspose2d(nb_channels, 256, kernel_size=5), # from 24x24 + ) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +###################################################################### + +train_set = torchvision.datasets.MNIST( + args.data_dir + "/mnist/", train=True, download=True +) +train_input = train_set.data.view(-1, 28, 28).long() + +test_set = torchvision.datasets.MNIST( + args.data_dir + "/mnist/", train=False, download=True +) +test_input = test_set.data.view(-1, 28, 28).long() + +if args.nb_train_samples > 0: + train_input = train_input[: args.nb_train_samples] + +if args.nb_test_samples > 0: + test_input = test_input[: args.nb_test_samples] + +###################################################################### + +model = AutoEncoder(args.nb_channels, args.embedding_dim) + +# model = AttentionAE( +# dim_model=16, +# dim_keys=16, +# dim_hidden=16, +# nb_heads=4, +# nb_blocks=4, +# dropout=0.0, +# len_max=1e5, +# ) + +# model = ResAutoEncoder(nb_channels=128, kernel_size=9) + +print(model) + +optimizer = optim.Adam(model.parameters(), lr=1e-3) + +model.to(device) + +train_input, test_input = train_input.to(device), test_input.to(device) + +nb_iterations = 10 + +###################################################################### + + +def dist(u, v): + return (u - v).pow(2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + + +def kl(u, v): + s = (u.size(0),) + (1,) * (u.dim() - 1) + u = u.flatten(1).log_softmax(dim=1) + v = v.flatten(1).log_softmax(dim=1) + return ( + F.kl_div(u, v, log_target=True, reduction="none") + .sum(dim=1, keepdim=True) + .clamp(min=0) + .reshape(*s) + ) + + +def pb(e, desc): + return tqdm( + e, + dynamic_ncols=True, + desc=desc, + total=train_input.size(0) // args.batch_size, + delay=10, + ) + + +for n_epoch in range(args.nb_epochs): + acc_loss = 0 + + for targets in pb(train_input.split(args.batch_size), "train"): + targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) * 9.0 + 1.0 + input = torch.randn(targets.size(), device=targets.device) * 1e-3 + + # print(f"-----------------") + + loss = 0 + for n in range(nb_iterations): + output = model(input).clamp(min=-10, max=10) + # assert not output.isnan().any() + current_kl = kl(output, targets) + + # assert not current_kl.isnan().any() + + nb_remain = nb_iterations - n + tolerated_kl = kl(input, targets) * (nb_remain - 1) / nb_remain + + # assert not tolerated_kl.isnan().any() + + with torch.no_grad(): + a = input.new_full((input.size(0), 1, 1, 1), 0.0) + b = input.new_full((input.size(0), 1, 1, 1), 1.0) + + # KL(a * output + (1-a) * targets) = 0 + # KL(b * output + (1-b) * targets) >> 0 + + for _ in range(10): + c = (a + b) / 2 + kl_c = kl(c * output + (1 - c) * targets, targets) + # print(f"{c.size()=} {output.size()=} {targets.size()=} {kl_c.size()=}") + # print(f"{kl_c}") + # assert kl_c.min() >= 0 + # assert not kl_c.isnan().any() + m = (kl_c >= tolerated_kl).long() + a = m * a + (1 - m) * c + b = m * c + (1 - m) * b + # print(f"{a.size()=} {b.size()=} {m.size()=}") + + c = (a + b) / 2 + + # print() + # print(tolerated_kl.flatten()) + # print(kl_c.flatten()) + # print(c.flatten()) + + input = c * output + (1 - c) * targets + loss += kl(output, input).mean() + # assert not loss.isnan() + input = input.detach() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + acc_loss += loss.item() + + log_string(f"acc_loss {n_epoch} {acc_loss}") + + ###################################################################### + + targets = test_input[:256] + targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) * 9.0 + 1.0 + input = torch.randn(targets.size(), device=targets.device) * 1e-3 + model.eval() + + input = torch.randn(input.size(), device=input.device) + for _ in range(nb_iterations): + output = model(input) + input = output.detach() + + output = ( + output.softmax(dim=1) + * torch.arange(output.size(1), device=output.device)[None, :, None, None] + ).sum(dim=1) + output = output[:, None, :, :] / 255 + + torchvision.utils.save_image( + 1 - output, f"output_{n_epoch:04d}.png", nrow=16, pad_value=0.8 + ) + + +###################################################################### -- 2.39.5